├── .gitignore
├── LICENSE
├── README.md
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── configs
├── anet
│ ├── anet_few_shot.yaml
│ ├── anet_k400_finetune.yaml
│ ├── anet_k400_finetune_336.yaml
│ └── anet_zero_shot.yaml
├── hmdb51
│ ├── hmdb_few_shot.yaml
│ ├── hmdb_k400_finetune.yaml
│ ├── hmdb_k400_finetune_336.yaml
│ └── hmdb_zero_shot.yaml
├── k400
│ ├── k400_few_shot.yaml
│ ├── k400_train_rgb_rn50.yaml
│ ├── k400_train_rgb_vitb-16-f16.yaml
│ ├── k400_train_rgb_vitb-16-f8.yaml
│ ├── k400_train_rgb_vitb-32-f16.yaml
│ ├── k400_train_rgb_vitb-32-f8.yaml
│ ├── k400_train_rgb_vitl-14-336-f32.yaml
│ ├── k400_train_rgb_vitl-14-336-f8.yaml
│ └── k400_train_rgb_vitl-14-f8.yaml
├── k600
│ ├── k600_zero_shot_split1.yaml
│ ├── k600_zero_shot_split2.yaml
│ └── k600_zero_shot_split3.yaml
└── ucf101
│ ├── ucf_few_shot.yaml
│ ├── ucf_k400_finetune.yaml
│ ├── ucf_k400_finetune_336.yaml
│ └── ucf_zero_shot.yaml
├── datasets
├── __init__.py
├── datasets.py
└── transforms.py
├── exps
├── anet
│ └── ViT-L
│ │ ├── 14
│ │ └── f16
│ │ │ └── log.txt
│ │ └── 14-336px
│ │ └── f16
│ │ └── log.txt
├── hmdb51
│ └── ViT-L
│ │ └── 14
│ │ └── f16
│ │ └── log.txt
├── k400
│ ├── ViT-B
│ │ ├── 16
│ │ │ ├── f16
│ │ │ │ └── log.txt
│ │ │ └── f8
│ │ │ │ └── log.txt
│ │ └── 32
│ │ │ ├── f16
│ │ │ └── log.txt
│ │ │ └── f8
│ │ │ └── log.txt
│ └── ViT-L
│ │ ├── 14
│ │ └── f8
│ │ │ └── log.txt
│ │ └── 14-336px
│ │ ├── f32
│ │ └── log.txt
│ │ └── f8
│ │ └── log.txt
└── ucf101
│ └── ViT-L
│ ├── 14
│ └── f16
│ │ └── log.txt
│ └── 14-336px
│ └── f16
│ └── log.txt
├── lists
├── anet
│ ├── anet1.3_label2name.json
│ ├── anet_full_for_zeroshot.txt
│ ├── anet_train_instance_fps1.txt
│ └── anet_val_video_fps1.txt
├── anet1.3_labels.csv
├── hmdb51
│ ├── hmdb_full_for_zeroshot.txt
│ ├── train_rgb_split_1.txt
│ ├── train_rgb_split_2.txt
│ ├── train_rgb_split_3.txt
│ ├── val_rgb_split_1.txt
│ ├── val_rgb_split_2.txt
│ └── val_rgb_split_3.txt
├── hmdb51_labels.csv
├── k400
│ ├── kinetics_rgb_train_se320.txt
│ ├── kinetics_rgb_val_se320.txt
│ ├── kinetics_video_train_se320.txt
│ └── kinetics_video_val_se320.txt
├── k600
│ ├── k160_labels_split1.csv
│ ├── k160_labels_split2.csv
│ ├── k160_labels_split3.csv
│ ├── test_split1_exist.txt
│ ├── test_split2_exist.txt
│ └── test_split3_exist.txt
├── kinetics_400_labels.csv
├── ucf101
│ ├── train_rgb_split_1.txt
│ ├── train_rgb_split_2.txt
│ ├── train_rgb_split_3.txt
│ ├── ucf_full_for_zeroshot.txt
│ ├── val_rgb_split_1.txt
│ ├── val_rgb_split_2.txt
│ └── val_rgb_split_3.txt
└── ucf_labels.csv
├── modules
├── coop.py
├── temporal_modeling.py
├── text_prompt.py
└── video_clip.py
├── scripts
├── run_test.sh
├── run_test_zeroshot.sh
├── run_train.sh
├── run_train_multinodes.sh
└── run_train_nce.sh
├── teaser.png
├── test.py
├── test_anet.py
├── test_zeroshot.py
├── text4vis.png
├── train.py
├── train_nce.py
└── utils
├── Augmentation.py
├── NCELoss.py
├── logger.py
├── lr_scheduler.py
├── solver.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 | .DS_Store
128 | .vscode/
129 | exp/
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Wenhao Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🔥【AAAI'2023, IJCV'2023】Revisiting Classifier: Transferring Vision-Language Models for Video Recognition
4 |
5 | [](https://ojs.aaai.org/index.php/AAAI/article/view/25386/25158)
6 | [](https://link.springer.com/article/10.1007/s11263-023-01876-w)
7 |
8 |
9 | [Wenhao Wu](https://whwu95.github.io/)1,2, [Zhun Sun](https://scholar.google.co.jp/citations?user=Y-3iZ9EAAAAJ&hl=en)2, [Wanli Ouyang](https://wlouyang.github.io/)3,1
10 |
11 |
12 | 1[The University of Sydney](https://www.sydney.edu.au/), 2[Baidu](https://vis.baidu.com/#/), 3[Shanghai AI Lab](https://www.shlab.org.cn/)
13 |
14 |
15 |
16 |
17 | ***
18 | [](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=transferring-textual-knowledge-for-visual)
19 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-activitynet?p=transferring-textual-knowledge-for-visual)
20 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-ucf101?p=transferring-textual-knowledge-for-visual)
21 | [](https://paperswithcode.com/sota/zero-shot-action-recognition-on-kinetics?p=transferring-textual-knowledge-for-visual)
22 | [](https://paperswithcode.com/sota/zero-shot-action-recognition-on-activitynet?p=transferring-textual-knowledge-for-visual)
23 | [](https://paperswithcode.com/sota/zero-shot-action-recognition-on-ucf101?p=transferring-textual-knowledge-for-visual)
24 | [](https://paperswithcode.com/sota/zero-shot-action-recognition-on-hmdb51?p=transferring-textual-knowledge-for-visual)
25 |
26 |
27 | This is the official implementation of the **AAAI paper** [Revisiting Classifier: Transferring Vision-Language Models for Video Recognition](https://arxiv.org/abs/2207.01297), and **IJCV paper** [Transferring Vision-Language Models for Visual Recognition: A Classifier Perspective](https://link.springer.com/article/10.1007/s11263-023-01876-w).
28 |
29 | 🙋 I also have other cross-modal video projects that may interest you ✨.
30 |
31 |
32 | > [**Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models**](https://arxiv.org/abs/2301.00182)
33 | > Wenhao Wu, Xiaohan Wang, Haipeng Luo, Jingdong Wang, Yi Yang, Wanli Ouyang
34 | > [](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Bidirectional_Cross-Modal_Knowledge_Exploration_for_Video_Recognition_With_Pre-Trained_Vision-Language_CVPR_2023_paper.html) [](https://github.com/whwu95/BIKE)
35 |
36 |
37 | > [**Cap4Video: What Can Auxiliary Captions Do for Text-Video Retrieval?**](https://arxiv.org/abs/2301.00184)
38 | > Wenhao Wu, Haipeng Luo, Bo Fang, Jingdong Wang, Wanli Ouyang
39 | > Accepted by CVPR 2023 as 🌟Highlight🌟 | [](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Cap4Video_What_Can_Auxiliary_Captions_Do_for_Text-Video_Retrieval_CVPR_2023_paper.html) [](https://github.com/whwu95/Cap4Video)
40 |
41 |
42 |
43 |
44 |
45 |
46 | ## 📣 Updates
47 | - [x] **`Aug 07, 2023`** The extension of Text4Vis has been accepted by **International Journal of Computer Vision (IJCV)**.
48 | - [x] **`Dec 22, 2022`** Models: The pre-trained models & logs.
49 | - [x] **`Nov 30, 2022`** Config: All the configs (general/few-shot/zero-shot video recognition) on Kinetics-400 & 600, ActivityNet, UCF, and HMDB.
50 | - [x] **`Nov 30, 2022`** Code: Zero-shot Evaluation: Half-classes evaluation and Full-classes evaluation.
51 | - [x] **`Nov 28, 2022`** Code: Single-Machine/Multi-Machine Multi-GPU Distributed Training, Distributed testing.
52 | - [x] **`Nov 19, 2022`** 🎉Our paper has been accepted by **AAAI-2023**.
53 | - [x] **`Jul 1, 2022`** 💡Our [initial Arxiv paper](https://arxiv.org/abs/2207.01297v1) is released.
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 | ## 🌈 Overview
62 | In our Text4Vis, we revise the role of the linear classifier and replace the classifier with the different knowledge from pre-trained model. We utilize the well-pretrained language model to generate good semantic target for efficient transferring learning.
63 |
64 | 
65 | 
66 |
67 | ## Content
68 | - [Prerequisites](#prerequisites)
69 | - [Data Preparation](#data-preparation)
70 | - [Model Zoo](#model-zoo)
71 | - [Training](#training)
72 | - [Testing](#testing)
73 | - [BibTeX & Citation](#bibtex)
74 | - [Acknowledgment](#acknowledgment)
75 |
76 |
77 |
78 | ## 📕 Prerequisites
79 | The code is built with following libraries:
80 |
81 | - [PyTorch](https://pytorch.org/) >= 1.8
82 | - RandAugment
83 | - pprint
84 | - tqdm
85 | - dotmap
86 | - yaml
87 | - csv
88 | - Optional: decord (for on-the-fly video training)
89 | - Optional: torchnet (for mAP evaluation on ActivityNet)
90 |
91 |
92 | ## 📚 Data Preparation
93 |
94 | #### Video Loader
95 | **(Recommend)** To train all of our models, we extract videos into frames for fast reading. Please refer to [MVFNet](https://github.com/whwu95/MVFNet/blob/main/data_process/DATASETS.md) repo for the detaied guide of data processing.
96 | The annotation file is a text file with multiple lines, and each line indicates the directory to frames of a video, total frames of the video and the label of a video, which are split with a whitespace. Here is the format:
97 | ```sh
98 | abseiling/-7kbO0v4hag_000107_000117 300 0
99 | abseiling/-bwYZwnwb8E_000013_000023 300 0
100 | ```
101 |
102 | **(Optional)** We can also decode the videos in an online fashion using [decord](https://github.com/dmlc/decord). This manner should work but are not tested. All of the models offered have been trained using offline frames. Example of annotation:
103 | ```sh
104 | abseiling/-7kbO0v4hag_000107_000117.mp4 0
105 | abseiling/-bwYZwnwb8E_000013_000023.mp4 0
106 | ```
107 |
108 | #### Annotation
109 | Annotation information consists of two parts: video label, and category description.
110 |
111 | - Video Label: As mentioned above, this part is same as the traditional video recognition. Please refer to `lists/k400/kinetics_rgb_train_se320.txt` for the format.
112 | - Category Description: We also need a textual description for each video category. Please refer to `lists/kinetics_400_labels.csv` for the format.
113 |
114 |
115 |
116 | ## 📱 Model Zoo
117 |
118 | Here we provide some off-the-shelf pre-trained checkpoints of our models in the following tables.
119 |
120 | *#Frame = #input_frame x #spatial crops x #temporal clips*
121 | #### Kinetics-400
122 |
123 | | Architecture |#Frame | Top-1 Acc.(%) | checkpoint | Train log| config|
124 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|
125 | | ViT-B/32 | 8x3x4 | 80.0 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-32-f8.pt) | [log](exps/k400/ViT-B/32/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitb-32-f8.yaml) |
126 | | ViT-B/32 | 16x3x4 | 80.5 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-32-f16.pt) | [log](exps/k400/ViT-B/32/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitb-32-f16.yaml) |
127 | | ViT-B/16 | 8x3x4 | 82.9 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-16-f8.pt) | [log](exps/k400/ViT-B/16/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitb-16-f8.yaml) |
128 | | ViT-B/16 | 16x3x4 | 83.6 | [Github](https://github.com/whwu95/Text4Vis/releases/download/v1/k400-vitb-16-f16.pt)| [log](exps/k400/ViT-B/16/f16/log.txt) | [config](configs/k400/k400_train_rgb_vitb-16-f16.yaml) |
129 | | ViT-L/14* | 8x3x4 | 86.4 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-f8.yaml) |
130 | | ViT-L/14-336 | 8x3x4 | 87.1 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14-336px/f8/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f8.yaml) |
131 | | ViT-L/14-336 | 32x3x1 | 87.8 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/k400/ViT-L/14-336px/f32/log.txt) | [config](configs/k400/k400_train_rgb_vitl-14-336-f32.yaml) |
132 |
133 | *Note: * indicates that this ViT-L model is used for the zero-shot evaluation on UCF, HMDB, ActivityNet and Kinetics-600.*
134 |
135 | #### ActivityNet
136 | | Architecture |#Frame | mAP (%) | checkpoint | Train log| config|
137 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|
138 | | ViT-L/14 | 16x1x1 | 96.5 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [config](configs/anet/anet_k400_finetune.yaml) |
139 | | ViT-L/14-336 | 16x1x1 | 96.9 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/anet/ViT-L/14-336px/f16/log.txt) | [config](configs/anet/anet_k400_finetune_336.yaml) |
140 |
141 | #### UCF-101
142 | | Architecture |#Frame | Top-1 Acc. (%) | checkpoint | Train log| config|
143 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|
144 | | ViT-L/14 | 16x1x1 | 98.1 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/ucf101/ViT-L/14/f16/log.txt) | [config](configs/ucf101/ucf_k400_finetune.yaml) |
145 |
146 |
147 | #### HMDB-51
148 | | Architecture |#Frame | Top-1 Acc. (%) | checkpoint | Train log| config|
149 | |:------------:|:-------------------:|:------------------:|:-----------------:|:--------------:|:--------------:|
150 | | ViT-L/14 | 16x1x1 | 81.3 | [OneDrive](https://unisyd-my.sharepoint.com/:f:/g/personal/wenhao_wu_sydney_edu_au/EqsXAPrVnUFBv77XCk6CEzQBZXqEcWySWzr2MoIBfD29tw?e=T0HaTe) | [log](exps/hmdb51/ViT-L/14/f16/log.txt) | [config](configs/hmdb51/hmdb_k400_finetune.yaml) |
151 |
152 |
153 |
154 | ## 🚀 Training
155 | This implementation supports Multi-GPU `DistributedDataParallel` training, which is faster and simpler than `DataParallel` used in [ActionCLIP](https://github.com/sallymmx/actionclip).
156 |
157 | - **Single Machine**: To train our model on Kinetics-400 with 8 GPUs in *Single Machine*, you can run:
158 | ```sh
159 | # For example, train the 8 Frames ViT-B/32.
160 | sh scripts/run_train.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml
161 | ```
162 |
163 | - **Mulitple Machines**: We also provide the script to train larger model with *Mulitple Machines* (e.g., 2 machines and 16 GPUs), you can run:
164 | ```sh
165 | # For example, we train the 8 Frames ViT-L/14 with 2 machines as follows:
166 | # For first machine, you need to set the ip of your first machine as the --master_addr, --nnodes is 2.
167 | # Compared with the Single-Machine training script, only one node_id needs to be added.
168 | sh scripts/run_train_multinodes.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml 0
169 |
170 | # For second machine, --master_addr is still the ip of your first machine
171 | sh scripts/run_train_multinodes.sh configs/k400/k400_train_rgb_vitl-14-f8.yaml 1
172 | ```
173 |
174 | - **Few-shot Recognition**: To train our model under *Few-shot* scenario, you just need to add one line in the general config file:
175 | ```sh
176 | # You can refer to config/k400/k400_few_shot.yaml
177 | data:
178 | ... # general configurations
179 | shot: 2 # i.e., 2-shot setting
180 | ```
181 |
182 |
183 | ## ⚡ Testing
184 | We support single view validation and multi-view (4x3 views) validation.
185 |
186 | #### General/Few-shot Video Recognition
187 | ```sh
188 | # Single view evaluation. e.g., ViT-B/32 8 Frames on Kinetics-400
189 | sh scripts/run_test.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml exp/k400/ViT-B/32/f8/last_model.pt
190 |
191 | # Multi-view evalition (4clipsx3crops). e.g., ViT-B/32 8 Frames on Kinetics-400
192 | sh scripts/run_test.sh configs/k400/k400_train_rgb_vitb-32-f8.yaml exp/k400/ViT-B/32/f8/last_model.pt --test_crops 3 --test_clips 4
193 | ```
194 |
195 |
196 | #### Zero-shot Evaluation
197 |
198 | We use the Kinetics-400 pre-trained model (e.g., ViT-L/14 with 8 frames) to perform cross-dataset zero-shot evaluation, i.e., UCF101, HMDB51, ActivityNet, Kinetics-600.
199 |
200 | - Half-classes Evaluation: A traditional evaluation protocol involves selecting half of the test dataset's classes, repeating the process ten times, and reporting the mean accuracy with a standard deviation of ten times.
201 |
202 |
203 | - Full-classes Evaluation: Perform evaluation on the entire dataset.
204 |
205 | ```sh
206 | # On ActivityNet: reporting the half-classes and full-classes results
207 | sh scripts/run_test_zeroshot.sh configs/anet/anet_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt
208 |
209 | # On UCF101: reporting the half-classes and full-classes results
210 | sh scripts/run_test_zeroshot.sh configs/ucf101/ucf_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt
211 |
212 | # On HMDB51: reporting the half-classes and full-classes results
213 | sh scripts/run_test_zeroshot.sh configs/hmdb51/hmdb_zero_shot.yaml exp/k400/ViT-L/14/f8/last_model.pt
214 |
215 | # On Kinetics-600: manually calculating the mean accuracy with standard deviation of three splits.
216 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split1.yaml exp/k400/ViT-L/14/f8/last_model.pt
217 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split2.yaml exp/k400/ViT-L/14/f8/last_model.pt
218 | sh scripts/run_test.sh configs/k600/k600_zero_shot_split3.yaml exp/k400/ViT-L/14/f8/last_model.pt
219 | ```
220 |
221 |
222 |
223 |
224 | ## 📌 BibTeX & Citation
225 | If you find this repository useful, please star🌟 this repo and cite📑 our paper:
226 |
227 | ```bibtex
228 | @inproceedings{wu2023revisiting,
229 | title={Revisiting classifier: Transferring vision-language models for video recognition},
230 | author={Wu, Wenhao and Sun, Zhun and Ouyang, Wanli},
231 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence},
232 | volume={37},
233 | number={3},
234 | pages={2847--2855},
235 | year={2023}
236 | }
237 |
238 | @article{wu2023transferring,
239 | title={Transferring vision-language models for visual recognition: A classifier perspective},
240 | author={Wu, Wenhao and Sun, Zhun and Song, Yuxin and Wang, Jingdong and Ouyang, Wanli},
241 | journal={International Journal of Computer Vision},
242 | pages={1--18},
243 | year={2023},
244 | publisher={Springer}
245 | }
246 | ```
247 |
248 | If you also find [BIKE](https://github.com/whwu95/BIKE) useful, please cite the paper:
249 |
250 | ```bibtex
251 | @inproceedings{bike,
252 | title={Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models},
253 | author={Wu, Wenhao and Wang, Xiaohan and Luo, Haipeng and Wang, Jingdong and Yang, Yi and Ouyang, Wanli},
254 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
255 | year={2023}
256 | }
257 | ```
258 |
259 |
260 |
261 | ## 🎗️ Acknowledgement
262 |
263 | This repository is built based on [ActionCLIP](https://github.com/sallymmx/actionclip) and [CLIP](https://github.com/openai/CLIP). Sincere thanks to their wonderful works.
264 |
265 |
266 | ## 👫 Contact
267 | For any question, please file an issue.
268 |
269 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14-336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
40 | }
41 |
42 |
43 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(
95 | name: str,
96 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
97 | jit=True,
98 | internal_modeling=False, joint_st=False, T=8, dropout=0.,
99 | emb_dropout=0.,
100 | pretrain=True):
101 | """Load a CLIP model
102 |
103 | Parameters
104 | ----------
105 | name : str
106 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
107 |
108 | device : Union[str, torch.device]
109 | The device to put the loaded model
110 |
111 | jit : bool
112 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
113 |
114 | download_root: str
115 | path to download the model files; by default, it uses "~/.cache/clip"
116 |
117 | Returns
118 | -------
119 | model : torch.nn.Module
120 | The CLIP model
121 |
122 | preprocess : Callable[[PIL.Image], torch.Tensor]
123 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
124 | """
125 | if name in _MODELS:
126 | model_path = _download(_MODELS[name])
127 | elif os.path.isfile(name):
128 | model_path = name
129 | else:
130 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
131 |
132 | try:
133 | # loading JIT archive
134 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
135 | state_dict = None
136 | except RuntimeError:
137 | # loading saved state dict
138 | if jit:
139 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
140 | jit = False
141 | state_dict = torch.load(model_path, map_location="cpu")
142 |
143 | if not jit:
144 | model = build_model(state_dict or model.state_dict(), joint=joint_st, tm=internal_modeling, T=T, dropout=dropout, emb_dropout=emb_dropout, pretrain=pretrain).to(device)
145 | if str(device) == "cpu":
146 | model.float()
147 | return model, model.state_dict()
148 |
149 | # patch the device names
150 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
151 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
152 |
153 | def patch_device(module):
154 | try:
155 | graphs = [module.graph] if hasattr(module, "graph") else []
156 | except RuntimeError:
157 | graphs = []
158 |
159 | if hasattr(module, "forward1"):
160 | graphs.append(module.forward1.graph)
161 |
162 | for graph in graphs:
163 | for node in graph.findAllNodes("prim::Constant"):
164 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
165 | node.copyAttributes(device_node)
166 |
167 | model.apply(patch_device)
168 |
169 | if str(device) == "cpu":
170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
172 | float_node = float_input.node()
173 |
174 | def patch_float(module):
175 | try:
176 | graphs = [module.graph] if hasattr(module, "graph") else []
177 | except RuntimeError:
178 | graphs = []
179 |
180 | if hasattr(module, "forward1"):
181 | graphs.append(module.forward1.graph)
182 |
183 | for graph in graphs:
184 | for node in graph.findAllNodes("aten::to"):
185 | inputs = list(node.inputs())
186 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
187 | if inputs[i].node()["value"] == 5:
188 | inputs[i].node().copyAttributes(float_node)
189 |
190 | model.apply(patch_float)
191 | patch_float(model.encode_image)
192 | patch_float(model.encode_text)
193 |
194 | model.float()
195 |
196 | return model, _transform(model.input_resolution.item())
197 |
198 |
199 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
200 | """
201 | Returns the tokenized representation of given input string(s)
202 |
203 | Parameters
204 | ----------
205 | texts : Union[str, List[str]]
206 | An input string or a list of input strings to tokenize
207 |
208 | context_length : int
209 | The context length to use; all CLIP models use 77 as the context length
210 |
211 | truncate: bool
212 | Whether to truncate the text in case its encoding is longer than the context length
213 |
214 | Returns
215 | -------
216 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
217 | """
218 | if isinstance(texts, str):
219 | texts = [texts]
220 |
221 | sot_token = _tokenizer.encoder["<|startoftext|>"]
222 | eot_token = _tokenizer.encoder["<|endoftext|>"]
223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 |
226 | for i, tokens in enumerate(all_tokens):
227 | if len(tokens) > context_length:
228 | if truncate:
229 | tokens = tokens[:context_length]
230 | tokens[-1] = eot_token
231 | else:
232 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
233 | result[i, :len(tokens)] = torch.tensor(tokens)
234 |
235 | return result
236 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/anet/anet_few_shot.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: anet
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 200
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1'
15 | train_list: 'lists/anet/anet_train_instance_fps1.txt'
16 | val_list: 'lists/anet/anet_val_video_fps1.txt' #
17 | label_list: 'lists/anet1.3_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | shot: 2
21 | network:
22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
23 | init: True
24 | drop_out: 0.0
25 | emb_dropout: 0.0
26 | type: clip_anet
27 | sim_header: None
28 | drop: 0
29 | solver:
30 | type: cosine
31 | epochs: 30
32 | start_epoch: 0
33 | optim: adamw
34 | lr: 5.e-5
35 | lr_warmup_step: 5
36 | weight_decay: 0.2
37 | loss_type: CE
38 | evaluate: False
39 | clip_ratio: 0.1
40 | grad_accumulation_steps: 1
41 | logging:
42 | print_freq: 10
43 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/anet/anet_k400_finetune.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: anet
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 4
10 | workers: 4
11 | num_classes: 200
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1'
15 | train_list: 'lists/anet/anet_train_instance_fps1.txt'
16 | val_list: 'lists/anet/anet_val_video_fps1.txt' #
17 | label_list: 'lists/anet1.3_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_anet
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 2
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
43 |
44 |
--------------------------------------------------------------------------------
/configs/anet/anet_k400_finetune_336.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: anet
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 1
10 | workers: 4
11 | num_classes: 200
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/anet_instance_frames_v1.3_train_vids_fps1'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/anet/activitynet_val_resize_img_256_340_fps1'
15 | train_list: 'lists/anet/anet_train_instance_fps1.txt'
16 | val_list: 'lists/anet/anet_val_video_fps1.txt' #
17 | label_list: 'lists/anet1.3_labels.csv'
18 | input_size: 336
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14-336px
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_anet
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 8
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
43 |
44 |
--------------------------------------------------------------------------------
/configs/anet/anet_zero_shot.yaml:
--------------------------------------------------------------------------------
1 | seed: 1024
2 | data:
3 | dataset: anet
4 | modality: RGB
5 | num_segments: 8
6 | seg_length: 1
7 | batch_size: 32
8 | workers: 4
9 | num_classes: 200
10 | image_tmpl: 'image_{:06d}.jpg'
11 | val_root: /bpfs/v2_mnt/VIS/wuwenhao/anet
12 | val_list: 'lists/anet/anet_full_for_zeroshot.txt'
13 | label_list: 'lists/anet1.3_labels.csv'
14 | index_bias: 1
15 | input_size: 224
16 | network:
17 | arch: ViT-L/14 #ViT-L/14 #ViT-B/32 ViT-B/16
18 | init: True
19 | drop_out: 0.0
20 | emb_dropout: 0.0
21 | type: clip_anet
22 | sim_header: Transf
23 | logging:
24 | print_freq: 10
25 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/hmdb51/hmdb_few_shot.yaml:
--------------------------------------------------------------------------------
1 | pretrain:
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: hmdb51
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 51
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt'
16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt'
17 | label_list: 'lists/hmdb51_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | shot: 2
21 | network:
22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
23 | init: True
24 | drop_out: 0.0
25 | emb_dropout: 0.0
26 | type: clip_hmdb
27 | sim_header: None
28 | drop: 0
29 | solver:
30 | type: cosine
31 | epochs: 30
32 | start_epoch: 0
33 | optim: adamw
34 | lr: 5.e-5
35 | lr_warmup_step: 5
36 | weight_decay: 0.2
37 | loss_type: CE
38 | evaluate: False
39 | clip_ratio: 0.1
40 | grad_accumulation_steps: 1
41 | logging:
42 | print_freq: 10
43 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/hmdb51/hmdb_k400_finetune.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: hmdb51
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 4
10 | workers: 4
11 | num_classes: 51
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt'
16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt'
17 | label_list: 'lists/hmdb51_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_hmdb
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 2
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/hmdb51/hmdb_k400_finetune_336.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: hmdb51
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 1
10 | workers: 4
11 | num_classes: 51
12 | image_tmpl: 'image_{:06d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
15 | train_list: 'lists/hmdb51/train_rgb_split_1.txt'
16 | val_list: 'lists/hmdb51/val_rgb_split_1.txt'
17 | label_list: 'lists/hmdb51_labels.csv'
18 | input_size: 336
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_hmdb
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 8
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/hmdb51/hmdb_zero_shot.yaml:
--------------------------------------------------------------------------------
1 | seed: 1024
2 | data:
3 | dataset: hmdb51
4 | modality: RGB
5 | num_segments: 8
6 | seg_length: 1
7 | batch_size: 16
8 | workers: 8
9 | num_classes: 51
10 | image_tmpl: 'image_{:06d}.jpg'
11 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/hmdb51_rgb_img_256_340'
12 | val_list: 'lists/hmdb51/hmdb_full_for_zeroshot.txt'
13 | label_list: 'lists/hmdb51_labels.csv'
14 | index_bias: 1
15 | input_size: 224
16 | network:
17 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
18 | init: True
19 | drop_out: 0.0
20 | emb_dropout: 0.0
21 | type: clip_hmdb
22 | sim_header: Transf
23 | logging:
24 | print_freq: 10
25 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_few_shot.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | index_bias: 1
19 | input_size: 224
20 | randaug:
21 | N: 2 #2
22 | M: 9 #9
23 | random_shift: True
24 | shot: 2
25 | network:
26 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
27 | init: True
28 | drop_out: 0.0
29 | emb_dropout: 0.0
30 | type: clip_k400
31 | sim_header: None
32 | drop: 0
33 | solver:
34 | type: cosine
35 | epochs: 30
36 | start_epoch: 0
37 | epoch_offset: 0
38 | optim: adamw
39 | lr: 5.e-5
40 | lr_warmup_step: 5
41 | weight_decay: 0.2
42 | loss_type: CE
43 | evaluate: False
44 | clip_ratio: 0.1
45 | grad_accumulation_steps: 1
46 | logging:
47 | print_freq: 10
48 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_rn50.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 32
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: RN50 # RN50 RN101 RN50x4 RN50x16 RN50x64
25 | init: True
26 | tm: False # False tsm
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 1
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitb-16-f16.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 16
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-B/16 #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 2
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitb-16-f8.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 32
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-B/16 #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 1
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitb-32-f16.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 32
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-B/32 #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 1
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitb-32-f8.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 32
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-B/32 #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 1
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitl-14-336-f32.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 32
8 | seg_length: 1
9 | batch_size: 1
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 336
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 8
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitl-14-336-f8.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 2
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 336
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 4
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k400/k400_train_rgb_vitl-14-f8.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k400
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 400
12 | image_tmpl: 'img_{:05d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
14 | train_list: 'lists/k400/kinetics_rgb_train_se320.txt'
15 | val_root: /bpfs/v2_mnt/VIS/test/k400/kinetics_400_val_320_opencv
16 | val_list: lists/k400/kinetics_rgb_val_se320.txt
17 | label_list: 'lists/kinetics_400_labels.csv'
18 | input_size: 224
19 | randaug:
20 | N: 2 #2
21 | M: 9 #9
22 | random_shift: True
23 | network:
24 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
25 | init: True
26 | tm: False # False tsm tokent1d tokenshift
27 | drop_out: 0.0
28 | emb_dropout: 0.0
29 | type: clip_k400
30 | sim_header: Transf # Transf None
31 | joint_st: False
32 | drop: 0
33 | fix_text: True
34 | fix_video: False
35 | solver:
36 | type: cosine
37 | epochs: 30
38 | start_epoch: 0
39 | epoch_offset: 0
40 | optim: adamw
41 | lr: 5.e-5
42 | lr_warmup_step: 5
43 | weight_decay: 0.2
44 | loss_type: CE
45 | evaluate: False
46 | clip_ratio: 0.1
47 | grad_accumulation_steps: 2
48 | logging:
49 | print_freq: 10
50 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k600/k600_zero_shot_split1.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k600
6 | modality: video
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 160
12 | image_tmpl: 'img_{:05d}.jpg'
13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video'
14 | val_list: lists/k600/test_split1_exist.txt
15 | label_list: lists/k600/k160_labels_split1.csv
16 | index_bias: 1
17 | input_size: 224
18 | network:
19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
20 | init: True
21 | drop_out: 0.0
22 | emb_dropout: 0.0
23 | type: clip_k600
24 | sim_header: Transf
25 | drop: 0
26 | logging:
27 | print_freq: 10
28 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k600/k600_zero_shot_split2.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k600
6 | modality: video
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 160
12 | image_tmpl: 'img_{:05d}.jpg'
13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video'
14 | val_list: lists/k600/test_split2_exist.txt
15 | label_list: lists/k600/k160_labels_split2.csv
16 | index_bias: 1
17 | input_size: 224
18 | network:
19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
20 | init: True
21 | drop_out: 0.0
22 | emb_dropout: 0.0
23 | type: clip_k600
24 | sim_header: Transf
25 | drop: 0
26 | logging:
27 | print_freq: 10
28 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/k600/k600_zero_shot_split3.yaml:
--------------------------------------------------------------------------------
1 | resume:
2 | pretrain:
3 | seed: 1024
4 | data:
5 | dataset: k600
6 | modality: video
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 160
12 | image_tmpl: 'img_{:05d}.jpg'
13 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/k600_test_video'
14 | val_list: lists/k600/test_split3_exist.txt
15 | label_list: lists/k600/k160_labels_split3.csv
16 | index_bias: 1
17 | input_size: 224
18 | network:
19 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
20 | init: True
21 | drop_out: 0.0
22 | emb_dropout: 0.0
23 | type: clip_k600
24 | sim_header: Transf
25 | drop: 0
26 | logging:
27 | print_freq: 10
28 | eval_freq: 1
--------------------------------------------------------------------------------
/configs/ucf101/ucf_few_shot.yaml:
--------------------------------------------------------------------------------
1 | pretrain:
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: ucf101
6 | modality: RGB
7 | num_segments: 8
8 | seg_length: 1
9 | batch_size: 8
10 | workers: 4
11 | num_classes: 101
12 | image_tmpl: 'image_{:04d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
15 | train_list: 'lists/ucf101/train_rgb_split_1.txt'
16 | val_list: 'lists/ucf101/val_rgb_split_1.txt'
17 | label_list: 'lists/ucf_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | shot: 2
21 | network:
22 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
23 | init: True
24 | drop_out: 0.0
25 | emb_dropout: 0.0
26 | type: clip_hmdb
27 | sim_header: None
28 | drop: 0
29 | solver:
30 | type: cosine
31 | epochs: 30
32 | start_epoch: 0
33 | optim: adamw
34 | lr: 5.e-5
35 | lr_warmup_step: 5
36 | weight_decay: 0.2
37 | loss_type: CE
38 | evaluate: False
39 | clip_ratio: 0.1
40 | grad_accumulation_steps: 1
41 | logging:
42 | print_freq: 10
43 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/ucf101/ucf_k400_finetune.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: ucf101
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 4
10 | workers: 4
11 | num_classes: 101
12 | image_tmpl: 'image_{:04d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
15 | train_list: 'lists/ucf101/train_rgb_split_1.txt'
16 | val_list: 'lists/ucf101/val_rgb_split_1.txt'
17 | label_list: 'lists/ucf_labels.csv'
18 | input_size: 224
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14 #ViT-B/32 ViT-B/16
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_ucf
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 2
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/ucf101/ucf_k400_finetune_336.yaml:
--------------------------------------------------------------------------------
1 | pretrain: exp_sota/k400/ViT-L/14-336px/f16/last_model.pt
2 | resume:
3 | seed: 1024
4 | data:
5 | dataset: ucf101
6 | modality: RGB
7 | num_segments: 16
8 | seg_length: 1
9 | batch_size: 1
10 | workers: 4
11 | num_classes: 101
12 | image_tmpl: 'image_{:04d}.jpg'
13 | train_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
14 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
15 | train_list: 'lists/ucf101/train_rgb_split_1.txt'
16 | val_list: 'lists/ucf101/val_rgb_split_1.txt'
17 | label_list: 'lists/ucf_labels.csv'
18 | input_size: 336
19 | random_shift: True
20 | network:
21 | arch: ViT-L/14-336px #ViT-B/32 ViT-B/16
22 | init: True
23 | drop_out: 0.0
24 | emb_dropout: 0.0
25 | type: clip_ucf
26 | sim_header: Transf
27 | drop: 0
28 | solver:
29 | type: cosine
30 | epochs: 30
31 | start_epoch: 0
32 | optim: adamw
33 | lr: 5.e-5
34 | lr_warmup_step: 5
35 | weight_decay: 0.2
36 | loss_type: CE
37 | evaluate: False
38 | clip_ratio: 0.1
39 | grad_accumulation_steps: 8
40 | logging:
41 | print_freq: 10
42 | eval_freq: 5
--------------------------------------------------------------------------------
/configs/ucf101/ucf_zero_shot.yaml:
--------------------------------------------------------------------------------
1 | seed: 1024
2 | data:
3 | dataset: ucf101
4 | modality: RGB
5 | num_segments: 8
6 | seg_length: 1
7 | batch_size: 16
8 | workers: 8
9 | num_classes: 101
10 | image_tmpl: 'image_{:04d}.jpg'
11 | val_root: '/bpfs/v2_mnt/VIS/wuwenhao/UCF101-frames'
12 | val_list: 'lists/ucf101/ucf_full_for_zeroshot.txt' #
13 | label_list: 'lists/ucf_labels.csv'
14 | index_bias: 1
15 | input_size: 224
16 | network:
17 | arch: ViT-L/14
18 | init: True
19 | drop_out: 0.0
20 | emb_dropout: 0.0
21 | type: clip_ucf
22 | sim_header: Transf
23 | logging:
24 | print_freq: 10
25 | eval_freq: 1
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import *
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import decord
4 | import os
5 | import numpy as np
6 | from numpy.random import randint
7 | import io
8 | import pandas as pd
9 | import random
10 | from PIL import Image
11 | import math
12 | import copy
13 |
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | return int(self._data[1])
26 |
27 | @property
28 | def label(self):
29 | return int(self._data[-1])
30 |
31 |
32 | class Video_dataset(data.Dataset):
33 | def __init__(self, root_path, list_file, labels_file,
34 | num_segments=1, modality='RGB', new_length=1,
35 | image_tmpl='img_{:05d}.jpg', transform=None,
36 | random_shift=True, test_mode=False,
37 | index_bias=1, dense_sample=False, test_clips=3):
38 |
39 | self.root_path = root_path
40 | self.list_file = list_file
41 | self.num_segments = num_segments
42 | self.modality = modality
43 | self.seg_length = new_length
44 | self.image_tmpl = image_tmpl
45 | self.transform = transform
46 | self.random_shift = random_shift
47 | self.test_mode = test_mode
48 | self.loop=False
49 | self.index_bias = index_bias
50 | self.labels_file = labels_file
51 | self.sample_range = 128
52 | self.dense_sample = dense_sample # using dense sample as I3D
53 | self.test_clips = test_clips
54 | if self.dense_sample:
55 | print('=> Using dense sample for the dataset...')
56 |
57 | if self.index_bias is None:
58 | if self.image_tmpl == "frame{:d}.jpg":
59 | self.index_bias = 0
60 | else:
61 | self.index_bias = 1
62 | self._parse_list()
63 | self.initialized = False
64 |
65 | @property
66 | def total_length(self):
67 | return self.num_segments * self.seg_length
68 |
69 | @property
70 | def classes(self):
71 | classes_all = pd.read_csv(self.labels_file)
72 | return classes_all.values.tolist()
73 |
74 | def _parse_list(self):
75 | # check the frame number is large >3:
76 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
77 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config
78 | if not self.test_mode:
79 | tmp = [item for item in tmp if int(item[1]) >= 8]
80 | self.video_list = [VideoRecord(item) for item in tmp]
81 | print('video number:%d' % (len(self.video_list)))
82 |
83 | def _sample_indices(self, video_list):
84 | if self.dense_sample:
85 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
86 | interval = self.sample_range // self.num_segments
87 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
88 | base_offsets = np.arange(self.num_segments) * interval
89 | offsets = (base_offsets + start_idx) % len(video_list)
90 | return np.array(offsets) + self.index_bias
91 | else:
92 | if len(video_list) <= self.total_length:
93 | if self.loop:
94 | return np.mod(np.arange(
95 | self.total_length) + randint(len(video_list) // 2),
96 | len(video_list)) + self.index_bias
97 | offsets = np.concatenate((
98 | np.arange(len(video_list)),
99 | randint(len(video_list),
100 | size=self.total_length - len(video_list))))
101 | return np.sort(offsets) + self.index_bias
102 | offsets = list()
103 | ticks = [i * len(video_list) // self.num_segments
104 | for i in range(self.num_segments + 1)]
105 |
106 | for i in range(self.num_segments):
107 | tick_len = ticks[i + 1] - ticks[i]
108 | tick = ticks[i]
109 | if tick_len >= self.seg_length:
110 | tick += randint(tick_len - self.seg_length + 1)
111 | offsets.extend([j for j in range(tick, tick + self.seg_length)])
112 | return np.array(offsets) + self.index_bias
113 |
114 | def _get_val_indices(self, video_list):
115 | if self.dense_sample:
116 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
117 | t_stride = self.sample_range // self.num_segments
118 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
119 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)]
120 | return np.array(offsets) + self.index_bias
121 | else:
122 | tick = len(video_list) / float(self.num_segments)
123 | offsets = [int(tick * x) % len(video_list) for x in range(self.num_segments)]
124 | return np.array(offsets) + self.index_bias
125 |
126 |
127 | def _get_test_indices(self, video_list):
128 | if self.dense_sample:
129 | # multi-clip for dense sampling
130 | num_clips = self.test_clips
131 | sample_pos = max(0, len(video_list) - self.sample_range)
132 | interval = self.sample_range // self.num_segments
133 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)]
134 | base_offsets = np.arange(self.num_segments) * interval
135 | offsets = []
136 | for start_idx in start_list:
137 | offsets.extend((base_offsets + start_idx) % len(video_list))
138 | return np.array(offsets) + self.index_bias
139 | else:
140 | # multi-clip for uniform sampling
141 | num_clips = self.test_clips
142 | tick = len(video_list) / float(self.num_segments)
143 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int)
144 | offsets = []
145 | for start_idx in start_list.tolist():
146 | offsets += [
147 | int(start_idx + tick * x) % len(video_list)
148 | for x in range(self.num_segments)
149 | ]
150 | return np.array(offsets) + self.index_bias
151 |
152 |
153 |
154 | def _decord_decode(self, video_path):
155 | try:
156 | container = decord.VideoReader(video_path)
157 | except Exception as e:
158 | print("Failed to decode {} with exception: {}".format(
159 | video_path, e))
160 | return None
161 |
162 | return container
163 |
164 | def __getitem__(self, index):
165 | # decode frames to video_list
166 | if self.modality == 'video':
167 | _num_retries = 10
168 | for i_try in range(_num_retries):
169 | record = copy.deepcopy(self.video_list[index])
170 | directory = os.path.join(self.root_path, record.path)
171 | video_list = self._decord_decode(directory)
172 | # video_list = self._decord_pyav(directory)
173 | if video_list is None:
174 | print("Failed to decode video idx {} from {}; trial {}".format(
175 | index, directory, i_try)
176 | )
177 | index = random.randint(0, len(self.video_list))
178 | continue
179 | break
180 | else:
181 | record = self.video_list[index]
182 | video_list = os.listdir(os.path.join(self.root_path, record.path))
183 |
184 | if not self.test_mode: # train/val
185 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list)
186 | else: # test
187 | segment_indices = self._get_test_indices(video_list)
188 |
189 | return self.get(record, video_list, segment_indices)
190 |
191 |
192 | def _load_image(self, directory, idx):
193 | if self.modality == 'RGB':
194 | try:
195 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
196 | except Exception:
197 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
198 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
199 |
200 |
201 | def get(self, record, video_list, indices):
202 | images = list()
203 | for seg_ind in indices:
204 | p = int(seg_ind)
205 | if self.modality == 'video':
206 | seg_imgs = [Image.fromarray(video_list[p-1].asnumpy()).convert('RGB')]
207 | else:
208 | seg_imgs = self._load_image(record.path,p)
209 | images.extend(seg_imgs)
210 | if p < len(video_list):
211 | p += 1
212 | process_data, record_label = self.transform((images,record.label))
213 | return process_data, record_label
214 |
215 | def __len__(self):
216 | return len(self.video_list)
217 |
--------------------------------------------------------------------------------
/lists/anet/anet1.3_label2name.json:
--------------------------------------------------------------------------------
1 | {"0": "Applying sunscreen", "1": "Archery", "2": "Arm wrestling", "3": "Assembling bicycle", "4": "BMX", "5": "Baking cookies", "6": "Ballet", "7": "Bathing dog", "8": "Baton twirling", "9": "Beach soccer", "10": "Beer pong", "11": "Belly dance", "12": "Blow-drying hair", "13": "Blowing leaves", "14": "Braiding hair", "15": "Breakdancing", "16": "Brushing hair", "17": "Brushing teeth", "18": "Building sandcastles", "19": "Bullfighting", "20": "Bungee jumping", "21": "Calf roping", "22": "Camel ride", "23": "Canoeing", "24": "Capoeira", "25": "Carving jack-o-lanterns", "26": "Changing car wheel", "27": "Cheerleading", "28": "Chopping wood", "29": "Clean and jerk", "30": "Cleaning shoes", "31": "Cleaning sink", "32": "Cleaning windows", "33": "Clipping cat claws", "34": "Cricket", "35": "Croquet", "36": "Cumbia", "37": "Curling", "38": "Cutting the grass", "39": "Decorating the Christmas tree", "40": "Disc dog", "41": "Discus throw", "42": "Dodgeball", "43": "Doing a powerbomb", "44": "Doing crunches", "45": "Doing fencing", "46": "Doing karate", "47": "Doing kickboxing", "48": "Doing motocross", "49": "Doing nails", "50": "Doing step aerobics", "51": "Drinking beer", "52": "Drinking coffee", "53": "Drum corps", "54": "Elliptical trainer", "55": "Fixing bicycle", "56": "Fixing the roof", "57": "Fun sliding down", "58": "Futsal", "59": "Gargling mouthwash", "60": "Getting a haircut", "61": "Getting a piercing", "62": "Getting a tattoo", "63": "Grooming dog", "64": "Grooming horse", "65": "Hammer throw", "66": "Hand car wash", "67": "Hand washing clothes", "68": "Hanging wallpaper", "69": "Having an ice cream", "70": "High jump", "71": "Hitting a pinata", "72": "Hopscotch", "73": "Horseback riding", "74": "Hula hoop", "75": "Hurling", "76": "Ice fishing", "77": "Installing carpet", "78": "Ironing clothes", "79": "Javelin throw", "80": "Kayaking", "81": "Kite flying", "82": "Kneeling", "83": "Knitting", "84": "Laying tile", "85": "Layup drill in basketball", "86": "Long jump", "87": "Longboarding", "88": "Making a cake", "89": "Making a lemonade", "90": "Making a sandwich", "91": "Making an omelette", "92": "Mixing drinks", "93": "Mooping floor", "94": "Mowing the lawn", "95": "Paintball", "96": "Painting", "97": "Painting fence", "98": "Painting furniture", "99": "Peeling potatoes", "100": "Ping-pong", "101": "Plastering", "102": "Plataform diving", "103": "Playing accordion", "104": "Playing badminton", "105": "Playing bagpipes", "106": "Playing beach volleyball", "107": "Playing blackjack", "108": "Playing congas", "109": "Playing drums", "110": "Playing field hockey", "111": "Playing flauta", "112": "Playing guitarra", "113": "Playing harmonica", "114": "Playing ice hockey", "115": "Playing kickball", "116": "Playing lacrosse", "117": "Playing piano", "118": "Playing polo", "119": "Playing pool", "120": "Playing racquetball", "121": "Playing rubik cube", "122": "Playing saxophone", "123": "Playing squash", "124": "Playing ten pins", "125": "Playing violin", "126": "Playing water polo", "127": "Pole vault", "128": "Polishing forniture", "129": "Polishing shoes", "130": "Powerbocking", "131": "Preparing pasta", "132": "Preparing salad", "133": "Putting in contact lenses", "134": "Putting on makeup", "135": "Putting on shoes", "136": "Rafting", "137": "Raking leaves", "138": "Removing curlers", "139": "Removing ice from car", "140": "Riding bumper cars", "141": "River tubing", "142": "Rock climbing", "143": "Rock-paper-scissors", "144": "Rollerblading", "145": "Roof shingle removal", "146": "Rope skipping", "147": "Running a marathon", "148": "Sailing", "149": "Scuba diving", "150": "Sharpening knives", "151": "Shaving", "152": "Shaving legs", "153": "Shot put", "154": "Shoveling snow", "155": "Shuffleboard", "156": "Skateboarding", "157": "Skiing", "158": "Slacklining", "159": "Smoking a cigarette", "160": "Smoking hookah", "161": "Snatch", "162": "Snow tubing", "163": "Snowboarding", "164": "Spinning", "165": "Spread mulch", "166": "Springboard diving", "167": "Starting a campfire", "168": "Sumo", "169": "Surfing", "170": "Swimming", "171": "Swinging at the playground", "172": "Table soccer", "173": "Tai chi", "174": "Tango", "175": "Tennis serve with ball bouncing", "176": "Throwing darts", "177": "Trimming branches or hedges", "178": "Triple jump", "179": "Tug of war", "180": "Tumbling", "181": "Using parallel bars", "182": "Using the balance beam", "183": "Using the monkey bar", "184": "Using the pommel horse", "185": "Using the rowing machine", "186": "Using uneven bars", "187": "Vacuuming floor", "188": "Volleyball", "189": "Wakeboarding", "190": "Walking the dog", "191": "Washing dishes", "192": "Washing face", "193": "Washing hands", "194": "Waterskiing", "195": "Waxing skis", "196": "Welding", "197": "Windsurfing", "198": "Wrapping presents", "199": "Zumba"}
--------------------------------------------------------------------------------
/lists/anet1.3_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,Applying sunscreen
3 | 1,Archery
4 | 2,Arm wrestling
5 | 3,Assembling bicycle
6 | 4,BMX
7 | 5,Baking cookies
8 | 6,Ballet
9 | 7,Bathing dog
10 | 8,Baton twirling
11 | 9,Beach soccer
12 | 10,Beer pong
13 | 11,Belly dance
14 | 12,Blow-drying hair
15 | 13,Blowing leaves
16 | 14,Braiding hair
17 | 15,Breakdancing
18 | 16,Brushing hair
19 | 17,Brushing teeth
20 | 18,Building sandcastles
21 | 19,Bullfighting
22 | 20,Bungee jumping
23 | 21,Calf roping
24 | 22,Camel ride
25 | 23,Canoeing
26 | 24,Capoeira
27 | 25,Carving jack-o-lanterns
28 | 26,Changing car wheel
29 | 27,Cheerleading
30 | 28,Chopping wood
31 | 29,Clean and jerk
32 | 30,Cleaning shoes
33 | 31,Cleaning sink
34 | 32,Cleaning windows
35 | 33,Clipping cat claws
36 | 34,Cricket
37 | 35,Croquet
38 | 36,Cumbia
39 | 37,Curling
40 | 38,Cutting the grass
41 | 39,Decorating the Christmas tree
42 | 40,Disc dog
43 | 41,Discus throw
44 | 42,Dodgeball
45 | 43,Doing a powerbomb
46 | 44,Doing crunches
47 | 45,Doing fencing
48 | 46,Doing karate
49 | 47,Doing kickboxing
50 | 48,Doing motocross
51 | 49,Doing nails
52 | 50,Doing step aerobics
53 | 51,Drinking beer
54 | 52,Drinking coffee
55 | 53,Drum corps
56 | 54,Elliptical trainer
57 | 55,Fixing bicycle
58 | 56,Fixing the roof
59 | 57,Fun sliding down
60 | 58,Futsal
61 | 59,Gargling mouthwash
62 | 60,Getting a haircut
63 | 61,Getting a piercing
64 | 62,Getting a tattoo
65 | 63,Grooming dog
66 | 64,Grooming horse
67 | 65,Hammer throw
68 | 66,Hand car wash
69 | 67,Hand washing clothes
70 | 68,Hanging wallpaper
71 | 69,Having an ice cream
72 | 70,High jump
73 | 71,Hitting a pinata
74 | 72,Hopscotch
75 | 73,Horseback riding
76 | 74,Hula hoop
77 | 75,Hurling
78 | 76,Ice fishing
79 | 77,Installing carpet
80 | 78,Ironing clothes
81 | 79,Javelin throw
82 | 80,Kayaking
83 | 81,Kite flying
84 | 82,Kneeling
85 | 83,Knitting
86 | 84,Laying tile
87 | 85,Layup drill in basketball
88 | 86,Long jump
89 | 87,Longboarding
90 | 88,Making a cake
91 | 89,Making a lemonade
92 | 90,Making a sandwich
93 | 91,Making an omelette
94 | 92,Mixing drinks
95 | 93,Mooping floor
96 | 94,Mowing the lawn
97 | 95,Paintball
98 | 96,Painting
99 | 97,Painting fence
100 | 98,Painting furniture
101 | 99,Peeling potatoes
102 | 100,Ping-pong
103 | 101,Plastering
104 | 102,Plataform diving
105 | 103,Playing accordion
106 | 104,Playing badminton
107 | 105,Playing bagpipes
108 | 106,Playing beach volleyball
109 | 107,Playing blackjack
110 | 108,Playing congas
111 | 109,Playing drums
112 | 110,Playing field hockey
113 | 111,Playing flauta
114 | 112,Playing guitarra
115 | 113,Playing harmonica
116 | 114,Playing ice hockey
117 | 115,Playing kickball
118 | 116,Playing lacrosse
119 | 117,Playing piano
120 | 118,Playing polo
121 | 119,Playing pool
122 | 120,Playing racquetball
123 | 121,Playing rubik cube
124 | 122,Playing saxophone
125 | 123,Playing squash
126 | 124,Playing ten pins
127 | 125,Playing violin
128 | 126,Playing water polo
129 | 127,Pole vault
130 | 128,Polishing forniture
131 | 129,Polishing shoes
132 | 130,Powerbocking
133 | 131,Preparing pasta
134 | 132,Preparing salad
135 | 133,Putting in contact lenses
136 | 134,Putting on makeup
137 | 135,Putting on shoes
138 | 136,Rafting
139 | 137,Raking leaves
140 | 138,Removing curlers
141 | 139,Removing ice from car
142 | 140,Riding bumper cars
143 | 141,River tubing
144 | 142,Rock climbing
145 | 143,Rock-paper-scissors
146 | 144,Rollerblading
147 | 145,Roof shingle removal
148 | 146,Rope skipping
149 | 147,Running a marathon
150 | 148,Sailing
151 | 149,Scuba diving
152 | 150,Sharpening knives
153 | 151,Shaving
154 | 152,Shaving legs
155 | 153,Shot put
156 | 154,Shoveling snow
157 | 155,Shuffleboard
158 | 156,Skateboarding
159 | 157,Skiing
160 | 158,Slacklining
161 | 159,Smoking a cigarette
162 | 160,Smoking hookah
163 | 161,Snatch
164 | 162,Snow tubing
165 | 163,Snowboarding
166 | 164,Spinning
167 | 165,Spread mulch
168 | 166,Springboard diving
169 | 167,Starting a campfire
170 | 168,Sumo
171 | 169,Surfing
172 | 170,Swimming
173 | 171,Swinging at the playground
174 | 172,Table soccer
175 | 173,Tai chi
176 | 174,Tango
177 | 175,Tennis serve with ball bouncing
178 | 176,Throwing darts
179 | 177,Trimming branches or hedges
180 | 178,Triple jump
181 | 179,Tug of war
182 | 180,Tumbling
183 | 181,Using parallel bars
184 | 182,Using the balance beam
185 | 183,Using the monkey bar
186 | 184,Using the pommel horse
187 | 185,Using the rowing machine
188 | 186,Using uneven bars
189 | 187,Vacuuming floor
190 | 188,Volleyball
191 | 189,Wakeboarding
192 | 190,Walking the dog
193 | 191,Washing dishes
194 | 192,Washing face
195 | 193,Washing hands
196 | 194,Waterskiing
197 | 195,Waxing skis
198 | 196,Welding
199 | 197,Windsurfing
200 | 198,Wrapping presents
201 | 199,Zumba
--------------------------------------------------------------------------------
/lists/hmdb51_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,brush hair
3 | 1,cartwheel
4 | 2,catch
5 | 3,chew
6 | 4,clap
7 | 5,climb
8 | 6,climb stairs
9 | 7,dive
10 | 8,draw sword
11 | 9,dribble
12 | 10,drink
13 | 11,eat
14 | 12,fall floor
15 | 13,fencing
16 | 14,flic flac
17 | 15,golf
18 | 16,handstand
19 | 17,hit
20 | 18,hug
21 | 19,jump
22 | 20,kick
23 | 21,kick ball
24 | 22,kiss
25 | 23,laugh
26 | 24,pick
27 | 25,pour
28 | 26,pullup
29 | 27,punch
30 | 28,push
31 | 29,pushup
32 | 30,ride bike
33 | 31,ride horse
34 | 32,run
35 | 33,shake hands
36 | 34,shoot ball
37 | 35,shoot bow
38 | 36,shoot gun
39 | 37,sit
40 | 38,situp
41 | 39,smile
42 | 40,smoke
43 | 41,somersault
44 | 42,stand
45 | 43,swing baseball
46 | 44,sword
47 | 45,sword exercise
48 | 46,talk
49 | 47,throw
50 | 48,turn
51 | 49,walk
52 | 50,wave
53 |
--------------------------------------------------------------------------------
/lists/k600/k160_labels_split1.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,acting in play
3 | 1,adjusting glasses
4 | 2,arguing
5 | 3,attending conference
6 | 4,backflip (human)
7 | 5,base jumping
8 | 6,bathing dog
9 | 7,battle rope training
10 | 8,blowdrying hair
11 | 9,blowing bubble gum
12 | 10,bodysurfing
13 | 11,bottling
14 | 12,breaking boards
15 | 13,breathing fire
16 | 14,building sandcastle
17 | 15,bull fighting
18 | 16,bulldozing
19 | 17,burping
20 | 18,calligraphy
21 | 19,capsizing
22 | 20,casting fishing line
23 | 21,changing gear in car
24 | 22,chiseling stone
25 | 23,chiseling wood
26 | 24,chopping meat
27 | 25,coloring in
28 | 26,combing hair
29 | 27,cooking scallops
30 | 28,cracking knuckles
31 | 29,crossing eyes
32 | 30,cumbia
33 | 31,curling (sport)
34 | 32,cutting apple
35 | 33,cutting orange
36 | 34,delivering mail
37 | 35,docking boat
38 | 36,doing jigsaw puzzle
39 | 37,drooling
40 | 38,dumpster diving
41 | 39,dyeing eyebrows
42 | 40,embroidering
43 | 41,fencing (sport)
44 | 42,fixing bicycle
45 | 43,flint knapping
46 | 44,fly tying
47 | 45,geocaching
48 | 46,getting a piercing
49 | 47,gold panning
50 | 48,gospel singing in church
51 | 49,head stand
52 | 50,home roasting coffee
53 | 51,hugging baby
54 | 52,ice swimming
55 | 53,ironing hair
56 | 54,jaywalking
57 | 55,jumping bicycle
58 | 56,jumping jacks
59 | 57,karaoke
60 | 58,lawn mower racing
61 | 59,laying concrete
62 | 60,laying stone
63 | 61,lifting hat
64 | 62,lighting fire
65 | 63,lock picking
66 | 64,longboarding
67 | 65,luge
68 | 66,making cheese
69 | 67,making paper aeroplanes
70 | 68,marriage proposal
71 | 69,massaging neck
72 | 70,moon walking
73 | 71,mosh pit dancing
74 | 72,mountain climber (exercise)
75 | 73,mushroom foraging
76 | 74,needle felting
77 | 75,opening wine bottle
78 | 76,packing
79 | 77,passing soccer ball
80 | 78,photobombing
81 | 79,photocopying
82 | 80,pinching
83 | 81,pirouetting
84 | 82,planing wood
85 | 83,playing beer pong
86 | 84,playing blackjack
87 | 85,playing darts
88 | 86,playing field hockey
89 | 87,playing gong
90 | 88,playing hand clapping games
91 | 89,playing laser tag
92 | 90,playing lute
93 | 91,playing maracas
94 | 92,playing marbles
95 | 93,playing ocarina
96 | 94,playing pan pipes
97 | 95,playing pinball
98 | 96,playing polo
99 | 97,playing rubiks cube
100 | 98,playing with trains
101 | 99,poking bellybutton
102 | 100,popping balloons
103 | 101,preparing salad
104 | 102,pushing wheelbarrow
105 | 103,putting in contact lenses
106 | 104,putting on eyeliner
107 | 105,putting on foundation
108 | 106,putting on lipstick
109 | 107,putting on mascara
110 | 108,putting on sari
111 | 109,putting on shoes
112 | 110,raising eyebrows
113 | 111,repairing puncture
114 | 112,riding snow blower
115 | 113,roasting pig
116 | 114,rolling pastry
117 | 115,rope pushdown
118 | 116,sausage making
119 | 117,sawing wood
120 | 118,scrubbing face
121 | 119,separating eggs
122 | 120,sewing
123 | 121,shaping bread dough
124 | 122,shining flashlight
125 | 123,shucking oysters
126 | 124,sipping cup
127 | 125,skiing mono
128 | 126,sleeping
129 | 127,smelling feet
130 | 128,smoking pipe
131 | 129,square dancing
132 | 130,standing on hands
133 | 131,steer roping
134 | 132,sucking lolly
135 | 133,swinging baseball bat
136 | 134,tackling
137 | 135,tagging graffiti
138 | 136,talking on cell phone
139 | 137,tasting wine
140 | 138,threading needle
141 | 139,throwing knife
142 | 140,throwing snowballs
143 | 141,throwing tantrum
144 | 142,tie dying
145 | 143,tiptoeing
146 | 144,trimming shrubs
147 | 145,tying shoe laces
148 | 146,using a microscope
149 | 147,using a power drill
150 | 148,using a sledge hammer
151 | 149,using a wrench
152 | 150,using atm
153 | 151,using puppets
154 | 152,vacuuming floor
155 | 153,visiting the zoo
156 | 154,wading through water
157 | 155,watching tv
158 | 156,waving hand
159 | 157,winking
160 | 158,wood burning (art)
161 | 159,yarn spinning
162 |
--------------------------------------------------------------------------------
/lists/k600/k160_labels_split2.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,acting in play
3 | 1,arguing
4 | 2,assembling bicycle
5 | 3,backflip (human)
6 | 4,base jumping
7 | 5,bathing dog
8 | 6,battle rope training
9 | 7,blowing bubble gum
10 | 8,bottling
11 | 9,breathing fire
12 | 10,building lego
13 | 11,building sandcastle
14 | 12,bull fighting
15 | 13,calculating
16 | 14,calligraphy
17 | 15,card stacking
18 | 16,card throwing
19 | 17,carving ice
20 | 18,chewing gum
21 | 19,chiseling stone
22 | 20,chopping meat
23 | 21,chopping vegetables
24 | 22,coloring in
25 | 23,contorting
26 | 24,cooking scallops
27 | 25,cosplaying
28 | 26,cracking back
29 | 27,cracking knuckles
30 | 28,crossing eyes
31 | 29,cumbia
32 | 30,curling (sport)
33 | 31,cutting apple
34 | 32,delivering mail
35 | 33,directing traffic
36 | 34,docking boat
37 | 35,doing jigsaw puzzle
38 | 36,drooling
39 | 37,dumpster diving
40 | 38,dyeing eyebrows
41 | 39,embroidering
42 | 40,falling off bike
43 | 41,fidgeting
44 | 42,fixing bicycle
45 | 43,geocaching
46 | 44,getting a piercing
47 | 45,gold panning
48 | 46,hand washing clothes
49 | 47,head stand
50 | 48,historical reenactment
51 | 49,huddling
52 | 50,ice swimming
53 | 51,installing carpet
54 | 52,ironing hair
55 | 53,jumping bicycle
56 | 54,jumping jacks
57 | 55,land sailing
58 | 56,lawn mower racing
59 | 57,laying stone
60 | 58,laying tiles
61 | 59,leatherworking
62 | 60,licking
63 | 61,lifting hat
64 | 62,lighting fire
65 | 63,lock picking
66 | 64,longboarding
67 | 65,looking at phone
68 | 66,luge
69 | 67,making balloon shapes
70 | 68,making bubbles
71 | 69,marriage proposal
72 | 70,massaging neck
73 | 71,moon walking
74 | 72,mosh pit dancing
75 | 73,mushroom foraging
76 | 74,opening door
77 | 75,opening refrigerator
78 | 76,photobombing
79 | 77,pinching
80 | 78,pirouetting
81 | 79,planing wood
82 | 80,playing beer pong
83 | 81,playing blackjack
84 | 82,playing darts
85 | 83,playing dominoes
86 | 84,playing field hockey
87 | 85,playing gong
88 | 86,playing hand clapping games
89 | 87,playing laser tag
90 | 88,playing lute
91 | 89,playing maracas
92 | 90,playing netball
93 | 91,playing ocarina
94 | 92,playing pan pipes
95 | 93,playing ping pong
96 | 94,playing polo
97 | 95,playing rubiks cube
98 | 96,playing scrabble
99 | 97,playing with trains
100 | 98,poking bellybutton
101 | 99,polishing metal
102 | 100,popping balloons
103 | 101,pouring beer
104 | 102,preparing salad
105 | 103,putting in contact lenses
106 | 104,putting on eyeliner
107 | 105,putting on foundation
108 | 106,putting on mascara
109 | 107,putting on sari
110 | 108,repairing puncture
111 | 109,riding snow blower
112 | 110,roasting marshmallows
113 | 111,roasting pig
114 | 112,rolling pastry
115 | 113,rope pushdown
116 | 114,scrapbooking
117 | 115,scrubbing face
118 | 116,separating eggs
119 | 117,sewing
120 | 118,shaping bread dough
121 | 119,shining flashlight
122 | 120,shopping
123 | 121,shuffling feet
124 | 122,sipping cup
125 | 123,skiing mono
126 | 124,skipping stone
127 | 125,smashing
128 | 126,smelling feet
129 | 127,smoking pipe
130 | 128,square dancing
131 | 129,steer roping
132 | 130,sucking lolly
133 | 131,swimming front crawl
134 | 132,swinging baseball bat
135 | 133,sword swallowing
136 | 134,tagging graffiti
137 | 135,talking on cell phone
138 | 136,tasting wine
139 | 137,threading needle
140 | 138,throwing knife
141 | 139,throwing snowballs
142 | 140,throwing tantrum
143 | 141,throwing water balloon
144 | 142,tie dying
145 | 143,tightrope walking
146 | 144,trimming shrubs
147 | 145,twiddling fingers
148 | 146,tying shoe laces
149 | 147,using a paint roller
150 | 148,using a power drill
151 | 149,using a wrench
152 | 150,using atm
153 | 151,using bagging machine
154 | 152,using circular saw
155 | 153,using inhaler
156 | 154,visiting the zoo
157 | 155,wading through mud
158 | 156,wading through water
159 | 157,waving hand
160 | 158,weaving fabric
161 | 159,winking
162 |
--------------------------------------------------------------------------------
/lists/k600/k160_labels_split3.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,adjusting glasses
3 | 1,alligator wrestling
4 | 2,archaeological excavation
5 | 3,arguing
6 | 4,assembling bicycle
7 | 5,attending conference
8 | 6,backflip (human)
9 | 7,base jumping
10 | 8,battle rope training
11 | 9,blowdrying hair
12 | 10,blowing bubble gum
13 | 11,bouncing on bouncy castle
14 | 12,breaking boards
15 | 13,breathing fire
16 | 14,building lego
17 | 15,bull fighting
18 | 16,bulldozing
19 | 17,burping
20 | 18,calculating
21 | 19,calligraphy
22 | 20,capsizing
23 | 21,card stacking
24 | 22,card throwing
25 | 23,carving ice
26 | 24,casting fishing line
27 | 25,chiseling stone
28 | 26,chiseling wood
29 | 27,chopping meat
30 | 28,clam digging
31 | 29,coloring in
32 | 30,combing hair
33 | 31,contorting
34 | 32,cooking scallops
35 | 33,cosplaying
36 | 34,cracking knuckles
37 | 35,crossing eyes
38 | 36,cumbia
39 | 37,curling (sport)
40 | 38,delivering mail
41 | 39,directing traffic
42 | 40,doing jigsaw puzzle
43 | 41,dyeing eyebrows
44 | 42,falling off bike
45 | 43,falling off chair
46 | 44,fencing (sport)
47 | 45,fixing bicycle
48 | 46,flint knapping
49 | 47,fly tying
50 | 48,geocaching
51 | 49,hand washing clothes
52 | 50,head stand
53 | 51,historical reenactment
54 | 52,home roasting coffee
55 | 53,huddling
56 | 54,ice swimming
57 | 55,installing carpet
58 | 56,ironing hair
59 | 57,jaywalking
60 | 58,jumping bicycle
61 | 59,karaoke
62 | 60,land sailing
63 | 61,lawn mower racing
64 | 62,laying concrete
65 | 63,laying stone
66 | 64,laying tiles
67 | 65,leatherworking
68 | 66,licking
69 | 67,lifting hat
70 | 68,lighting fire
71 | 69,lock picking
72 | 70,longboarding
73 | 71,looking at phone
74 | 72,luge
75 | 73,making balloon shapes
76 | 74,making bubbles
77 | 75,making cheese
78 | 76,making paper aeroplanes
79 | 77,marriage proposal
80 | 78,mosh pit dancing
81 | 79,mushroom foraging
82 | 80,needle felting
83 | 81,opening door
84 | 82,opening refrigerator
85 | 83,opening wine bottle
86 | 84,packing
87 | 85,passing soccer ball
88 | 86,photobombing
89 | 87,photocopying
90 | 88,pinching
91 | 89,planing wood
92 | 90,playing blackjack
93 | 91,playing darts
94 | 92,playing field hockey
95 | 93,playing gong
96 | 94,playing laser tag
97 | 95,playing lute
98 | 96,playing maracas
99 | 97,playing ocarina
100 | 98,playing ping pong
101 | 99,playing polo
102 | 100,poking bellybutton
103 | 101,popping balloons
104 | 102,pouring beer
105 | 103,preparing salad
106 | 104,pushing wheelbarrow
107 | 105,putting in contact lenses
108 | 106,putting on eyeliner
109 | 107,putting on lipstick
110 | 108,putting on mascara
111 | 109,putting on sari
112 | 110,putting on shoes
113 | 111,raising eyebrows
114 | 112,repairing puncture
115 | 113,riding snow blower
116 | 114,roasting marshmallows
117 | 115,rolling pastry
118 | 116,rope pushdown
119 | 117,sawing wood
120 | 118,scrubbing face
121 | 119,shaping bread dough
122 | 120,shopping
123 | 121,shuffling feet
124 | 122,sipping cup
125 | 123,skiing mono
126 | 124,skipping stone
127 | 125,sleeping
128 | 126,smashing
129 | 127,square dancing
130 | 128,staring
131 | 129,steer roping
132 | 130,sucking lolly
133 | 131,swimming front crawl
134 | 132,swinging baseball bat
135 | 133,sword swallowing
136 | 134,tasting wine
137 | 135,threading needle
138 | 136,throwing knife
139 | 137,throwing snowballs
140 | 138,throwing tantrum
141 | 139,tie dying
142 | 140,tightrope walking
143 | 141,tiptoeing
144 | 142,trimming shrubs
145 | 143,using a microscope
146 | 144,using a paint roller
147 | 145,using a power drill
148 | 146,using a sledge hammer
149 | 147,using a wrench
150 | 148,using atm
151 | 149,using bagging machine
152 | 150,using inhaler
153 | 151,using puppets
154 | 152,visiting the zoo
155 | 153,wading through water
156 | 154,walking through snow
157 | 155,watching tv
158 | 156,waving hand
159 | 157,weaving fabric
160 | 158,winking
161 | 159,wood burning (art)
162 |
--------------------------------------------------------------------------------
/lists/kinetics_400_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,abseiling
3 | 1,air drumming
4 | 2,answering questions
5 | 3,applauding
6 | 4,applying cream
7 | 5,archery
8 | 6,arm wrestling
9 | 7,arranging flowers
10 | 8,assembling computer
11 | 9,auctioning
12 | 10,baby waking up
13 | 11,baking cookies
14 | 12,balloon blowing
15 | 13,bandaging
16 | 14,barbequing
17 | 15,bartending
18 | 16,beatboxing
19 | 17,bee keeping
20 | 18,belly dancing
21 | 19,bench pressing
22 | 20,bending back
23 | 21,bending metal
24 | 22,biking through snow
25 | 23,blasting sand
26 | 24,blowing glass
27 | 25,blowing leaves
28 | 26,blowing nose
29 | 27,blowing out candles
30 | 28,bobsledding
31 | 29,bookbinding
32 | 30,bouncing on trampoline
33 | 31,bowling
34 | 32,braiding hair
35 | 33,breading or breadcrumbing
36 | 34,breakdancing
37 | 35,brush painting
38 | 36,brushing hair
39 | 37,brushing teeth
40 | 38,building cabinet
41 | 39,building shed
42 | 40,bungee jumping
43 | 41,busking
44 | 42,canoeing or kayaking
45 | 43,capoeira
46 | 44,carrying baby
47 | 45,cartwheeling
48 | 46,carving pumpkin
49 | 47,catching fish
50 | 48,catching or throwing baseball
51 | 49,catching or throwing frisbee
52 | 50,catching or throwing softball
53 | 51,celebrating
54 | 52,changing oil
55 | 53,changing wheel
56 | 54,checking tires
57 | 55,cheerleading
58 | 56,chopping wood
59 | 57,clapping
60 | 58,clay pottery making
61 | 59,clean and jerk
62 | 60,cleaning floor
63 | 61,cleaning gutters
64 | 62,cleaning pool
65 | 63,cleaning shoes
66 | 64,cleaning toilet
67 | 65,cleaning windows
68 | 66,climbing a rope
69 | 67,climbing ladder
70 | 68,climbing tree
71 | 69,contact juggling
72 | 70,cooking chicken
73 | 71,cooking egg
74 | 72,cooking on campfire
75 | 73,cooking sausages
76 | 74,counting money
77 | 75,country line dancing
78 | 76,cracking neck
79 | 77,crawling baby
80 | 78,crossing river
81 | 79,crying
82 | 80,curling hair
83 | 81,cutting nails
84 | 82,cutting pineapple
85 | 83,cutting watermelon
86 | 84,dancing ballet
87 | 85,dancing charleston
88 | 86,dancing gangnam style
89 | 87,dancing macarena
90 | 88,deadlifting
91 | 89,decorating the christmas tree
92 | 90,digging
93 | 91,dining
94 | 92,disc golfing
95 | 93,diving cliff
96 | 94,dodgeball
97 | 95,doing aerobics
98 | 96,doing laundry
99 | 97,doing nails
100 | 98,drawing
101 | 99,dribbling basketball
102 | 100,drinking
103 | 101,drinking beer
104 | 102,drinking shots
105 | 103,driving car
106 | 104,driving tractor
107 | 105,drop kicking
108 | 106,drumming fingers
109 | 107,dunking basketball
110 | 108,dying hair
111 | 109,eating burger
112 | 110,eating cake
113 | 111,eating carrots
114 | 112,eating chips
115 | 113,eating doughnuts
116 | 114,eating hotdog
117 | 115,eating ice cream
118 | 116,eating spaghetti
119 | 117,eating watermelon
120 | 118,egg hunting
121 | 119,exercising arm
122 | 120,exercising with an exercise ball
123 | 121,extinguishing fire
124 | 122,faceplanting
125 | 123,feeding birds
126 | 124,feeding fish
127 | 125,feeding goats
128 | 126,filling eyebrows
129 | 127,finger snapping
130 | 128,fixing hair
131 | 129,flipping pancake
132 | 130,flying kite
133 | 131,folding clothes
134 | 132,folding napkins
135 | 133,folding paper
136 | 134,front raises
137 | 135,frying vegetables
138 | 136,garbage collecting
139 | 137,gargling
140 | 138,getting a haircut
141 | 139,getting a tattoo
142 | 140,giving or receiving award
143 | 141,golf chipping
144 | 142,golf driving
145 | 143,golf putting
146 | 144,grinding meat
147 | 145,grooming dog
148 | 146,grooming horse
149 | 147,gymnastics tumbling
150 | 148,hammer throw
151 | 149,headbanging
152 | 150,headbutting
153 | 151,high jump
154 | 152,high kick
155 | 153,hitting baseball
156 | 154,hockey stop
157 | 155,holding snake
158 | 156,hopscotch
159 | 157,hoverboarding
160 | 158,hugging
161 | 159,hula hooping
162 | 160,hurdling
163 | 161,hurling (sport)
164 | 162,ice climbing
165 | 163,ice fishing
166 | 164,ice skating
167 | 165,ironing
168 | 166,javelin throw
169 | 167,jetskiing
170 | 168,jogging
171 | 169,juggling balls
172 | 170,juggling fire
173 | 171,juggling soccer ball
174 | 172,jumping into pool
175 | 173,jumpstyle dancing
176 | 174,kicking field goal
177 | 175,kicking soccer ball
178 | 176,kissing
179 | 177,kitesurfing
180 | 178,knitting
181 | 179,krumping
182 | 180,laughing
183 | 181,laying bricks
184 | 182,long jump
185 | 183,lunge
186 | 184,making a cake
187 | 185,making a sandwich
188 | 186,making bed
189 | 187,making jewelry
190 | 188,making pizza
191 | 189,making snowman
192 | 190,making sushi
193 | 191,making tea
194 | 192,marching
195 | 193,massaging back
196 | 194,massaging feet
197 | 195,massaging legs
198 | 196,massaging person's head
199 | 197,milking cow
200 | 198,mopping floor
201 | 199,motorcycling
202 | 200,moving furniture
203 | 201,mowing lawn
204 | 202,news anchoring
205 | 203,opening bottle
206 | 204,opening present
207 | 205,paragliding
208 | 206,parasailing
209 | 207,parkour
210 | 208,passing American football (in game)
211 | 209,passing American football (not in game)
212 | 210,peeling apples
213 | 211,peeling potatoes
214 | 212,petting animal (not cat)
215 | 213,petting cat
216 | 214,picking fruit
217 | 215,planting trees
218 | 216,plastering
219 | 217,playing accordion
220 | 218,playing badminton
221 | 219,playing bagpipes
222 | 220,playing basketball
223 | 221,playing bass guitar
224 | 222,playing cards
225 | 223,playing cello
226 | 224,playing chess
227 | 225,playing clarinet
228 | 226,playing controller
229 | 227,playing cricket
230 | 228,playing cymbals
231 | 229,playing didgeridoo
232 | 230,playing drums
233 | 231,playing flute
234 | 232,playing guitar
235 | 233,playing harmonica
236 | 234,playing harp
237 | 235,playing ice hockey
238 | 236,playing keyboard
239 | 237,playing kickball
240 | 238,playing monopoly
241 | 239,playing organ
242 | 240,playing paintball
243 | 241,playing piano
244 | 242,playing poker
245 | 243,playing recorder
246 | 244,playing saxophone
247 | 245,playing squash or racquetball
248 | 246,playing tennis
249 | 247,playing trombone
250 | 248,playing trumpet
251 | 249,playing ukulele
252 | 250,playing violin
253 | 251,playing volleyball
254 | 252,playing xylophone
255 | 253,pole vault
256 | 254,presenting weather forecast
257 | 255,pull ups
258 | 256,pumping fist
259 | 257,pumping gas
260 | 258,punching bag
261 | 259,punching person (boxing)
262 | 260,push up
263 | 261,pushing car
264 | 262,pushing cart
265 | 263,pushing wheelchair
266 | 264,reading book
267 | 265,reading newspaper
268 | 266,recording music
269 | 267,riding a bike
270 | 268,riding camel
271 | 269,riding elephant
272 | 270,riding mechanical bull
273 | 271,riding mountain bike
274 | 272,riding mule
275 | 273,riding or walking with horse
276 | 274,riding scooter
277 | 275,riding unicycle
278 | 276,ripping paper
279 | 277,robot dancing
280 | 278,rock climbing
281 | 279,rock scissors paper
282 | 280,roller skating
283 | 281,running on treadmill
284 | 282,sailing
285 | 283,salsa dancing
286 | 284,sanding floor
287 | 285,scrambling eggs
288 | 286,scuba diving
289 | 287,setting table
290 | 288,shaking hands
291 | 289,shaking head
292 | 290,sharpening knives
293 | 291,sharpening pencil
294 | 292,shaving head
295 | 293,shaving legs
296 | 294,shearing sheep
297 | 295,shining shoes
298 | 296,shooting basketball
299 | 297,shooting goal (soccer)
300 | 298,shot put
301 | 299,shoveling snow
302 | 300,shredding paper
303 | 301,shuffling cards
304 | 302,side kick
305 | 303,sign language interpreting
306 | 304,singing
307 | 305,situp
308 | 306,skateboarding
309 | 307,ski jumping
310 | 308,skiing (not slalom or crosscountry)
311 | 309,skiing crosscountry
312 | 310,skiing slalom
313 | 311,skipping rope
314 | 312,skydiving
315 | 313,slacklining
316 | 314,slapping
317 | 315,sled dog racing
318 | 316,smoking
319 | 317,smoking hookah
320 | 318,snatch weight lifting
321 | 319,sneezing
322 | 320,sniffing
323 | 321,snorkeling
324 | 322,snowboarding
325 | 323,snowkiting
326 | 324,snowmobiling
327 | 325,somersaulting
328 | 326,spinning poi
329 | 327,spray painting
330 | 328,spraying
331 | 329,springboard diving
332 | 330,squat
333 | 331,sticking tongue out
334 | 332,stomping grapes
335 | 333,stretching arm
336 | 334,stretching leg
337 | 335,strumming guitar
338 | 336,surfing crowd
339 | 337,surfing water
340 | 338,sweeping floor
341 | 339,swimming backstroke
342 | 340,swimming breast stroke
343 | 341,swimming butterfly stroke
344 | 342,swing dancing
345 | 343,swinging legs
346 | 344,swinging on something
347 | 345,sword fighting
348 | 346,tai chi
349 | 347,taking a shower
350 | 348,tango dancing
351 | 349,tap dancing
352 | 350,tapping guitar
353 | 351,tapping pen
354 | 352,tasting beer
355 | 353,tasting food
356 | 354,testifying
357 | 355,texting
358 | 356,throwing axe
359 | 357,throwing ball
360 | 358,throwing discus
361 | 359,tickling
362 | 360,tobogganing
363 | 361,tossing coin
364 | 362,tossing salad
365 | 363,training dog
366 | 364,trapezing
367 | 365,trimming or shaving beard
368 | 366,trimming trees
369 | 367,triple jump
370 | 368,tying bow tie
371 | 369,tying knot (not on a tie)
372 | 370,tying tie
373 | 371,unboxing
374 | 372,unloading truck
375 | 373,using computer
376 | 374,using remote controller (not gaming)
377 | 375,using segway
378 | 376,vault
379 | 377,waiting in line
380 | 378,walking the dog
381 | 379,washing dishes
382 | 380,washing feet
383 | 381,washing hair
384 | 382,washing hands
385 | 383,water skiing
386 | 384,water sliding
387 | 385,watering plants
388 | 386,waxing back
389 | 387,waxing chest
390 | 388,waxing eyebrows
391 | 389,waxing legs
392 | 390,weaving basket
393 | 391,welding
394 | 392,whistling
395 | 393,windsurfing
396 | 394,wrapping present
397 | 395,wrestling
398 | 396,writing
399 | 397,yawning
400 | 398,yoga
401 | 399,zumba
402 |
--------------------------------------------------------------------------------
/lists/ucf_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,ApplyEyeMakeup
3 | 1,ApplyLipstick
4 | 2,Archery
5 | 3,BabyCrawling
6 | 4,BalanceBeam
7 | 5,BandMarching
8 | 6,BaseballPitch
9 | 7,Basketball
10 | 8,BasketballDunk
11 | 9,BenchPress
12 | 10,Biking
13 | 11,Billiards
14 | 12,BlowDryHair
15 | 13,BlowingCandles
16 | 14,BodyWeightSquats
17 | 15,Bowling
18 | 16,BoxingPunchingBag
19 | 17,BoxingSpeedBag
20 | 18,BreastStroke
21 | 19,BrushingTeeth
22 | 20,CleanAndJerk
23 | 21,CliffDiving
24 | 22,CricketBowling
25 | 23,CricketShot
26 | 24,CuttingInKitchen
27 | 25,Diving
28 | 26,Drumming
29 | 27,Fencing
30 | 28,FieldHockeyPenalty
31 | 29,FloorGymnastics
32 | 30,FrisbeeCatch
33 | 31,FrontCrawl
34 | 32,GolfSwing
35 | 33,Haircut
36 | 34,Hammering
37 | 35,HammerThrow
38 | 36,HandstandPushups
39 | 37,HandstandWalking
40 | 38,HeadMassage
41 | 39,HighJump
42 | 40,HorseRace
43 | 41,HorseRiding
44 | 42,HulaHoop
45 | 43,IceDancing
46 | 44,JavelinThrow
47 | 45,JugglingBalls
48 | 46,JumpingJack
49 | 47,JumpRope
50 | 48,Kayaking
51 | 49,Knitting
52 | 50,LongJump
53 | 51,Lunges
54 | 52,MilitaryParade
55 | 53,Mixing
56 | 54,MoppingFloor
57 | 55,Nunchucks
58 | 56,ParallelBars
59 | 57,PizzaTossing
60 | 58,PlayingCello
61 | 59,PlayingDaf
62 | 60,PlayingDhol
63 | 61,PlayingFlute
64 | 62,PlayingGuitar
65 | 63,PlayingPiano
66 | 64,PlayingSitar
67 | 65,PlayingTabla
68 | 66,PlayingViolin
69 | 67,PoleVault
70 | 68,PommelHorse
71 | 69,PullUps
72 | 70,Punch
73 | 71,PushUps
74 | 72,Rafting
75 | 73,RockClimbingIndoor
76 | 74,RopeClimbing
77 | 75,Rowing
78 | 76,SalsaSpin
79 | 77,ShavingBeard
80 | 78,Shotput
81 | 79,SkateBoarding
82 | 80,Skiing
83 | 81,Skijet
84 | 82,SkyDiving
85 | 83,SoccerJuggling
86 | 84,SoccerPenalty
87 | 85,StillRings
88 | 86,SumoWrestling
89 | 87,Surfing
90 | 88,Swing
91 | 89,TableTennisShot
92 | 90,TaiChi
93 | 91,TennisSwing
94 | 92,ThrowDiscus
95 | 93,TrampolineJumping
96 | 94,Typing
97 | 95,UnevenBars
98 | 96,VolleyballSpiking
99 | 97,WalkingWithDog
100 | 98,WallPushups
101 | 99,WritingOnBoard
102 | 100,YoYo
103 |
--------------------------------------------------------------------------------
/modules/coop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import functional as F
6 | from torch.cuda.amp import GradScaler, autocast
7 |
8 |
9 | from clip import clip
10 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
11 |
12 | _tokenizer = _Tokenizer()
13 |
14 |
15 |
16 |
17 | class TextEncoder(nn.Module):
18 | def __init__(self, clip_model):
19 | super().__init__()
20 | self.transformer = clip_model.transformer
21 | self.positional_embedding = clip_model.positional_embedding
22 | self.ln_final = clip_model.ln_final
23 | self.text_projection = clip_model.text_projection
24 | self.dtype = clip_model.dtype
25 |
26 | def forward(self, prompts, tokenized_prompts):
27 | x = prompts + self.positional_embedding.type(self.dtype)
28 | x = x.permute(1, 0, 2) # NLD -> LND
29 | x = self.transformer(x)
30 | x = x.permute(1, 0, 2) # LND -> NLD
31 | x = self.ln_final(x).type(self.dtype)
32 |
33 | # x.shape = [batch_size, n_ctx, transformer.width]
34 | # take features from the eot embedding (eot_token is the highest number in each sequence)
35 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
36 |
37 | return x
38 |
39 |
40 | class PromptLearner(nn.Module):
41 | def __init__(self, cfg, classnames, clip_model):
42 | super().__init__()
43 | n_cls = len(classnames)
44 | n_ctx = cfg.COOP.N_CTX
45 | ctx_init = cfg.COOP.CTX_INIT
46 | dtype = clip_model.dtype
47 | ctx_dim = clip_model.ln_final.weight.shape[0]
48 | clip_imsize = clip_model.visual.input_resolution
49 | cfg_imsize = cfg.data.input_size
50 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
51 |
52 | if ctx_init:
53 | # use given words to initialize context vectors
54 | ctx_init = ctx_init.replace("_", " ")
55 | n_ctx = len(ctx_init.split(" "))
56 | prompt = clip.tokenize(ctx_init)
57 | with torch.no_grad():
58 | embedding = clip_model.token_embedding(prompt).type(dtype)
59 | ctx_vectors = embedding[0, 1 : 1 + n_ctx, :]
60 | prompt_prefix = ctx_init
61 |
62 | else:
63 | # random initialization
64 | if cfg.COOP.CSC:
65 | print("Initializing class-specific contexts")
66 | ctx_vectors = torch.empty(n_cls, n_ctx, ctx_dim, dtype=dtype)
67 | else:
68 | print("Initializing a generic context")
69 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
70 | nn.init.normal_(ctx_vectors, std=0.02)
71 | prompt_prefix = " ".join(["X"] * n_ctx)
72 |
73 | print(f'Initial context: "{prompt_prefix}"')
74 | print(f"Number of context words (tokens): {n_ctx}")
75 |
76 | self.ctx = nn.Parameter(ctx_vectors) # to be optimized
77 |
78 | classnames = [name.replace("_", " ") for name in classnames]
79 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
80 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
81 |
82 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
83 | with torch.no_grad():
84 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
85 |
86 | # These token vectors will be saved when in save_model(),
87 | # but they should be ignored in load_model() as we want to use
88 | # those computed using the current class names
89 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
90 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx :, :]) # CLS, EOS
91 |
92 | self.n_cls = n_cls
93 | self.n_ctx = n_ctx
94 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
95 | self.name_lens = name_lens
96 | self.class_token_position = cfg.COOP.CLASS_TOKEN_POSITION
97 |
98 | def forward(self):
99 | ctx = self.ctx
100 | if ctx.dim() == 2:
101 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
102 |
103 | prefix = self.token_prefix
104 | suffix = self.token_suffix
105 |
106 | if self.class_token_position == "end":
107 | prompts = torch.cat(
108 | [
109 | prefix, # (n_cls, 1, dim)
110 | ctx, # (n_cls, n_ctx, dim)
111 | suffix, # (n_cls, *, dim)
112 | ],
113 | dim=1,
114 | )
115 |
116 | elif self.class_token_position == "middle":
117 | half_n_ctx = self.n_ctx // 2
118 | prompts = []
119 | for i in range(self.n_cls):
120 | name_len = self.name_lens[i]
121 | prefix_i = prefix[i : i + 1, :, :]
122 | class_i = suffix[i : i + 1, :name_len, :]
123 | suffix_i = suffix[i : i + 1, name_len:, :]
124 | ctx_i_half1 = ctx[i : i + 1, :half_n_ctx, :]
125 | ctx_i_half2 = ctx[i : i + 1, half_n_ctx:, :]
126 | prompt = torch.cat(
127 | [
128 | prefix_i, # (1, 1, dim)
129 | ctx_i_half1, # (1, n_ctx//2, dim)
130 | class_i, # (1, name_len, dim)
131 | ctx_i_half2, # (1, n_ctx//2, dim)
132 | suffix_i, # (1, *, dim)
133 | ],
134 | dim=1,
135 | )
136 | prompts.append(prompt)
137 | prompts = torch.cat(prompts, dim=0)
138 |
139 | elif self.class_token_position == "front":
140 | prompts = []
141 | for i in range(self.n_cls):
142 | name_len = self.name_lens[i]
143 | prefix_i = prefix[i : i + 1, :, :]
144 | class_i = suffix[i : i + 1, :name_len, :]
145 | suffix_i = suffix[i : i + 1, name_len:, :]
146 | ctx_i = ctx[i : i + 1, :, :]
147 | prompt = torch.cat(
148 | [
149 | prefix_i, # (1, 1, dim)
150 | class_i, # (1, name_len, dim)
151 | ctx_i, # (1, n_ctx, dim)
152 | suffix_i, # (1, *, dim)
153 | ],
154 | dim=1,
155 | )
156 | prompts.append(prompt)
157 | prompts = torch.cat(prompts, dim=0)
158 |
159 | else:
160 | raise ValueError
161 |
162 | return prompts
163 |
164 |
165 | class CoopCLIP(nn.Module):
166 | def __init__(self, cfg, classnames, clip_model, fusion_model, n_seg):
167 | super().__init__()
168 | self.clip_model = clip_model
169 | self.fusion_model = fusion_model
170 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
171 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
172 | self.text_encoder = TextEncoder(clip_model)
173 | self.logit_scale = clip_model.logit_scale
174 | self.dtype = clip_model.dtype
175 | self.n_seg = n_seg
176 | self.cfg = cfg
177 |
178 | def encode_image(self, image):
179 | bt = image.size(0)
180 | b = bt // self.n_seg
181 |
182 | image_emb = self.clip_model.encode_image(image).view(b, self.n_seg, -1)
183 | image_emb = self.fusion_model(image_emb)
184 | return image_emb
185 |
186 |
187 | def forward(self, image, class_id, train=False):
188 | if train:
189 | return self.forward_train(image, class_id)
190 | else:
191 | return self.forward_test(image, class_id)
192 |
193 | def forward_train(self, image, class_id):
194 |
195 | image_features = self.encode_image(image.type(self.dtype)) # B dim
196 |
197 | prompts = self.prompt_learner() # [400 77 512]
198 | tokenized_prompts = self.tokenized_prompts # [400 77]
199 |
200 | if self.cfg.solver.loss_type == 'NCE':
201 | prompts = prompts[class_id] # [bs 77 512]
202 | tokenized_prompts = tokenized_prompts[class_id] # [bs 77]
203 | text_features = self.text_encoder(prompts, tokenized_prompts) # bs Dim
204 | return image_features, text_features, self.logit_scale.exp()
205 | elif self.cfg.solver.loss_type == 'CE':
206 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim
207 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
208 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
209 |
210 | logit_scale = self.logit_scale.exp()
211 | logits = logit_scale * image_features @ text_features.t() # B 400
212 | return logits
213 | elif self.cfg.solver.loss_type == 'NCE,CE':
214 | bs_image_features = image_features
215 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim
216 | bs_text_features = text_features[class_id] # bs dim
217 |
218 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
219 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
220 |
221 | logit_scale = self.logit_scale.exp()
222 | logits = logit_scale * image_features @ text_features.t() # B 400
223 | return bs_image_features, bs_text_features, logit_scale, logits
224 |
225 | def forward_test(self, image, class_id):
226 | image_features = self.encode_image(image.type(self.dtype)) # B dim
227 |
228 | prompts = self.prompt_learner() # [400 77 512]
229 | tokenized_prompts = self.tokenized_prompts # [400 77]
230 |
231 | text_features = self.text_encoder(prompts, tokenized_prompts) # 400 Dim
232 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
233 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
234 |
235 | logit_scale = self.logit_scale.exp()
236 | logits = logit_scale * image_features @ text_features.t() # B 400
237 | return logits
--------------------------------------------------------------------------------
/modules/temporal_modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from clip.model import VisualTransformer, ModifiedResNet
5 | import numpy as np
6 |
7 |
8 | class TemporalShift(nn.Module):
9 | def __init__(self, net, n_segment=3, n_div=8, inplace=False):
10 | super(TemporalShift, self).__init__()
11 | self.net = net
12 | self.n_segment = n_segment
13 | self.fold_div = n_div
14 | self.inplace = inplace
15 | if inplace:
16 | print('=> Using in-place shift...')
17 | print('=> Using fold div: {}'.format(self.fold_div))
18 |
19 | def forward(self, x):
20 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
21 | x = self.net(x)
22 | return x
23 |
24 | @staticmethod
25 | def shift(x, n_segment, fold_div=3, inplace=False):
26 | nt, c, h, w = x.size()
27 | n_batch = nt // n_segment
28 | x = x.view(n_batch, n_segment, c, h, w)
29 |
30 | fold = c // fold_div
31 | if inplace:
32 | # Due to some out of order error when performing parallel computing.
33 | # May need to write a CUDA kernel.
34 | raise NotImplementedError
35 | # out = InplaceShift.apply(x, fold)
36 | else:
37 | out = torch.zeros_like(x)
38 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
39 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
40 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
41 |
42 | return out.view(nt, c, h, w)
43 |
44 |
45 | class TemporalShift_VIT(nn.Module):
46 | def __init__(self, net, n_segment=3, n_div=8, inplace=False):
47 | super(TemporalShift_VIT, self).__init__()
48 | self.net = net
49 | self.n_segment = n_segment
50 | self.fold_div = n_div
51 | self.inplace = inplace
52 | if inplace:
53 | print('=> Using in-place shift...')
54 | print('=> Using fold div: {}'.format(self.fold_div))
55 |
56 | def forward(self, x):
57 | x = self.shift(x, self.n_segment, fold_div=self.fold_div, inplace=self.inplace)
58 | x = self.net(x)
59 | return x
60 |
61 | @staticmethod
62 | def shift(x, n_segment, fold_div=3, inplace=False):
63 | hw, nt, c = x.size()
64 | cls_ = x[0,:,:].unsqueeze(0)
65 | x = x[1:,:,:]
66 | # print(cls_.size())
67 | x = x.permute(1,2,0) # nt,c,hw
68 | n_batch = nt // n_segment
69 | h = int(np.sqrt(hw-1))
70 | w = h
71 | x = x.contiguous().view(n_batch, n_segment, c, h, w)
72 |
73 | fold = c // fold_div
74 | if inplace:
75 | # Due to some out of order error when performing parallel computing.
76 | # May need to write a CUDA kernel.
77 | raise NotImplementedError
78 | # out = InplaceShift.apply(x, fold)
79 | else:
80 | out = torch.zeros_like(x)
81 | out[:, :-1, :fold] = x[:, 1:, :fold] # shift left
82 | out[:, 1:, fold: 2 * fold] = x[:, :-1, fold: 2 * fold] # shift right
83 | out[:, :, 2 * fold:] = x[:, :, 2 * fold:] # not shift
84 | out = out.contiguous().view(nt, c, h*w)
85 | out = out.permute(2,0,1) #hw, nt, c
86 | out = torch.cat((cls_,out),dim=0)
87 | # print(out.size())
88 | return out
89 |
90 |
91 |
92 | class TokenShift(nn.Module):
93 | def __init__(self, n_segment=3, n_div=4):
94 | super(TokenShift, self).__init__()
95 | self.n_segment = n_segment
96 | self.fold_div = n_div
97 |
98 | def forward(self, x):
99 | # n bt c
100 | n, bt, c = x.size()
101 | b = bt // self.n_segment
102 | x = x.permute(1, 0, 2).contiguous().view(b, self.n_segment, n, c)
103 |
104 | fold = c // self.fold_div
105 | out = torch.zeros_like(x)
106 | out[:, :-1, 0, :fold] = x[:, 1:, 0, :fold] # shift left
107 | out[:, 1:, 0, fold:2*fold] = x[:,:-1:, 0, fold:2*fold] # shift right
108 |
109 | out[:, :, 1:, :2*fold] = x[:, :, 1:, :2*fold] # not shift
110 | out[:, :, :, 2*fold:] = x[:, :, :, 2*fold:] # not shift
111 |
112 | out = out.view(bt, n, c).permute(1, 0, 2).contiguous()
113 |
114 | return out
115 |
116 | def make_tokenshift(net, n_segment, n_div=4, locations_list=[]):
117 | for idx, block in enumerate(net.transformer.resblocks):
118 | if idx in locations_list:
119 | net.transformer.resblocks[idx].control_point = TokenShift(
120 | n_segment=n_segment,
121 | n_div=n_div,
122 | )
123 |
124 |
125 | class TokenT1D(nn.Module):
126 | def __init__(self, in_channels, n_segment=3, n_div=4, mode='shift'):
127 | super(TokenT1D, self).__init__()
128 | self.input_channels = in_channels
129 | self.n_segment = n_segment
130 | self.fold_div = n_div
131 | self.fold = self.input_channels // self.fold_div
132 | self.conv = nn.Conv1d(self.fold_div*self.fold, self.fold_div*self.fold,
133 | kernel_size=3, padding=1, groups=self.fold_div*self.fold,
134 | bias=False)
135 |
136 | if mode == 'shift':
137 | self.conv.weight.requires_grad = True
138 | self.conv.weight.data.zero_()
139 | self.conv.weight.data[:self.fold, 0, 2] = 1 # shift left
140 | self.conv.weight.data[self.fold: 2 * self.fold, 0, 0] = 1 # shift right
141 | if 2*self.fold < self.input_channels:
142 | self.conv.weight.data[2 * self.fold:, 0, 1] = 1 # fixed
143 | elif mode == 'fixed':
144 | self.conv.weight.requires_grad = True
145 | self.conv.weight.data.zero_()
146 | self.conv.weight.data[:, 0, 1] = 1 # fixed
147 | elif mode == 'norm':
148 | self.conv.weight.requires_grad = True
149 |
150 | def forward(self, x):
151 | # n bt c
152 | n, bt, c = x.size()
153 | b = bt // self.n_segment
154 | x = x.permute(1, 0, 2).contiguous().view(b, self.n_segment, n, c)
155 | x = x.permute(0, 2, 3, 1).contiguous() # b, n, c, t
156 | out = torch.zeros_like(x)
157 | out[:, 0] = self.conv(x[:, 0])
158 | out[:, 1:] = x[:, 1:]
159 | out = out.permute(1, 0, 3, 2).contiguous().view(n, bt, c)
160 | return out
161 |
162 | def make_tokenT1D(net, n_segment, n_div=4, locations_list=[]):
163 | for idx, block in enumerate(net.transformer.resblocks):
164 | if idx in locations_list:
165 | block.control_point = TokenT1D(
166 | in_channels=block.control_point.inplanes,
167 | n_segment=n_segment,
168 | n_div=n_div,
169 | )
170 |
171 | class InplaceShift(torch.autograd.Function):
172 | # Special thanks to @raoyongming for the help to this function
173 | @staticmethod
174 | def forward(ctx, input, fold):
175 | # not support higher order gradient
176 | # input = input.detach_()
177 | ctx.fold_ = fold
178 | n, t, c, h, w = input.size()
179 | buffer = input.data.new(n, t, fold, h, w).zero_()
180 | buffer[:, :-1] = input.data[:, 1:, :fold]
181 | input.data[:, :, :fold] = buffer
182 | buffer.zero_()
183 | buffer[:, 1:] = input.data[:, :-1, fold: 2 * fold]
184 | input.data[:, :, fold: 2 * fold] = buffer
185 | return input
186 |
187 | @staticmethod
188 | def backward(ctx, grad_output):
189 | # grad_output = grad_output.detach_()
190 | fold = ctx.fold_
191 | n, t, c, h, w = grad_output.size()
192 | buffer = grad_output.data.new(n, t, fold, h, w).zero_()
193 | buffer[:, 1:] = grad_output.data[:, :-1, :fold]
194 | grad_output.data[:, :, :fold] = buffer
195 | buffer.zero_()
196 | buffer[:, :-1] = grad_output.data[:, 1:, fold: 2 * fold]
197 | grad_output.data[:, :, fold: 2 * fold] = buffer
198 | return grad_output, None
199 |
200 |
201 | class TemporalPool(nn.Module):
202 | def __init__(self, net, n_segment):
203 | super(TemporalPool, self).__init__()
204 | self.net = net
205 | self.n_segment = n_segment
206 |
207 | def forward(self, x):
208 | x = self.temporal_pool(x, n_segment=self.n_segment)
209 | return self.net(x)
210 |
211 | @staticmethod
212 | def temporal_pool(x, n_segment):
213 | nt, c, h, w = x.size()
214 | n_batch = nt // n_segment
215 | x = x.view(n_batch, n_segment, c, h, w).transpose(1, 2) # n, c, t, h, w
216 | x = F.max_pool3d(x, kernel_size=(3, 1, 1), stride=(2, 1, 1), padding=(1, 0, 0))
217 | x = x.transpose(1, 2).contiguous().view(nt // 2, c, h, w)
218 | return x
219 |
220 | def make_temporal_shift_vit(net, n_segment, n_div=8, place='block', temporal_pool=False):
221 | if temporal_pool:
222 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
223 | else:
224 | n_segment_list = [n_segment] * 4
225 | assert n_segment_list[-1] > 0
226 | print('=> n_segment per stage: {}'.format(n_segment_list))
227 |
228 | if isinstance(net, VisualTransformer):
229 | if place == 'block':
230 | def make_block_temporal(stage, this_segment):
231 | blocks = list(stage.children())
232 | print('=> Processing stage with {} blocks'.format(len(blocks)))
233 | for i, b in enumerate(blocks):
234 | blocks[i] = TemporalShift_VIT(b, n_segment=this_segment, n_div=n_div)
235 | return nn.Sequential(*(blocks))
236 |
237 | net.transformer.resblocks = make_block_temporal(net.transformer.resblocks, n_segment_list[0])
238 |
239 | # net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
240 | # net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
241 | # net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
242 |
243 |
244 | else:
245 | raise NotImplementedError(place)
246 |
247 |
248 |
249 |
250 | def make_temporal_shift(net, n_segment, n_div=8, place='blockres', temporal_pool=False):
251 | if temporal_pool:
252 | n_segment_list = [n_segment, n_segment // 2, n_segment // 2, n_segment // 2]
253 | else:
254 | n_segment_list = [n_segment] * 4
255 | assert n_segment_list[-1] > 0
256 | print('=> n_segment per stage: {}'.format(n_segment_list))
257 |
258 | if isinstance(net, ModifiedResNet):
259 | if place == 'block':
260 | def make_block_temporal(stage, this_segment):
261 | blocks = list(stage.children())
262 | print('=> Processing stage with {} blocks'.format(len(blocks)))
263 | for i, b in enumerate(blocks):
264 | blocks[i] = TemporalShift(b, n_segment=this_segment, n_div=n_div)
265 | return nn.Sequential(*(blocks))
266 |
267 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
268 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
269 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
270 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
271 |
272 | elif 'blockres' in place:
273 | n_round = 1
274 | if len(list(net.layer3.children())) >= 23:
275 | n_round = 2
276 | print('=> Using n_round {} to insert temporal shift'.format(n_round))
277 |
278 | def make_block_temporal(stage, this_segment):
279 | blocks = list(stage.children())
280 | print('=> Processing stage with {} blocks residual'.format(len(blocks)))
281 | for i, b in enumerate(blocks):
282 | if i % n_round == 0:
283 | blocks[i].conv1 = TemporalShift(b.conv1, n_segment=this_segment, n_div=n_div)
284 | return nn.Sequential(*blocks)
285 |
286 | net.layer1 = make_block_temporal(net.layer1, n_segment_list[0])
287 | net.layer2 = make_block_temporal(net.layer2, n_segment_list[1])
288 | net.layer3 = make_block_temporal(net.layer3, n_segment_list[2])
289 | net.layer4 = make_block_temporal(net.layer4, n_segment_list[3])
290 | else:
291 | raise NotImplementedError(place)
292 |
293 |
294 |
295 | def make_temporal_pool(net, n_segment):
296 | import torchvision
297 | if isinstance(net, torchvision.models.ResNet):
298 | print('=> Injecting nonlocal pooling')
299 | net.layer2 = TemporalPool(net.layer2, n_segment)
300 | else:
301 | raise NotImplementedError
302 |
303 |
304 | if __name__ == '__main__':
305 | # test inplace shift v.s. vanilla shift
306 | tsm1 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=False)
307 | tsm2 = TemporalShift(nn.Sequential(), n_segment=8, n_div=8, inplace=True)
308 |
309 | print('=> Testing CPU...')
310 | # test forward
311 | with torch.no_grad():
312 | for i in range(10):
313 | x = torch.rand(2 * 8, 3, 224, 224)
314 | y1 = tsm1(x)
315 | y2 = tsm2(x)
316 | assert torch.norm(y1 - y2).item() < 1e-5
317 |
318 | # test backward
319 | with torch.enable_grad():
320 | for i in range(10):
321 | x1 = torch.rand(2 * 8, 3, 224, 224)
322 | x1.requires_grad_()
323 | x2 = x1.clone()
324 | y1 = tsm1(x1)
325 | y2 = tsm2(x2)
326 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
327 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
328 | assert torch.norm(grad1 - grad2).item() < 1e-5
329 |
330 | print('=> Testing GPU...')
331 | tsm1.cuda()
332 | tsm2.cuda()
333 | # test forward
334 | with torch.no_grad():
335 | for i in range(10):
336 | x = torch.rand(2 * 8, 3, 224, 224).cuda()
337 | y1 = tsm1(x)
338 | y2 = tsm2(x)
339 | assert torch.norm(y1 - y2).item() < 1e-5
340 |
341 | # test backward
342 | with torch.enable_grad():
343 | for i in range(10):
344 | x1 = torch.rand(2 * 8, 3, 224, 224).cuda()
345 | x1.requires_grad_()
346 | x2 = x1.clone()
347 | y1 = tsm1(x1)
348 | y2 = tsm2(x2)
349 | grad1 = torch.autograd.grad((y1 ** 2).mean(), [x1])[0]
350 | grad2 = torch.autograd.grad((y2 ** 2).mean(), [x2])[0]
351 | assert torch.norm(grad1 - grad2).item() < 1e-5
352 | print('Test passed.')
--------------------------------------------------------------------------------
/modules/text_prompt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import clip
3 |
4 | def text_prompt(data):
5 | # text_aug = ['{}']
6 | text_aug = ['a video of a person {}.']
7 |
8 | # Kinetics
9 | # text_aug = [
10 | # 'a photo of a person {}.',
11 | # 'a photo of {}.',
12 | # 'a photo of a person using {}.',
13 | # 'a photo of a person doing {}.',
14 | # 'a photo of a person during {}.',
15 | # 'a photo of a person performing {}.',
16 | # 'a photo of a person practicing {}.',
17 | # 'a video of {}.',
18 | # 'a video of a person {}.',
19 | # 'a video of a person using {}.',
20 | # 'a video of a person doing {}.',
21 | # 'a video of a person during {}.',
22 | # 'a video of a person performing {}.',
23 | # 'a video of a person practicing {}.',
24 | # 'a example of {}.',
25 | # 'a example of a person {}.',
26 | # 'a example of a person using {}.',
27 | # 'a example of a person doing {}.',
28 | # 'a example of a person during {}.',
29 | # 'a example of a person performing {}.',
30 | # 'a example of a person practicing {}.',
31 | # 'a demonstration of {}.',
32 | # 'a demonstration of a person {}.',
33 | # 'a demonstration of a person using {}.',
34 | # 'a demonstration of a person doing {}.',
35 | # 'a demonstration of a person during {}.',
36 | # 'a demonstration of a person performing {}.',
37 | # 'a demonstration of a person practicing {}.',
38 | # ]
39 |
40 | text_dict = {}
41 | num_text_aug = len(text_aug)
42 |
43 | for ii, txt in enumerate(text_aug):
44 | text_dict[ii] = torch.cat([clip.tokenize(txt.format(c)) for i, c in data.classes])
45 |
46 | classes = text_dict[0]
47 |
48 | return classes, num_text_aug, text_dict
--------------------------------------------------------------------------------
/modules/video_clip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from collections import OrderedDict
4 | from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
5 |
6 |
7 | class LayerNorm(nn.Module):
8 | def __init__(self, hidden_size, eps=1e-12):
9 | """Construct a layernorm module in the TF style (epsilon inside the square root).
10 | """
11 | super(LayerNorm, self).__init__()
12 | self.weight = nn.Parameter(torch.ones(hidden_size))
13 | self.bias = nn.Parameter(torch.zeros(hidden_size))
14 | self.variance_epsilon = eps
15 |
16 | def forward(self, x):
17 | u = x.mean(-1, keepdim=True)
18 | s = (x - u).pow(2).mean(-1, keepdim=True)
19 | x = (x - u) / torch.sqrt(s + self.variance_epsilon)
20 | return self.weight * x + self.bias
21 |
22 | class QuickGELU(nn.Module):
23 | def forward(self, x: torch.Tensor):
24 | return x * torch.sigmoid(1.702 * x)
25 |
26 |
27 | class ResidualAttentionBlock(nn.Module):
28 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
29 | super().__init__()
30 |
31 | self.attn = nn.MultiheadAttention(d_model, n_head)
32 | self.ln_1 = LayerNorm(d_model)
33 | self.mlp = nn.Sequential(OrderedDict([
34 | ("c_fc", nn.Linear(d_model, d_model * 4)),
35 | ("gelu", QuickGELU()),
36 | ("c_proj", nn.Linear(d_model * 4, d_model))
37 | ]))
38 | self.ln_2 = LayerNorm(d_model)
39 | self.attn_mask = attn_mask
40 |
41 | def attention(self, x: torch.Tensor):
42 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
43 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
44 |
45 | def forward(self, x: torch.Tensor):
46 | x = x + self.attention(self.ln_1(x))
47 | x = x + self.mlp(self.ln_2(x))
48 | return x
49 |
50 |
51 |
52 | class TemporalTransformer(nn.Module):
53 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
54 | super().__init__()
55 | self.width = width
56 | self.layers = layers
57 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
58 |
59 | def forward(self, x: torch.Tensor):
60 | return self.resblocks((x))
61 |
62 |
63 | class video_header(nn.Module):
64 | def __init__(self, vid_head, clip_state_dict):
65 | super().__init__()
66 | self.vid_header = vid_head
67 | assert vid_head in ["None", "Transf"]
68 |
69 | if self.vid_header == "Transf":
70 | embed_dim = clip_state_dict["text_projection"].shape[1]
71 |
72 | context_length = clip_state_dict["positional_embedding"].shape[0]
73 | vocab_size = clip_state_dict["token_embedding.weight"].shape[0]
74 | transformer_width = clip_state_dict["ln_final.weight"].shape[0]
75 | transformer_heads = transformer_width // 64
76 |
77 | transformer_layers = len(
78 | set(k.split(".")[2] for k in clip_state_dict if k.startswith(f"transformer.resblocks")))
79 |
80 | self.frame_position_embeddings = nn.Embedding(context_length, embed_dim)
81 |
82 | self.transformer = TemporalTransformer(width=embed_dim, layers=6, heads=transformer_heads)
83 | print('layer=6')
84 |
85 | self.apply(self.init_weights)
86 |
87 | def init_weights(self, module):
88 | """ Initialize the weights.
89 | """
90 | if isinstance(module, (nn.Linear, nn.Embedding)):
91 | # Slightly different from the TF version which uses truncated_normal for initialization
92 | # cf https://github.com/pytorch/pytorch/pull/5617
93 | module.weight.data.normal_(mean=0.0, std=0.02)
94 | elif isinstance(module, LayerNorm):
95 | if 'beta' in dir(module) and 'gamma' in dir(module):
96 | module.beta.data.zero_()
97 | module.gamma.data.fill_(1.0)
98 | else:
99 | module.bias.data.zero_()
100 | module.weight.data.fill_(1.0)
101 | if isinstance(module, nn.Linear) and module.bias is not None:
102 | module.bias.data.zero_()
103 |
104 | def forward(self, x):
105 | b, t, c = x.size()
106 | x = x.contiguous()
107 | if self.vid_header == "None":
108 | pass
109 |
110 | elif self.vid_header == "Transf":
111 | x_original = x
112 | seq_length = t
113 | position_ids = torch.arange(seq_length, dtype=torch.long, device=x.device)
114 | position_ids = position_ids.unsqueeze(0).expand(x.size(0), -1)
115 | frame_position_embeddings = self.frame_position_embeddings(position_ids)
116 | x = x + frame_position_embeddings
117 |
118 | x = x.permute(1, 0, 2) # NLD -> LND
119 | x = self.transformer(x)
120 | x = x.permute(1, 0, 2) # LND -> NLD
121 | x = x.type(x_original.dtype) + x_original
122 |
123 | else:
124 | raise ValueError('Unknown temporal modeling header: {}'.format(self.vid_header))
125 | return x.mean(dim=1, keepdim=False)
126 |
127 |
128 | class VideoCLIP(nn.Module):
129 | def __init__(self, clip_model, video_header, n_seg) :
130 | super(VideoCLIP, self).__init__()
131 | self.visual = clip_model.visual
132 | self.fusion_model = video_header
133 | self.n_seg = n_seg
134 | self.logit_scale = clip_model.logit_scale
135 |
136 | def forward(self, image, text_emb):
137 | image_emb = self.encode_image(image)
138 | image_emb = image_emb / image_emb.norm(dim=-1, keepdim=True)
139 | text_emb = text_emb / text_emb.norm(dim=-1, keepdim=True)
140 | logit_scale = self.logit_scale.exp()
141 | logits = logit_scale * image_emb @ text_emb.t()
142 | return logits
143 |
144 | def encode_image(self, image):
145 | bt = image.size(0)
146 | b = bt // self.n_seg
147 | image_emb = self.visual(image)
148 | if image_emb.size(0) == b: # joint
149 | return image_emb
150 | else:
151 | image_emb = image_emb.view(b, self.n_seg, -1)
152 | image_emb = self.fusion_model(image_emb)
153 | return image_emb
--------------------------------------------------------------------------------
/scripts/run_test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | if [ -f $1 ]; then
3 | config=$1
4 | else
5 | echo "need a config file"
6 | exit
7 | fi
8 |
9 | weight=$2
10 |
11 | python -m torch.distributed.launch --master_port 1238 --nproc_per_node=8 \
12 | test.py --config ${config} --weights ${weight} ${@:3}
--------------------------------------------------------------------------------
/scripts/run_test_zeroshot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | if [ -f $1 ]; then
3 | config=$1
4 | else
5 | echo "need a config file"
6 | exit
7 | fi
8 |
9 | weight=$2
10 |
11 | python -m torch.distributed.launch --master_port 1238 --nproc_per_node=8 \
12 | test_zeroshot.py --config ${config} --weights ${weight} ${@:3}
--------------------------------------------------------------------------------
/scripts/run_train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | if [ -f $1 ]; then
3 | config=$1
4 | else
5 | echo "need a config file"
6 | exit
7 | fi
8 |
9 | now=$(date +"%Y%m%d_%H%M%S")
10 | python -m torch.distributed.launch --master_port 1237 --nproc_per_node=8 \
11 | train.py --config ${config} --log_time $now
--------------------------------------------------------------------------------
/scripts/run_train_multinodes.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | #--- Multi-nodes training hyperparams ---
4 | nnodes=2
5 | master_addr="10.67.212.21"
6 |
7 | # Note:
8 | # 0. You need to set the master ip address according to your own machines.
9 | # 1. You'd better to scale the learning rate when you use more gpus.
10 | # 2. Command: sh scripts/run_train_multinodes.sh node_rank
11 | #############################################
12 | if [ -f $1 ]; then
13 | config=$1
14 | else
15 | echo "need a config file"
16 | exit
17 | fi
18 |
19 | python -m torch.distributed.launch --master_port 12355 --nproc_per_node=8 \
20 | --nnodes=${nnodes} --node_rank=$2 \
21 | --master_addr=${master_addr} \
22 | train.py --config ${config} ${@:3}
23 |
--------------------------------------------------------------------------------
/scripts/run_train_nce.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | if [ -f $1 ]; then
3 | config=$1
4 | else
5 | echo "need a config file"
6 | exit
7 | fi
8 |
9 | now=$(date +"%Y%m%d_%H%M%S")
10 | python -m torch.distributed.launch --master_port 1237 --nproc_per_node=4 \
11 | train_nce.py --config ${config} --log_time $now
--------------------------------------------------------------------------------
/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/teaser.png
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | from torch.nn.parallel import DistributedDataParallel
8 | import torch.distributed as dist
9 | import torch.backends.cudnn as cudnn
10 | import torchvision
11 | import time
12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy
13 | import clip
14 |
15 | import yaml
16 | from dotmap import DotMap
17 |
18 | from datasets import Video_dataset
19 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample
20 | from modules.video_clip import video_header, VideoCLIP
21 | from modules.text_prompt import text_prompt
22 |
23 |
24 |
25 | def get_parser():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--config', type=str, help='global config file')
28 | parser.add_argument('--weights', type=str, default=None)
29 | parser.add_argument('--dist_url', default='env://',
30 | help='url used to set up distributed training')
31 | parser.add_argument('--world_size', default=1, type=int,
32 | help='number of distributed processes')
33 | parser.add_argument("--local_rank", type=int,
34 | help='local rank for DistributedDataParallel')
35 | parser.add_argument(
36 | "--precision",
37 | choices=["amp", "fp16", "fp32"],
38 | default="amp",
39 | help="Floating point precition."
40 | )
41 | parser.add_argument('--test_crops', type=int, default=1)
42 | parser.add_argument('--test_clips', type=int, default=1)
43 | parser.add_argument('--dense', default=False, action="store_true",
44 | help='use multiple clips for test')
45 | args = parser.parse_args()
46 | return args
47 |
48 | def update_dict(dict):
49 | new_dict = {}
50 | for k, v in dict.items():
51 | new_dict[k.replace('module.', '')] = v
52 | return new_dict
53 |
54 |
55 | def main(args):
56 | init_distributed_mode(args)
57 |
58 | with open(args.config, 'r') as f:
59 | config = yaml.load(f, Loader=yaml.FullLoader)
60 |
61 | config = DotMap(config)
62 |
63 | device = "cpu"
64 | if torch.cuda.is_available():
65 | device = "cuda"
66 | cudnn.benchmark = True
67 |
68 | # get fp16 model and weight
69 | model, clip_state_dict = clip.load(
70 | config.network.arch,
71 | device='cpu',jit=False,
72 | internal_modeling=config.network.tm,
73 | T=config.data.num_segments,
74 | dropout=config.network.drop_out,
75 | emb_dropout=config.network.emb_dropout,
76 | pretrain=config.network.init,
77 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32
78 |
79 | video_head = video_header(
80 | config.network.sim_header,
81 | clip_state_dict)
82 |
83 | if args.precision == "amp" or args.precision == "fp32":
84 | model = model.float()
85 |
86 |
87 | input_mean = [0.48145466, 0.4578275, 0.40821073]
88 | input_std = [0.26862954, 0.26130258, 0.27577711]
89 |
90 | # rescale size
91 | if 'something' in config.data.dataset:
92 | scale_size = (240, 320)
93 | else:
94 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size
95 |
96 | # crop size
97 | input_size = config.data.input_size
98 |
99 | # control the spatial crop
100 | if args.test_crops == 1: # one crop
101 | cropping = torchvision.transforms.Compose([
102 | GroupScale(scale_size),
103 | GroupCenterCrop(input_size),
104 | ])
105 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center)
106 | cropping = torchvision.transforms.Compose([
107 | GroupFullResSample(
108 | crop_size=input_size,
109 | scale_size=scale_size,
110 | flip=False)
111 | ])
112 | elif args.test_crops == 5: # do not flip, so only 5 crops
113 | cropping = torchvision.transforms.Compose([
114 | GroupOverSample(
115 | crop_size=input_size,
116 | scale_size=scale_size,
117 | flip=False)
118 | ])
119 | elif args.test_crops == 10:
120 | cropping = torchvision.transforms.Compose([
121 | GroupOverSample(
122 | crop_size=input_size,
123 | scale_size=scale_size,
124 | )
125 | ])
126 | else:
127 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops))
128 |
129 |
130 | val_data = Video_dataset(
131 | config.data.val_root, config.data.val_list, config.data.label_list,
132 | random_shift=False, num_segments=config.data.num_segments,
133 | modality=config.data.modality,
134 | image_tmpl=config.data.image_tmpl,
135 | test_mode=True,
136 | transform=torchvision.transforms.Compose([
137 | cropping,
138 | Stack(roll=False),
139 | ToTorchFormatTensor(div=True),
140 | GroupNormalize(input_mean,input_std),
141 | ]),
142 | dense_sample=args.dense,
143 | test_clips=args.test_clips)
144 |
145 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
146 | val_loader = DataLoader(val_data,
147 | batch_size=config.data.batch_size,num_workers=config.data.workers,
148 | sampler=val_sampler, pin_memory=True, drop_last=False)
149 |
150 |
151 | model_full = VideoCLIP(model, video_head, config.data.num_segments)
152 |
153 | if os.path.isfile(args.weights):
154 | checkpoint = torch.load(args.weights, map_location='cpu')
155 | if dist.get_rank() == 0:
156 | print('load model: epoch {}'.format(checkpoint['epoch']))
157 |
158 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict']))
159 | del checkpoint
160 |
161 | if args.distributed:
162 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True)
163 |
164 |
165 | classes, num_text_aug, text_dict = text_prompt(val_data)
166 | n_class = text_dict[0].size(0)
167 | #### generate classes feature ######
168 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','')
169 | if os.path.isfile(class_feats_file):
170 | print('=> load classes features from {}'.format(class_feats_file))
171 | classes_features = torch.load(class_feats_file)
172 | else:
173 | model.eval()
174 | with torch.no_grad():
175 | classes_features = model.encode_text(classes) # 400 512
176 | # if dist.get_rank() == 0:
177 | # torch.save(classes_features.cpu(), class_feats_file)
178 |
179 |
180 | prec1 = validate(
181 | val_loader, device,
182 | model_full, config, classes_features, args.test_crops, args.test_clips)
183 |
184 | return
185 |
186 |
187 | def validate_rumtime(val_loader, device, model, config, text_features, test_crops, test_clips):
188 |
189 | model.eval()
190 | with torch.no_grad():
191 | batch_size = config.data.batch_size
192 | image = torch.rand(batch_size, config.data.num_segments, 3, config.data.input_size, config.data.input_size)
193 | b, t, c, h, w = image.size()
194 | proc_start_time = time.time()
195 |
196 | for i in range(2000):
197 | image_input = image.to(device).view(-1, c, h, w)
198 | image_features = model.module.encode_image(image_input)
199 | cnt_time = time.time() - proc_start_time
200 |
201 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
202 | runtime = float(cnt_time) / (i+1) / (batch_size)
203 | print(
204 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'.format(
205 | i, 2000, runtime=runtime)))
206 | return cnt_time
207 |
208 |
209 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips):
210 |
211 | top1 = AverageMeter()
212 | top5 = AverageMeter()
213 | model.eval()
214 | proc_start_time = time.time()
215 |
216 | with torch.no_grad():
217 | n_class = text_features.size(0)
218 |
219 | for i, (image, class_id) in enumerate(val_loader):
220 | batch_size = class_id.numel()
221 | num_crop = test_crops
222 |
223 | num_crop *= test_clips # 4 clips for testing when using dense sample
224 |
225 | class_id = class_id.to(device)
226 | text_features = text_features.to(device)
227 | n_seg = config.data.num_segments
228 | image = image.view((-1, n_seg, 3) + image.size()[-2:])
229 | b, t, c, h, w = image.size()
230 | image_input = image.to(device).view(-1, c, h, w)
231 | image_features = model.module.encode_image(image_input)
232 | cnt_time = time.time() - proc_start_time
233 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim
234 |
235 | image_features /= image_features.norm(dim=-1, keepdim=True)
236 | text_features /= text_features.norm(dim=-1, keepdim=True)
237 | similarity = (100.0 * image_features @ text_features.T)
238 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1)
239 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200
240 |
241 | prec = accuracy(similarity, class_id, topk=(1, 5))
242 | prec1 = reduce_tensor(prec[0])
243 | prec5 = reduce_tensor(prec[1])
244 |
245 | top1.update(prec1.item(), class_id.size(0))
246 | top5.update(prec5.item(), class_id.size(0))
247 |
248 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
249 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size())
250 | print(
251 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'
252 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
253 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
254 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5)))
255 |
256 | if dist.get_rank() == 0:
257 | print('-----Evaluation is finished------')
258 | print('Overall Prec@1 {:.03f}% Prec@5 {:.03f}%'.format(top1.avg, top5.avg))
259 |
260 | return top1.avg
261 |
262 |
263 |
264 | if __name__ == '__main__':
265 | args = get_parser()
266 | main(args)
267 |
268 |
--------------------------------------------------------------------------------
/test_anet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | from torch.nn.parallel import DistributedDataParallel
8 | import torch.distributed as dist
9 | import torch.backends.cudnn as cudnn
10 | import torchvision
11 | import time
12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy, mean_average_precision
13 | import clip
14 |
15 | import yaml
16 | from dotmap import DotMap
17 |
18 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample
19 | from modules.video_clip import video_header, VideoCLIP
20 | from modules.text_prompt import text_prompt
21 |
22 |
23 |
24 | def get_parser():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('--config', type=str, help='global config file')
27 | parser.add_argument('--weights', type=str, default=None)
28 | parser.add_argument('--dist_url', default='env://',
29 | help='url used to set up distributed training')
30 | parser.add_argument('--world_size', default=1, type=int,
31 | help='number of distributed processes')
32 | parser.add_argument("--local_rank", type=int,
33 | help='local rank for DistributedDataParallel')
34 | parser.add_argument(
35 | "--precision",
36 | choices=["amp", "fp16", "fp32"],
37 | default="amp",
38 | help="Floating point precition."
39 | )
40 | parser.add_argument('--test_crops', type=int, default=1)
41 | parser.add_argument('--test_clips', type=int, default=1)
42 | parser.add_argument('--dense', default=False, action="store_true",
43 | help='use multiple clips for test')
44 | args = parser.parse_args()
45 | return args
46 |
47 | def update_dict(dict):
48 | new_dict = {}
49 | for k, v in dict.items():
50 | new_dict[k.replace('module.', '')] = v
51 | return new_dict
52 |
53 |
54 | def main(args):
55 | init_distributed_mode(args)
56 |
57 | with open(args.config, 'r') as f:
58 | config = yaml.load(f, Loader=yaml.FullLoader)
59 |
60 | if 'something' in config['data']['dataset']:
61 | from datasets.sth import Video_dataset
62 | else:
63 | from datasets.kinetics import Video_dataset
64 |
65 | config = DotMap(config)
66 |
67 | device = "cpu"
68 | if torch.cuda.is_available():
69 | device = "cuda"
70 | cudnn.benchmark = True
71 |
72 | # get fp16 model and weight
73 | model, clip_state_dict = clip.load(
74 | config.network.arch,
75 | device='cpu',jit=False,
76 | internal_modeling=config.network.tm,
77 | T=config.data.num_segments,
78 | dropout=config.network.drop_out,
79 | emb_dropout=config.network.emb_dropout,
80 | pretrain=config.network.init,
81 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32
82 |
83 | video_head = video_header(
84 | config.network.sim_header,
85 | clip_state_dict)
86 |
87 | if args.precision == "amp" or args.precision == "fp32":
88 | model = model.float()
89 |
90 |
91 | input_mean = [0.48145466, 0.4578275, 0.40821073]
92 | input_std = [0.26862954, 0.26130258, 0.27577711]
93 |
94 | # rescale size
95 | if 'something' in config.data.dataset:
96 | scale_size = (240, 320)
97 | else:
98 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size
99 |
100 | # crop size
101 | input_size = config.data.input_size
102 |
103 | # control the spatial crop
104 | if args.test_crops == 1: # one crop
105 | cropping = torchvision.transforms.Compose([
106 | GroupScale(scale_size),
107 | GroupCenterCrop(input_size),
108 | ])
109 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center)
110 | cropping = torchvision.transforms.Compose([
111 | GroupFullResSample(
112 | crop_size=input_size,
113 | scale_size=scale_size,
114 | flip=False)
115 | ])
116 | elif args.test_crops == 5: # do not flip, so only 5 crops
117 | cropping = torchvision.transforms.Compose([
118 | GroupOverSample(
119 | crop_size=input_size,
120 | scale_size=scale_size,
121 | flip=False)
122 | ])
123 | elif args.test_crops == 10:
124 | cropping = torchvision.transforms.Compose([
125 | GroupOverSample(
126 | crop_size=input_size,
127 | scale_size=scale_size,
128 | )
129 | ])
130 | else:
131 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops))
132 |
133 |
134 | val_data = Video_dataset(
135 | config.data.val_root, config.data.val_list, config.data.label_list,
136 | random_shift=False, num_segments=config.data.num_segments,
137 | modality=config.data.modality,
138 | image_tmpl=config.data.image_tmpl,
139 | test_mode=True,
140 | transform=torchvision.transforms.Compose([
141 | cropping,
142 | Stack(roll=False),
143 | ToTorchFormatTensor(div=True),
144 | GroupNormalize(input_mean,input_std),
145 | ]),
146 | dense_sample=args.dense,
147 | test_clips=args.test_clips)
148 |
149 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
150 | val_loader = DataLoader(val_data,
151 | batch_size=config.data.batch_size,num_workers=config.data.workers,
152 | sampler=val_sampler, pin_memory=True, drop_last=False)
153 |
154 |
155 | model_full = VideoCLIP(model, video_head, config.data.num_segments)
156 |
157 | if os.path.isfile(args.weights):
158 | checkpoint = torch.load(args.weights, map_location='cpu')
159 | if dist.get_rank() == 0:
160 | print('load model: epoch {}'.format(checkpoint['epoch']))
161 |
162 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict']))
163 | del checkpoint
164 |
165 | if args.distributed:
166 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True)
167 |
168 |
169 | classes, num_text_aug, text_dict = text_prompt(val_data)
170 | n_class = text_dict[0].size(0)
171 | #### generate classes feature ######
172 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','')
173 | if os.path.isfile(class_feats_file):
174 | print('=> load classes features from {}'.format(class_feats_file))
175 | classes_features = torch.load(class_feats_file)
176 | else:
177 | model.eval()
178 | with torch.no_grad():
179 | classes_features = model.encode_text(classes) # 400 512
180 | # if dist.get_rank() == 0:
181 | # torch.save(classes_features.cpu(), class_feats_file)
182 |
183 |
184 | prec1 = validate(
185 | val_loader, device,
186 | model_full, config, classes_features, args.test_crops, args.test_clips)
187 |
188 | return
189 |
190 |
191 | def validate_rumtime(val_loader, device, model, config, text_features, test_crops, test_clips):
192 |
193 | model.eval()
194 | with torch.no_grad():
195 | batch_size = config.data.batch_size
196 | image = torch.rand(batch_size, config.data.num_segments, 3, config.data.input_size, config.data.input_size)
197 | b, t, c, h, w = image.size()
198 | proc_start_time = time.time()
199 |
200 | for i in range(2000):
201 | image_input = image.to(device).view(-1, c, h, w)
202 | image_features = model.module.encode_image(image_input)
203 | cnt_time = time.time() - proc_start_time
204 |
205 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
206 | runtime = float(cnt_time) / (i+1) / (batch_size)
207 | print(
208 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'.format(
209 | i, 2000, runtime=runtime)))
210 | return cnt_time
211 |
212 |
213 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips):
214 |
215 | top1 = AverageMeter()
216 | top5 = AverageMeter()
217 | model.eval()
218 | proc_start_time = time.time()
219 |
220 | sim_logits = []
221 | labels = []
222 |
223 | with torch.no_grad():
224 | n_class = text_features.size(0)
225 |
226 | for i, (image, class_id) in enumerate(val_loader):
227 | batch_size = class_id.numel()
228 | num_crop = test_crops
229 |
230 | num_crop *= test_clips # 4 clips for testing when using dense sample
231 |
232 | class_id = class_id.to(device)
233 | text_features = text_features.to(device)
234 | n_seg = config.data.num_segments
235 | image = image.view((-1, n_seg, 3) + image.size()[-2:])
236 | b, t, c, h, w = image.size()
237 | image_input = image.to(device).view(-1, c, h, w)
238 | image_features = model.module.encode_image(image_input)
239 | cnt_time = time.time() - proc_start_time
240 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim
241 |
242 | image_features /= image_features.norm(dim=-1, keepdim=True)
243 | text_features /= text_features.norm(dim=-1, keepdim=True)
244 | similarity = (100.0 * image_features @ text_features.T)
245 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1)
246 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200
247 |
248 | ########## for saving
249 | sim_logits.append(concat_all_gather(similarity))
250 | labels.append(concat_all_gather(class_id))
251 | ##########
252 |
253 | prec = accuracy(similarity, class_id, topk=(1, 5))
254 | prec1 = reduce_tensor(prec[0])
255 | prec5 = reduce_tensor(prec[1])
256 |
257 | top1.update(prec1.item(), class_id.size(0))
258 | top5.update(prec5.item(), class_id.size(0))
259 |
260 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
261 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size())
262 | print(
263 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'
264 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
265 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
266 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5)))
267 |
268 | if dist.get_rank() == 0:
269 | print('-----Evaluation is finished------')
270 | print('Overall Prec@1 {:.03f}% Prec@5 {:.03f}%'.format(top1.avg, top5.avg))
271 |
272 | sim, gt = sim_logits[0], labels[0]
273 | for i in range(1, len(sim_logits)):
274 | sim = torch.cat((sim, sim_logits[i]), 0)
275 | gt = torch.cat((gt, labels[i]), 0)
276 |
277 | if dist.get_rank() == 0:
278 | mAP = mean_average_precision(sim, gt)
279 | print('Overall mAP: {:.03f}%'.format(mAP[1].item()))
280 |
281 | return top1.avg
282 |
283 |
284 | # utils
285 | @torch.no_grad()
286 | def concat_all_gather(tensor):
287 | """
288 | Performs all_gather operation on the provided tensors.
289 | *** Warning ***: torch.distributed.all_gather has no gradient.
290 | """
291 | tensors_gather = [torch.ones_like(tensor)
292 | for _ in range(torch.distributed.get_world_size())]
293 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
294 |
295 | output = torch.cat(tensors_gather, dim=0)
296 | return output.cpu()
297 |
298 |
299 |
300 | if __name__ == '__main__':
301 | args = get_parser()
302 | main(args)
303 |
304 |
--------------------------------------------------------------------------------
/test_zeroshot.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | import numpy as np
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.data import DataLoader
7 | from torch.nn.parallel import DistributedDataParallel
8 | import torch.distributed as dist
9 | import torch.backends.cudnn as cudnn
10 | import torchvision
11 | import time
12 | from utils.utils import init_distributed_mode, AverageMeter, reduce_tensor, accuracy
13 | import clip
14 |
15 | import yaml
16 | from dotmap import DotMap
17 |
18 | from datasets import Video_dataset
19 | from datasets.transforms import GroupScale, GroupCenterCrop, Stack, ToTorchFormatTensor, GroupNormalize, GroupOverSample, GroupFullResSample
20 | from modules.video_clip import video_header, VideoCLIP
21 | from modules.text_prompt import text_prompt
22 |
23 |
24 |
25 | def get_parser():
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--config', type=str, help='global config file')
28 | parser.add_argument('--weights', type=str, default=None)
29 | parser.add_argument('--dist_url', default='env://',
30 | help='url used to set up distributed training')
31 | parser.add_argument('--world_size', default=1, type=int,
32 | help='number of distributed processes')
33 | parser.add_argument("--local_rank", type=int,
34 | help='local rank for DistributedDataParallel')
35 | parser.add_argument(
36 | "--precision",
37 | choices=["amp", "fp16", "fp32"],
38 | default="amp",
39 | help="Floating point precition."
40 | )
41 | parser.add_argument('--test_crops', type=int, default=1)
42 | parser.add_argument('--test_clips', type=int, default=1)
43 | parser.add_argument('--dense', default=False, action="store_true",
44 | help='use multiple clips for test')
45 | args = parser.parse_args()
46 | return args
47 |
48 | def update_dict(dict):
49 | new_dict = {}
50 | for k, v in dict.items():
51 | new_dict[k.replace('module.', '')] = v
52 | return new_dict
53 |
54 |
55 | def main(args):
56 | init_distributed_mode(args)
57 |
58 | with open(args.config, 'r') as f:
59 | config = yaml.load(f, Loader=yaml.FullLoader)
60 |
61 | config = DotMap(config)
62 |
63 | device = "cpu"
64 | if torch.cuda.is_available():
65 | device = "cuda"
66 | cudnn.benchmark = True
67 |
68 | # get fp16 model and weight
69 | model, clip_state_dict = clip.load(
70 | config.network.arch,
71 | device='cpu',jit=False,
72 | internal_modeling=config.network.tm,
73 | T=config.data.num_segments,
74 | dropout=config.network.drop_out,
75 | emb_dropout=config.network.emb_dropout,
76 | pretrain=config.network.init,
77 | joint_st= config.network.joint_st) # Must set jit=False for training ViT-B/32
78 |
79 | video_head = video_header(
80 | config.network.sim_header,
81 | clip_state_dict)
82 |
83 | if args.precision == "amp" or args.precision == "fp32":
84 | model = model.float()
85 |
86 |
87 | input_mean = [0.48145466, 0.4578275, 0.40821073]
88 | input_std = [0.26862954, 0.26130258, 0.27577711]
89 |
90 | # rescale size
91 | if 'something' in config.data.dataset:
92 | scale_size = (240, 320)
93 | else:
94 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size
95 |
96 | # crop size
97 | input_size = config.data.input_size
98 |
99 | # control the spatial crop
100 | if args.test_crops == 1: # one crop
101 | cropping = torchvision.transforms.Compose([
102 | GroupScale(scale_size),
103 | GroupCenterCrop(input_size),
104 | ])
105 | elif args.test_crops == 3: # do not flip, so only 3 crops (left right center)
106 | cropping = torchvision.transforms.Compose([
107 | GroupFullResSample(
108 | crop_size=input_size,
109 | scale_size=scale_size,
110 | flip=False)
111 | ])
112 | elif args.test_crops == 5: # do not flip, so only 5 crops
113 | cropping = torchvision.transforms.Compose([
114 | GroupOverSample(
115 | crop_size=input_size,
116 | scale_size=scale_size,
117 | flip=False)
118 | ])
119 | elif args.test_crops == 10:
120 | cropping = torchvision.transforms.Compose([
121 | GroupOverSample(
122 | crop_size=input_size,
123 | scale_size=scale_size,
124 | )
125 | ])
126 | else:
127 | raise ValueError("Only 1, 3, 5, 10 crops are supported while we got {}".format(args.test_crops))
128 |
129 |
130 | val_data = Video_dataset(
131 | config.data.val_root, config.data.val_list, config.data.label_list,
132 | random_shift=False, num_segments=config.data.num_segments,
133 | modality=config.data.modality,
134 | image_tmpl=config.data.image_tmpl,
135 | test_mode=True,
136 | transform=torchvision.transforms.Compose([
137 | cropping,
138 | Stack(roll=False),
139 | ToTorchFormatTensor(div=True),
140 | GroupNormalize(input_mean,input_std),
141 | ]),
142 | dense_sample=args.dense,
143 | test_clips=args.test_clips)
144 |
145 | val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
146 | val_loader = DataLoader(val_data,
147 | batch_size=config.data.batch_size,num_workers=config.data.workers,
148 | sampler=val_sampler, pin_memory=True, drop_last=False)
149 |
150 |
151 | model_full = VideoCLIP(model, video_head, config.data.num_segments)
152 |
153 | if os.path.isfile(args.weights):
154 | checkpoint = torch.load(args.weights, map_location='cpu')
155 | if dist.get_rank() == 0:
156 | print('load model: epoch {}'.format(checkpoint['epoch']))
157 |
158 | model_full.load_state_dict(update_dict(checkpoint['model_state_dict']))
159 | del checkpoint
160 |
161 | if args.distributed:
162 | model_full = DistributedDataParallel(model_full.cuda(), device_ids=[args.gpu], find_unused_parameters=True)
163 |
164 |
165 | classes, num_text_aug, text_dict = text_prompt(val_data)
166 | n_class = text_dict[0].size(0)
167 | #### generate classes feature ######
168 | class_feats_file = 'text_feats_{}_{}.pt'.format(config['data']['dataset'], config['network']['arch']).replace('/','')
169 | if os.path.isfile(class_feats_file):
170 | print('=> load classes features from {}'.format(class_feats_file))
171 | classes_features = torch.load(class_feats_file)
172 | else:
173 | model.eval()
174 | with torch.no_grad():
175 | classes_features = model.encode_text(classes) # 400 512
176 | # if dist.get_rank() == 0:
177 | # torch.save(classes_features.cpu(), class_feats_file)
178 |
179 |
180 | prec1 = validate(
181 | val_loader, device,
182 | model_full, config, classes_features, args.test_crops, args.test_clips)
183 |
184 | return
185 |
186 |
187 |
188 | def validate(val_loader, device, model, config, text_features, test_crops, test_clips):
189 |
190 | top1 = AverageMeter()
191 | top5 = AverageMeter()
192 | model.eval()
193 | proc_start_time = time.time()
194 |
195 | sim_logits = [] #
196 | labels = [] #
197 | i_features = []
198 |
199 | with torch.no_grad():
200 | n_class = text_features.size(0)
201 |
202 | for i, (image, class_id) in enumerate(val_loader):
203 | batch_size = class_id.numel()
204 | num_crop = test_crops
205 |
206 | num_crop *= test_clips # 4 clips for testing when using dense sample
207 |
208 | class_id = class_id.to(device)
209 | text_features = text_features.to(device)
210 | n_seg = config.data.num_segments
211 | image = image.view((-1, n_seg, 3) + image.size()[-2:])
212 | b, t, c, h, w = image.size()
213 | image_input = image.to(device).view(-1, c, h, w)
214 | image_features = model.module.encode_image(image_input)
215 | cnt_time = time.time() - proc_start_time
216 | image_features = image_features.reshape(batch_size, num_crop, -1).mean(1) # bs dim
217 |
218 | image_features /= image_features.norm(dim=-1, keepdim=True)
219 | text_features /= text_features.norm(dim=-1, keepdim=True)
220 | similarity = (100.0 * image_features @ text_features.T)
221 | similarity = similarity.view(batch_size, -1, n_class).softmax(dim=-1)
222 | similarity = similarity.mean(dim=1, keepdim=False) # bs 200
223 |
224 | ########## gathering
225 | i_features.append(concat_all_gather(image_features))
226 | sim_logits.append(concat_all_gather(similarity))
227 | labels.append(concat_all_gather(class_id))
228 | ##########
229 |
230 |
231 | prec = accuracy(similarity, class_id, topk=(1, 5))
232 | prec1 = reduce_tensor(prec[0])
233 | prec5 = reduce_tensor(prec[1])
234 |
235 | top1.update(prec1.item(), class_id.size(0))
236 | top5.update(prec5.item(), class_id.size(0))
237 |
238 | if i % config.logging.print_freq == 0 and dist.get_rank() == 0:
239 | runtime = float(cnt_time) / (i+1) / (batch_size * dist.get_world_size())
240 | print(
241 | ('Test: [{0}/{1}], average {runtime:.4f} sec/video \t'
242 | 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
243 | 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
244 | i, len(val_loader), runtime=runtime, top1=top1, top5=top5)))
245 |
246 | if dist.get_rank() == 0:
247 | ## half-classes evaluation
248 | sim, la = sim_logits[0], labels[0]
249 | vid_feat = i_features[0]
250 | for i in range(1, len(sim_logits)):
251 | sim = torch.cat((sim, sim_logits[i]), 0)
252 | la = torch.cat((la, labels[i]), 0)
253 | vid_feat = torch.cat((vid_feat, i_features[i]), 0)
254 |
255 | acc_split, acc_split_top5 = multi_split_test(vid_feat.cpu(), text_features.cpu(), la.cpu())
256 | accuracy_split, accuracy_split_std = np.mean(acc_split), np.std(acc_split)
257 | accuracy_split_top5, accuracy_split_top5_std = np.mean(acc_split_top5), np.std(acc_split_top5)
258 | print('-----Half-classes Evaluation-----')
259 | print('Top1: mean {:.03f}%, std {:.03f}%'.format(accuracy_split, accuracy_split_std))
260 | print('Top5: mean {:.03f}%, std {:.03f}%'.format(accuracy_split_top5, accuracy_split_top5_std))
261 |
262 | return top1.avg
263 |
264 | # utils
265 | @torch.no_grad()
266 | def concat_all_gather(tensor):
267 | """
268 | Performs all_gather operation on the provided tensors.
269 | *** Warning ***: torch.distributed.all_gather has no gradient.
270 | """
271 | tensors_gather = [torch.ones_like(tensor)
272 | for _ in range(torch.distributed.get_world_size())]
273 | torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
274 |
275 | output = torch.cat(tensors_gather, dim=0)
276 | return output.cpu()
277 |
278 | def compute_accuracy(vis_emb, text_emb, label):
279 | n_class = len(text_emb)
280 | n_samples = len(vis_emb)
281 | similarity=(100.0 * vis_emb @ text_emb.T)
282 | similarity=similarity.view(n_samples, -1, n_class).softmax(dim = -1)
283 | similarity=similarity.mean(dim = 1, keepdim = False) # b 101
284 | prec=accuracy(similarity, label, topk = (1, 5))
285 | return prec[0], prec[1]
286 |
287 |
288 | def multi_split_test(vis_embs, text_embs, true_label):
289 | # vis_embs: [10000, 768]
290 | # text_embs: [101, 768]
291 | # true_label: [10000,]
292 | full_acc1, full_acc5 = compute_accuracy(vis_embs, text_embs, true_label)
293 | print('-----Full-classes Evaluation------')
294 | print('Overall Top1 {:.03f}% Top5 {:.03f}%'.format(full_acc1.item(), full_acc5.item()))
295 |
296 | # Calculate accuracy per split
297 | # Only when the model has been trained on a different dataset
298 | true_label = true_label.numpy()
299 | accuracy_split, accuracy_split_top5 = np.zeros(10), np.zeros(10)
300 | for split in range(len(accuracy_split)):
301 | np.random.seed(split)
302 | sel_classes = np.random.permutation(len(text_embs))[:len(text_embs) // 2] # [50, ]
303 | sel = [l in sel_classes for l in true_label] # len = 10000
304 | subclasses = np.unique(true_label[sel]) # [50, ]
305 | # label_map = {}
306 | # for i in range(len(subclasses)):
307 | # label_map[subclasses[i]] = i
308 | # new_label = np.array([label_map[l] for l in true_label[sel]])
309 | # new_label = torch.from_numpy(new_label)
310 | # label mapping: [4900, ], new label
311 | tl = np.array([int(np.where(l == subclasses)[0]) for l in true_label[sel]])
312 | tl = torch.from_numpy(tl)
313 | acc, acc5 = compute_accuracy(vis_embs[sel], text_embs[subclasses], tl)
314 | accuracy_split[split] = acc
315 | accuracy_split_top5[split] = acc5
316 |
317 | return accuracy_split, accuracy_split_top5
318 |
319 | if __name__ == '__main__':
320 | args = get_parser()
321 | main(args)
322 |
323 |
--------------------------------------------------------------------------------
/text4vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/Text4Vis/d61b34d0208a03ce6146edcc51033b2a040cb249/text4vis.png
--------------------------------------------------------------------------------
/utils/Augmentation.py:
--------------------------------------------------------------------------------
1 | from datasets.transforms import *
2 |
3 | def train_augmentation(input_size, flip=True):
4 | if flip:
5 | return torchvision.transforms.Compose([
6 | GroupRandomSizedCrop(input_size),
7 | GroupRandomHorizontalFlip(is_flow=False)])
8 | else:
9 | return torchvision.transforms.Compose([
10 | GroupMultiScaleCrop(input_size, [1, .875, .75, .66]),
11 | GroupRandomHorizontalFlip_sth()])
12 |
13 |
14 | def get_augmentation(training, config):
15 | input_mean = [0.48145466, 0.4578275, 0.40821073]
16 | input_std = [0.26862954, 0.26130258, 0.27577711]
17 | scale_size = 256 if config.data.input_size == 224 else config.data.input_size
18 |
19 | normalize = GroupNormalize(input_mean, input_std)
20 | if 'something' in config.data.dataset:
21 | groupscale = GroupScale((240, 320))
22 | else:
23 | groupscale = GroupScale(int(scale_size))
24 |
25 | if training:
26 | train_aug = train_augmentation(
27 | config.data.input_size,
28 | flip=False if 'something' in config.data.dataset else True)
29 |
30 | unique = torchvision.transforms.Compose([
31 | groupscale,
32 | train_aug,
33 | GroupRandomGrayscale(p=0.2),
34 | ])
35 | else:
36 | unique = torchvision.transforms.Compose([
37 | groupscale,
38 | GroupCenterCrop(config.data.input_size)])
39 |
40 | common = torchvision.transforms.Compose([
41 | Stack(roll=False),
42 | ToTorchFormatTensor(div=True),
43 | normalize])
44 | return torchvision.transforms.Compose([unique, common])
45 |
46 |
47 |
48 |
49 |
50 |
--------------------------------------------------------------------------------
/utils/NCELoss.py:
--------------------------------------------------------------------------------
1 | import torch.nn.functional as F
2 | import torch.nn as nn
3 |
4 | class NCELoss(nn.Module):
5 | """Loss that uses a 'hinge' on the lower bound.
6 | This means that for samples with a label value smaller than the threshold, the loss is zero if the prediction is
7 | also smaller than that threshold.
8 | args:
9 | error_matric: What base loss to use (MSE by default).
10 | threshold: Threshold to use for the hinge.
11 | clip: Clip the loss if it is above this value.
12 | """
13 |
14 | def __init__(self, error_metric=nn.KLDivLoss(reduction='mean')):
15 | super().__init__()
16 | print('=========using NCE Loss==========')
17 | self.error_metric = error_metric
18 |
19 | def forward(self, prediction, label):
20 | batch_size = len(prediction)
21 | probs1 = F.log_softmax(prediction, 1)
22 | probs2 = F.softmax(label * 10, 1)
23 | loss = self.error_metric(probs1, probs2) * batch_size
24 | return loss
25 |
26 |
27 | class DualLoss(nn.Module):
28 | def __init__(self, error_metric=nn.KLDivLoss(reduction='mean')):
29 | super().__init__()
30 | print('=========using DS Loss==========')
31 | self.error_metric = error_metric
32 |
33 | def forward(self, prediction, label, temp=1000):
34 | batch_size = len(prediction)
35 | prediction = prediction * F.softmax(prediction/temp, dim=0) * batch_size
36 | probs1 = F.log_softmax(prediction, 1)
37 | probs2 = F.softmax(label * 10, 1)
38 | loss = self.error_metric(probs1, probs2) * batch_size
39 | return loss
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import os
4 | import sys
5 | from termcolor import colored
6 |
7 |
8 | class _ColorfulFormatter(logging.Formatter):
9 | def __init__(self, *args, **kwargs):
10 | self._root_name = kwargs.pop("root_name") + "."
11 | self._abbrev_name = kwargs.pop("abbrev_name", "")
12 | if len(self._abbrev_name):
13 | self._abbrev_name = self._abbrev_name + "."
14 | super(_ColorfulFormatter, self).__init__(*args, **kwargs)
15 |
16 | def formatMessage(self, record):
17 | record.name = record.name.replace(self._root_name, self._abbrev_name)
18 | log = super(_ColorfulFormatter, self).formatMessage(record)
19 | if record.levelno == logging.WARNING:
20 | prefix = colored("WARNING", "red", attrs=["blink"])
21 | elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
22 | prefix = colored("ERROR", "red", attrs=["blink", "underline"])
23 | else:
24 | return log
25 | return prefix + " " + log
26 |
27 |
28 | # so that calling setup_logger multiple times won't add many handlers
29 | @functools.lru_cache()
30 | def setup_logger(
31 | output=None, distributed_rank=0, *, color=True, name="moco", abbrev_name=None
32 | ):
33 | """
34 | Initialize the detectron2 logger and set its verbosity level to "INFO".
35 | Args:
36 | output (str): a file name or a directory to save log. If None, will not save log file.
37 | If ends with ".txt" or ".log", assumed to be a file name.
38 | Otherwise, logs will be saved to `output/log.txt`.
39 | name (str): the root module name of this logger
40 | Returns:
41 | logging.Logger: a logger
42 | """
43 | logger = logging.getLogger(name)
44 | logger.setLevel(logging.DEBUG)
45 | logger.propagate = False
46 |
47 | if abbrev_name is None:
48 | abbrev_name = name
49 |
50 | plain_formatter = logging.Formatter(
51 | "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
52 | )
53 | # stdout logging: master only
54 | if distributed_rank == 0:
55 | ch = logging.StreamHandler(stream=sys.stdout)
56 | ch.setLevel(logging.DEBUG)
57 | if color:
58 | formatter = _ColorfulFormatter(
59 | colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
60 | datefmt="%m/%d %H:%M:%S",
61 | root_name=name,
62 | abbrev_name=str(abbrev_name),
63 | )
64 | else:
65 | formatter = plain_formatter
66 | ch.setFormatter(formatter)
67 | logger.addHandler(ch)
68 |
69 | # file logging: all workers
70 | if output is not None:
71 | if output.endswith(".txt") or output.endswith(".log"):
72 | filename = output
73 | else:
74 | filename = os.path.join(output, "log.txt")
75 | if distributed_rank == 0:
76 | os.makedirs(os.path.dirname(filename), exist_ok=True)
77 |
78 | fh = logging.StreamHandler(_cached_log_stream(filename))
79 | fh.setLevel(logging.DEBUG)
80 | fh.setFormatter(plain_formatter)
81 | logger.addHandler(fh)
82 |
83 | return logger
84 |
85 |
86 | # cache the opened file object, so that different calls to `setup_logger`
87 | # with the same file name can safely write to the same file.
88 | @functools.lru_cache(maxsize=None)
89 | def _cached_log_stream(filename):
90 | return open(filename, "a")
--------------------------------------------------------------------------------
/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from bisect import bisect_right
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | def to_tuple(x, L):
7 | if type(x) in (int, float):
8 | return [x] * L
9 | if type(x) in (list, tuple):
10 | if len(x) != L:
11 | raise ValueError('length of {} ({}) != {}'.format(x, len(x), L))
12 | return tuple(x)
13 | raise ValueError('input {} has unkown type {}'.format(x, type(x)))
14 |
15 |
16 | class WarmupLR(_LRScheduler):
17 |
18 | def __init__(self,
19 | optimizer,
20 | warmup_epochs=0,
21 | warmup_powers=1,
22 | warmup_lrs=0,
23 | last_epoch=-1):
24 | self.num_groups = len(optimizer.param_groups)
25 | self.warmup_epochs = to_tuple(warmup_epochs, self.num_groups)
26 | self.warmup_powers = to_tuple(warmup_powers, self.num_groups)
27 | self.warmup_lrs = to_tuple(warmup_lrs, self.num_groups)
28 | super(WarmupLR, self).__init__(optimizer, last_epoch)
29 | assert self.num_groups == len(self.base_lrs)
30 |
31 | def get_lr(self):
32 | curr_lrs = []
33 | for group_index in range(self.num_groups):
34 | if self.last_epoch < self.warmup_epochs[group_index]:
35 | progress = self.last_epoch / self.warmup_epochs[group_index]
36 | factor = progress ** self.warmup_powers[group_index]
37 | lr_gap = self.base_lrs[group_index] - self.warmup_lrs[group_index]
38 | curr_lrs.append(factor * lr_gap + self.warmup_lrs[group_index])
39 | else:
40 | curr_lrs.append(self.get_single_lr_after_warmup(group_index))
41 | return curr_lrs
42 |
43 | def get_single_lr_after_warmup(self, group_index):
44 | raise NotImplementedError
45 |
46 |
47 | class WarmupMultiStepLR(WarmupLR):
48 |
49 | def __init__(self,
50 | optimizer,
51 | milestones,
52 | gamma=0.1,
53 | warmup_epochs=0,
54 | warmup_powers=1,
55 | warmup_lrs=0,
56 | last_epoch=-1):
57 |
58 | if not list(milestones) == sorted(milestones):
59 | raise ValueError('Milestones should be a list of'
60 | ' increasing integers. Got %s' % repr(milestones))
61 | self.milestones = milestones
62 | self.gamma = gamma
63 | super(WarmupMultiStepLR, self).__init__(optimizer,
64 | warmup_epochs,
65 | warmup_powers,
66 | warmup_lrs,
67 | last_epoch)
68 | if self.milestones[0] <= max(self.warmup_epochs):
69 | raise ValueError('milstones[0] ({}) <= max(warmup_epochs) ({})'.format(
70 | milestones[0], max(self.warmup_epochs)))
71 |
72 | def get_single_lr_after_warmup(self, group_index):
73 | factor = self.gamma ** bisect_right(self.milestones, self.last_epoch)
74 | return self.base_lrs[group_index] * factor
75 |
76 |
77 | class WarmupCosineAnnealingLR(WarmupLR):
78 |
79 | def __init__(self,
80 | optimizer,
81 | total_epoch,
82 | final_factor=0,
83 | warmup_epochs=0,
84 | warmup_powers=1,
85 | warmup_lrs=0,
86 | last_epoch=-1):
87 | self.total_epoch = total_epoch
88 | self.final_factor = final_factor
89 | super(WarmupCosineAnnealingLR, self).__init__(optimizer,
90 | warmup_epochs,
91 | warmup_powers,
92 | warmup_lrs,
93 | last_epoch)
94 |
95 | def get_single_lr_after_warmup(self, group_index):
96 | warmup_epoch = self.warmup_epochs[group_index]
97 | progress = (self.last_epoch - warmup_epoch) / (self.total_epoch - warmup_epoch)
98 | progress = min(progress, 1.0)
99 | cosine_progress = (math.cos(math.pi * progress) + 1) / 2
100 | factor = cosine_progress * (1 - self.final_factor) + self.final_factor
101 | return self.base_lrs[group_index] * factor
102 |
103 |
104 | class WarmupExponentialLR(WarmupLR):
105 |
106 | def __init__(self,
107 | optimizer,
108 | total_epoch,
109 | final_factor=1e-3,
110 | warmup_epochs=0,
111 | warmup_powers=1,
112 | warmup_lrs=0,
113 | last_epoch=-1):
114 | if final_factor <= 0:
115 | raise ValueError('final_factor ({}) <= 0 not allowed'.format(final_factor))
116 | self.total_epoch = total_epoch
117 | self.final_factor = final_factor
118 | super(WarmupExponentialLR, self).__init__(optimizer,
119 | warmup_epochs,
120 | warmup_powers,
121 | warmup_lrs,
122 | last_epoch)
123 |
124 | def get_single_lr_after_warmup(self, group_index):
125 | warmup_epoch = self.warmup_epochs[group_index]
126 | progress = (self.last_epoch - warmup_epoch) / (self.total_epoch - warmup_epoch)
127 | progress = min(progress, 1.0)
128 | factor = self.final_factor ** progress
129 | return self.base_lrs[group_index] * factor
130 |
131 |
132 | class ReduceLROnPlateau(object):
133 | """Reduce learning rate when a metric has stopped improving.
134 | Models often benefit from reducing the learning rate by a factor
135 | of 2-10 once learning stagnates. This scheduler reads a metrics
136 | quantity and if no improvement is seen for a 'patience' number
137 | of epochs, the learning rate is reduced.
138 |
139 | Args:
140 | optimizer (Optimizer): Wrapped optimizer.
141 | mode (str): One of `min`, `max`. In `min` mode, lr will
142 | be reduced when the quantity monitored has stopped
143 | decreasing; in `max` mode it will be reduced when the
144 | quantity monitored has stopped increasing. Default: 'min'.
145 | factor (float): Factor by which the learning rate will be
146 | reduced. new_lr = lr * factor. Default: 0.1.
147 | patience (int): Number of epochs with no improvement after
148 | which learning rate will be reduced. For example, if
149 | `patience = 2`, then we will ignore the first 2 epochs
150 | with no improvement, and will only decrease the LR after the
151 | 3rd epoch if the loss still hasn't improved then.
152 | Default: 10.
153 | verbose (bool): If ``True``, prints a message to stdout for
154 | each update. Default: ``False``.
155 | threshold (float): Threshold for measuring the new optimum,
156 | to only focus on significant changes. Default: 1e-4.
157 | threshold_mode (str): One of `rel`, `abs`. In `rel` mode,
158 | dynamic_threshold = best * ( 1 + threshold ) in 'max'
159 | mode or best * ( 1 - threshold ) in `min` mode.
160 | In `abs` mode, dynamic_threshold = best + threshold in
161 | `max` mode or best - threshold in `min` mode. Default: 'rel'.
162 | cooldown (int): Number of epochs to wait before resuming
163 | normal operation after lr has been reduced. Default: 0.
164 | min_lr (float or list): A scalar or a list of scalars. A
165 | lower bound on the learning rate of all param groups
166 | or each group respectively. Default: 0.
167 | eps (float): Minimal decay applied to lr. If the difference
168 | between new and old lr is smaller than eps, the update is
169 | ignored. Default: 1e-8.
170 |
171 | Example:
172 | >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
173 | >>> scheduler = ReduceLROnPlateau(optimizer, 'min')
174 | >>> for epoch in range(10):
175 | >>> train(...)
176 | >>> val_loss = validate(...)
177 | >>> # Note that step should be called after validate()
178 | >>> scheduler.step(val_loss)
179 | """
180 |
181 | def __init__(self, optimizer, mode='min', factor=0.1, patience=10,
182 | verbose=False, threshold=1e-4, threshold_mode='rel',
183 | cooldown=0, min_lr=0, eps=1e-8):
184 |
185 | if factor >= 1.0:
186 | raise ValueError('Factor should be < 1.0.')
187 | self.factor = factor
188 |
189 | if not isinstance(optimizer, Optimizer):
190 | raise TypeError('{} is not an Optimizer'.format(
191 | type(optimizer).__name__))
192 | self.optimizer = optimizer
193 |
194 | if isinstance(min_lr, list) or isinstance(min_lr, tuple):
195 | if len(min_lr) != len(optimizer.param_groups):
196 | raise ValueError("expected {} min_lrs, got {}".format(
197 | len(optimizer.param_groups), len(min_lr)))
198 | self.min_lrs = list(min_lr)
199 | else:
200 | self.min_lrs = [min_lr] * len(optimizer.param_groups)
201 |
202 | self.patience = patience
203 | self.verbose = verbose
204 | self.cooldown = cooldown
205 | self.cooldown_counter = 0
206 | self.mode = mode
207 | self.threshold = threshold
208 | self.threshold_mode = threshold_mode
209 | self.best = None
210 | self.num_bad_epochs = None
211 | self.mode_worse = None # the worse value for the chosen mode
212 | self.is_better = None
213 | self.eps = eps
214 | self.last_epoch = -1
215 | self._init_is_better(mode=mode, threshold=threshold,
216 | threshold_mode=threshold_mode)
217 | self._reset()
218 |
219 | def _reset(self):
220 | """Resets num_bad_epochs counter and cooldown counter."""
221 | self.best = self.mode_worse
222 | self.cooldown_counter = 0
223 | self.num_bad_epochs = 0
224 |
225 | def step(self, metrics, epoch=None):
226 | current = metrics
227 | if epoch is None:
228 | epoch = self.last_epoch = self.last_epoch + 1
229 | self.last_epoch = epoch
230 |
231 | if self.is_better(current, self.best):
232 | self.best = current
233 | self.num_bad_epochs = 0
234 | else:
235 | self.num_bad_epochs += 1
236 |
237 | if self.in_cooldown:
238 | self.cooldown_counter -= 1
239 | self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
240 |
241 | if self.num_bad_epochs > self.patience:
242 | self._reduce_lr(epoch)
243 | self.cooldown_counter = self.cooldown
244 | self.num_bad_epochs = 0
245 |
246 | def _reduce_lr(self, epoch):
247 | for i, param_group in enumerate(self.optimizer.param_groups):
248 | old_lr = float(param_group['lr'])
249 | new_lr = max(old_lr * self.factor, self.min_lrs[i])
250 | if old_lr - new_lr > self.eps:
251 | param_group['lr'] = new_lr
252 | if self.verbose:
253 | print('Epoch {:5d}: reducing learning rate'
254 | ' of group {} to {:.4e}.'.format(epoch, i, new_lr))
255 |
256 | @property
257 | def in_cooldown(self):
258 | return self.cooldown_counter > 0
259 |
260 | def _cmp(self, mode, threshold_mode, threshold, a, best):
261 | if mode == 'min' and threshold_mode == 'rel':
262 | rel_epsilon = 1. - threshold
263 | return a < best * rel_epsilon
264 |
265 | elif mode == 'min' and threshold_mode == 'abs':
266 | return a < best - threshold
267 |
268 | elif mode == 'max' and threshold_mode == 'rel':
269 | rel_epsilon = threshold + 1.
270 | return a > best * rel_epsilon
271 |
272 | else: # mode == 'max' and epsilon_mode == 'abs':
273 | return a > best + threshold
274 |
275 | def _init_is_better(self, mode, threshold, threshold_mode):
276 | if mode not in {'min', 'max'}:
277 | raise ValueError('mode ' + mode + ' is unknown!')
278 | if threshold_mode not in {'rel', 'abs'}:
279 | raise ValueError('threshold mode ' + threshold_mode + ' is unknown!')
280 |
281 | if mode == 'min':
282 | self.mode_worse = inf
283 | else: # mode == 'max':
284 | self.mode_worse = -inf
285 |
286 | self.is_better = partial(self._cmp, mode, threshold_mode, threshold)
287 |
288 | def state_dict(self):
289 | return {key: value for key, value in self.__dict__.items() if key not in {'optimizer', 'is_better'}}
290 |
291 | def load_state_dict(self, state_dict):
292 | self.__dict__.update(state_dict)
293 | self._init_is_better(mode=self.mode, threshold=self.threshold, threshold_mode=self.threshold_mode)
294 |
--------------------------------------------------------------------------------
/utils/solver.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from utils.lr_scheduler import WarmupMultiStepLR, WarmupCosineAnnealingLR
3 |
4 | def _optimizer(config, model, video_head):
5 | if config.solver.optim == 'adam':
6 | optimizer = optim.Adam([{'params': model.parameters()},
7 | {'params': video_head.parameters(), 'lr': config.solver.lr}],
8 | lr=config.solver.lr * config.solver.clip_ratio, betas=(0.9, 0.999), eps=1e-8,
9 | weight_decay=0.2) # Params used from paper, the lr is smaller, more safe for fine tuning to new dataset
10 | print('Adam')
11 | elif config.solver.optim == 'sgd':
12 |
13 | optimizer = optim.SGD([{'params': model.parameters()},
14 | {'params': video_head.parameters(), 'lr': config.solver.lr}],
15 | config.solver.lr * config.solver.clip_ratio,
16 | momentum=config.solver.momentum,
17 | weight_decay=config.solver.weight_decay)
18 | print('SGD')
19 | elif config.solver.optim == 'adamw':
20 | vision_params = []
21 | text_params = []
22 | for name, param in model.named_parameters():
23 | if 'visual.' in name:
24 | vision_params.append(param)
25 | else:
26 | text_params.append(param)
27 |
28 | # print('[INFO] number of visual parameters:', len(vision_params), flush=True)
29 | # print('[INFO] number of textual parameters:', len(text_params), flush=True)
30 | optimizer = optim.AdamW([{'params': model.parameters(), 'lr': config.solver.lr * config.solver.clip_ratio},
31 | {'params': video_head.parameters(), 'lr': config.solver.lr}],
32 | betas=(0.9, 0.999), lr=config.solver.lr, eps=1e-8,
33 | weight_decay=config.solver.weight_decay) # Params used from paper, the lr is smaller, more safe for fine tuning to new dataset
34 | # for param_group in optimizer.param_groups:
35 | # print(param_group['lr'])
36 | else:
37 | raise ValueError('Unknown optimizer: {}'.format(config.solver.optim))
38 | return optimizer
39 |
40 |
41 | def _lr_scheduler(config, optimizer):
42 | if config.solver.type == 'cosine':
43 | lr_scheduler = WarmupCosineAnnealingLR(
44 | optimizer,
45 | config.solver.epochs,
46 | warmup_epochs=config.solver.lr_warmup_step
47 | )
48 | elif config.solver.type == 'multistep':
49 | if isinstance(config.solver.lr_decay_step, list):
50 | milestones = config.solver.lr_decay_step
51 | elif isinstance(config.solver.lr_decay_step, int):
52 | milestones = [
53 | config.solver.lr_decay_step * (i + 1)
54 | for i in range(config.solver.epochs //
55 | config.solver.lr_decay_step)]
56 | else:
57 | raise ValueError("error learning rate decay step: {}".format(type(config.solver.lr_decay_step)))
58 | lr_scheduler = WarmupMultiStepLR(
59 | optimizer,
60 | milestones,
61 | warmup_epochs=config.solver.lr_warmup_step
62 | )
63 | else:
64 | raise ValueError('Unknown lr scheduler: {}'.format(config.solver.type))
65 | return lr_scheduler
66 |
67 |
68 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | utils for clip
3 | """
4 | import os
5 |
6 | import torch
7 | import torch.distributed as dist
8 | import torch.distributed.nn as distnn
9 | from torch import nn
10 | import numpy
11 |
12 | def init_distributed_mode(args):
13 | """ init for distribute mode """
14 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
15 | args.rank = int(os.environ["RANK"])
16 | args.world_size = int(os.environ['WORLD_SIZE'])
17 | args.gpu = int(os.environ['LOCAL_RANK'])
18 | elif 'SLURM_PROCID' in os.environ:
19 | args.rank = int(os.environ['SLURM_PROCID'])
20 | args.gpu = args.rank % torch.cuda.device_count()
21 | else:
22 | print('Not using distributed mode')
23 | args.distributed = False
24 | return
25 |
26 | args.distributed = True
27 |
28 | torch.cuda.set_device(args.gpu)
29 | args.dist_backend = 'nccl'
30 | '''
31 | This is commented due to the stupid icoding pylint checking.
32 | print('distributed init rank {}: {}'.format(args.rank, args.dist_url), flush=True)
33 | '''
34 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
35 | world_size=args.world_size, rank=args.rank)
36 | torch.distributed.barrier()
37 |
38 |
39 | def ddp_all_reduce(*args):
40 | """ all reduce (op: sum) by ddp """
41 | t = torch.tensor([x for x in args], dtype=torch.float64, device='cuda')
42 | dist.barrier()
43 | dist.all_reduce(t)
44 | t = t.tolist()
45 | return t
46 |
47 |
48 | def ddp_all_gather(*args):
49 | """ all gather by ddp, all gather don't have grad_fn by default """
50 | rets = []
51 | world_size = dist.get_world_size()
52 | for x in args:
53 | if type(x) is torch.Tensor:
54 | ret = [torch.zeros_like(x) for _ in range(world_size)]
55 | dist.barrier()
56 | dist.all_gather(ret, x)
57 | else: # for any picklable object
58 | ret = [None for _ in range(world_size)]
59 | dist.barrier()
60 | dist.all_gather_object(ret, x)
61 | rets.append(ret)
62 | return rets if len(rets) > 1 else rets[0]
63 |
64 |
65 |
66 |
67 | def gather_labels(labels):
68 | # We gather tensors from all gpus
69 | gathered_labels = ddp_all_gather(labels)
70 | all_labels = torch.cat(gathered_labels)
71 | return all_labels
72 |
73 |
74 | def gen_label_cpu(labels):
75 | num = len(labels)
76 | gt = np.zeros(shape=(num,num))
77 | for i, label in enumerate(labels):
78 | for k in range(num):
79 | if labels[k] == label:
80 | gt[i,k] = 1
81 | return gt
82 |
83 |
84 | def gen_label(labels):
85 | num = len(labels)
86 | gt = torch.zeros(size=(num,num))
87 | labels_column = labels.reshape(-1,1).repeat(1,num)
88 | labels_row = labels.repeat(num,1)
89 | gt[labels_column == labels_row] = 1
90 | return gt
91 |
92 | def convert_models_to_fp32(model):
93 | for p in model.parameters():
94 | p.data = p.data.float()
95 | if p.grad is not None:
96 | p.grad.data = p.grad.data.float()
97 |
98 | def convert_models_to_fp16(model):
99 | # print(model)
100 | for p in model.parameters():
101 | p.data = p.data.half()
102 | p.grad.data = p.grad.data.half()
103 |
104 |
105 |
106 | def gather_features(
107 | image_features, text_features,
108 | local_loss=False, gather_with_grad=False, rank=0, world_size=1):
109 |
110 | # We gather tensors from all gpus
111 | if gather_with_grad:
112 | all_image_features = torch.cat(distnn.all_gather(image_features), dim=0)
113 | all_text_features = torch.cat(distnn.all_gather(text_features), dim=0)
114 | else:
115 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
116 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
117 | dist.all_gather(gathered_image_features, image_features)
118 | dist.all_gather(gathered_text_features, text_features)
119 | if not local_loss:
120 | # ensure grads for local rank when all_* features don't have a gradient
121 | gathered_image_features[rank] = image_features
122 | gathered_text_features[rank] = text_features
123 | all_image_features = torch.cat(gathered_image_features, dim=0)
124 | all_text_features = torch.cat(gathered_text_features, dim=0)
125 | return all_image_features, all_text_features
126 |
127 |
128 |
129 | def create_logits(image_features, text_features, logit_scale, local_loss=False):
130 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
131 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
132 | if dist.get_world_size() > 1:
133 | all_image_features, all_text_features = gather_features(
134 | image_features, text_features,
135 | local_loss=local_loss, gather_with_grad=False,
136 | rank=dist.get_rank(), world_size=dist.get_world_size())
137 |
138 | # cosine similarity as logits
139 | if local_loss:
140 | logits_per_image = logit_scale * image_features @ all_text_features.T
141 | logits_per_text = logit_scale * text_features @ all_image_features.T
142 | else:
143 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
144 | logits_per_text = logits_per_image.T
145 |
146 | else:
147 | logits_per_image = logit_scale * image_features @ text_features.T
148 | logits_per_text = logit_scale * text_features @ image_features.T
149 |
150 | # shape = [global_batch_size, global_batch_size]
151 | return logits_per_image, logits_per_text
152 |
153 |
154 |
155 |
156 |
157 | def epoch_saving(epoch, model, video_head, optimizer, filename):
158 | torch.save({
159 | 'epoch': epoch,
160 | 'model_state_dict': model.state_dict(),
161 | 'fusion_model_state_dict': video_head.state_dict(),
162 | 'optimizer_state_dict': optimizer.state_dict(),
163 | }, filename) #just change to your preferred folder/filename
164 |
165 | def best_saving(working_dir, epoch, model, video_head, optimizer):
166 | best_name = '{}/model_best.pt'.format(working_dir)
167 | torch.save({
168 | 'epoch': epoch,
169 | 'model_state_dict': model.state_dict(),
170 | 'fusion_model_state_dict': video_head.state_dict(),
171 | 'optimizer_state_dict': optimizer.state_dict(),
172 | }, best_name) # just change to your preferred folder/filename
173 |
174 |
175 | def reduce_tensor(tensor, n=None):
176 | if n is None:
177 | n = dist.get_world_size()
178 | rt = tensor.clone()
179 | dist.all_reduce(rt, op=dist.ReduceOp.SUM)
180 | rt = rt / n
181 | return rt
182 |
183 |
184 | class AverageMeter:
185 | """Computes and stores the average and current value"""
186 | def __init__(self):
187 | self.reset()
188 |
189 | def reset(self):
190 | self.val = 0
191 | self.avg = 0
192 | self.sum = 0
193 | self.count = 0
194 |
195 | def update(self, val, n=1):
196 | self.val = val
197 | self.sum += val * n
198 | self.count += n
199 | self.avg = self.sum / self.count
200 |
201 | def sync(self):
202 | rank = dist.get_rank()
203 | world_size = dist.get_world_size()
204 | val = torch.tensor(self.val).cuda()
205 | sum_v = torch.tensor(self.sum).cuda()
206 | count = torch.tensor(self.count).cuda()
207 | self.val = reduce_tensor(val, world_size).item()
208 | self.sum = reduce_tensor(sum_v, 1).item()
209 | self.count = reduce_tensor(count, 1).item()
210 | self.avg = self.sum / self.count
211 |
212 |
213 | def accuracy(output, target, topk=(1, )):
214 | """Computes the precision@k for the specified values of k"""
215 | maxk = max(topk)
216 | batch_size = target.size(0)
217 |
218 | _, pred = output.topk(maxk, 1, True, True)
219 | pred = pred.t()
220 | correct = pred.eq(target.view(1, -1).expand_as(pred))
221 | res = []
222 | for k in topk:
223 | correct_k = correct[:k].reshape(-1).float().sum(0)
224 | res.append(correct_k.mul_(100.0 / batch_size))
225 | return res
226 |
227 |
228 | from torchnet import meter
229 | def mean_average_precision(probs, labels):
230 | """Computes MAP for ActivityNet evaluation"""
231 | if not isinstance(probs, torch.Tensor):
232 | probs = torch.Tensor(probs).cuda()
233 |
234 | if not isinstance(labels, torch.Tensor):
235 | labels = torch.tensor(labels).long().cuda()
236 |
237 | gt = torch.zeros_like(probs).int()
238 | acc_meter = meter.ClassErrorMeter(topk=[1, 3], accuracy=True)
239 | gt[torch.LongTensor(range(gt.size(0))), labels] = 1
240 | acc_meter.add(probs, labels)
241 | acc = acc_meter.value()
242 |
243 | probs = torch.nn.functional.softmax(probs, dim=1)
244 |
245 | map_meter = meter.mAPMeter()
246 | map_meter.add(probs, gt)
247 | ap = map_meter.value()
248 | ap = float(ap) * 100
249 | return [torch.tensor(acc[0]).cuda(), torch.tensor(ap).cuda()]
250 |
251 |
252 | if __name__=='__main__':
253 | probs = torch.load('ANet_similarity_336.pth') # similarity logits
254 | labels = torch.load('ANet_labels_336.pth') # class ids
255 |
256 | mAP = mean_average_precision(probs, labels)
257 | print(mAP)
258 |
259 |
--------------------------------------------------------------------------------