├── .gitignore ├── LICENSE ├── README.md ├── clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── dataset ├── Augmentation.py ├── __init__.py ├── charades.py ├── kinetics.py ├── mixup.py ├── rand_augment.py ├── random_erasing.py ├── sth.py ├── sthvideo.py └── transforms.py ├── datasets.py ├── engine_for_finetuning.py ├── eva_clip ├── __init__.py ├── bpe_simple_vocab_16e6.txt.gz ├── constants.py ├── eva_vit_model.py ├── factory.py ├── hf_configs.py ├── hf_model.py ├── loss.py ├── model.py ├── model_configs │ ├── EVA01-CLIP-B-16.json │ ├── EVA01-CLIP-g-14-plus.json │ ├── EVA01-CLIP-g-14.json │ ├── EVA02-CLIP-B-16.json │ ├── EVA02-CLIP-L-14-336.json │ ├── EVA02-CLIP-L-14.json │ ├── EVA02-CLIP-bigE-14-plus.json │ └── EVA02-CLIP-bigE-14.json ├── modified_resnet.py ├── openai.py ├── pretrained.py ├── rope.py ├── timm_model.py ├── tokenizer.py ├── transform.py ├── transformer.py └── utils.py ├── lists ├── k400 │ ├── kinetics_rgb_train_se320.txt │ └── kinetics_rgb_val_se320.txt ├── kinetics_400_labels.csv ├── sth_labels.csv ├── sthv1 │ ├── train_rgb.txt │ └── val_rgb.txt └── sthv2 │ ├── train_rgb.txt │ └── val_rgb.txt ├── modules ├── ATM.py └── temporal_modeling.py ├── optim_factory.py ├── pics └── ATM.png ├── run_class_finetuning.py ├── scripts ├── k400 │ ├── train_base.sh │ ├── train_eva_large.sh │ └── train_eva_large_336.sh ├── ssv1 │ ├── test_base_f16.sh │ ├── test_base_f32.sh │ ├── test_base_f8.sh │ ├── test_large_f16.sh │ ├── train_base.sh │ └── train_eva_large.sh └── ssv2 │ ├── test_base_f16.sh │ ├── test_base_f32.sh │ ├── test_base_f8.sh │ ├── test_large_f16.sh │ ├── train_base.sh │ └── train_eva_large.sh ├── test_for_frame.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Wenhao Wu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | 6 |

【ICCV'2023】What Can Simple Arithmetic Operations Do for Temporal Modeling?

7 |
If you like our project, please give us a star ⭐ on GitHub for latest update.
8 | 9 | 10 | 11 | [![Conference](http://img.shields.io/badge/ICCV-2023-b6f107.svg)](https://openaccess.thecvf.com/content/ICCV2023/html/Wu_What_Can_Simple_Arithmetic_Operations_Do_for_Temporal_Modeling_ICCV_2023_paper.html) 12 | [![Paper](https://img.shields.io/badge/Arxiv-2311.15732-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2307.08908) 13 | 14 | 15 | [Wenhao Wu](https://whwu95.github.io/)1,2, [Yuxin Song]()2, [Zhun Sun]()2, [Jingdong Wang](https://jingdongwang2017.github.io/)3, [Chang Xu](http://changxu.xyz/)1, [Wanli Ouyang](https://wlouyang.github.io/)3,1 16 | 17 | 18 | 1[The University of Sydney](https://www.sydney.edu.au/), 2[Baidu](https://vis.baidu.com/#/), 3[Shanghai AI Lab](https://www.shlab.org.cn/) 19 | 20 |
21 | 22 | 23 | *** 24 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/what-can-simple-arithmetic-operations-do-for/action-recognition-in-videos-on-something-1)](https://paperswithcode.com/sota/action-recognition-in-videos-on-something-1?p=what-can-simple-arithmetic-operations-do-for) 25 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/what-can-simple-arithmetic-operations-do-for/action-classification-on-kinetics-400)](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=what-can-simple-arithmetic-operations-do-for) 26 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/what-can-simple-arithmetic-operations-do-for/action-recognition-in-videos-on-something)](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=what-can-simple-arithmetic-operations-do-for) 27 | 28 | This is the official implementation of our **ATM** (Arithmetic Temporal Module), which explores the potential of four simple arithmetic operations for temporal modeling. 29 | 30 | Our best model can achieve **89.4%** Top-1 Acc. on Kinetics-400, **65.6%** Top-1 Acc. on Something-Something V1, **74.6%** Top-1 Acc. on Something-Something V2! 31 | 32 | 33 |
🔥 I also have other recent video recognition projects that may interest you ✨.

34 | 35 | > [**Side4Video: Spatial-Temporal Side Network for Memory-Efficient Image-to-Video Transfer Learning**](https://arxiv.org/abs/2311.15769)
36 | > Huanjin Yao, Wenhao Wu, Zhiheng Li
37 | > [![arXiv](https://img.shields.io/badge/Arxiv-2311.15769-b31b1b.svg?logo=arXiv)](https://arxiv.org/abs/2311.15769) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/HJYao00/Side4Video) 38 | 39 | 40 | 41 | > [**Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models**](https://arxiv.org/abs/2301.00182)
42 | > Wenhao Wu, Xiaohan Wang, Haipeng Luo, Jingdong Wang, Yi Yang, Wanli Ouyang
43 | > [![Conference](http://img.shields.io/badge/CVPR-2023-f9f107.svg)](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Bidirectional_Cross-Modal_Knowledge_Exploration_for_Video_Recognition_With_Pre-Trained_Vision-Language_CVPR_2023_paper.html) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/whwu95/BIKE) 44 | 45 | 46 | > [**Revisiting Classifier: Transferring Vision-Language Models for Video Recognition**](https://arxiv.org/abs/2207.01297)
47 | > Wenhao Wu, Zhun Sun, Wanli Ouyang
48 | > [![Conference](http://img.shields.io/badge/AAAI-2023-f9f107.svg)](https://ojs.aaai.org/index.php/AAAI/article/view/25386/25158) [![Journal](http://img.shields.io/badge/IJCV-2023-Bf107.svg)](https://link.springer.com/article/10.1007/s11263-023-01876-w) [![github](https://img.shields.io/badge/-Github-black?logo=github)](https://github.com/whwu95/Text4Vis) 49 | 50 | 51 | 52 | 53 | 54 |

55 | 56 | 57 | 64 | 65 | 66 | 67 | ## 📣 News 68 | 69 | 70 | - [x] `Nov 29, 2023`: Training codes have be released. 71 | - [x] `July 14, 2023`: 🎉Our **ATM** has been accepted by **ICCV-2023**. 72 | 73 | 74 | 75 | ## 🌈 Overview 76 | ![ATM](pics/ATM.png) 77 | The key motivation behind ATM is to explore the potential of simple arithmetic operations to capture auxiliary temporal clues that may be embedded in current video features, without relying on the elaborate design. The ATM can be integrated into both vanilla CNN backbone (e.g., ResNet) and Vision Transformer (e.g., ViT) for video action recognition. 78 | 79 | 80 | ## 🚀 Training & Testing 81 | We offer training and testing scripts for Kinetics-400, Sth-Sth V1, and Sth-Sth V2. Please refer to the [*script*](https://github.com/whwu95/ATM/tree/main/scripts) folder for details. For example, you can run: 82 | 83 | ```sh 84 | # Train the 8 Frames ViT-B/32 model on Sth-Sth v1. 85 | sh scripts/ssv1/train_base.sh 86 | 87 | # Test the 8 Frames ViT-B/32 model on Sth-Sth v1. 88 | sh scripts/ssv1/test_base_f8.sh 89 | ``` 90 | 91 | 92 | 93 | 94 | ## 📌 BibTeX & Citation 95 | 96 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry😁. 97 | 98 | 99 | ```bibtex 100 | @inproceedings{atm, 101 | title={What Can Simple Arithmetic Operations Do for Temporal Modeling?}, 102 | author={Wu, Wenhao and Song, Yuxin and Sun, Zhun and Wang, Jingdong and Xu, Chang and Ouyang, Wanli}, 103 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)}, 104 | year={2023} 105 | } 106 | ``` 107 | 108 | 109 | 110 | ## 🎗️ Acknowledgement 111 | 112 | This repository is built upon portions of [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [CLIP](https://github.com/openai/CLIP), and [EVA](https://github.com/baaivision/EVA). Thanks to the contributors of these great codebases. 113 | 114 | 115 | ## 👫 Contact 116 | For any question, please file an issue or contact [Wenhao Wu](https://whwu95.github.io/). 117 | -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | 3 | 4 | 5 | -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14-336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt" 40 | } 41 | 42 | 43 | def _download(url: str, root: str = os.path.expanduser("./clip_pretrain")): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load( 95 | name: str, 96 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", 97 | jit=True, 98 | tsm=False, joint=False, T=8, dropout=0., 99 | emb_dropout=0., 100 | pretrain=True): 101 | """Load a CLIP model 102 | 103 | Parameters 104 | ---------- 105 | name : str 106 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 107 | 108 | device : Union[str, torch.device] 109 | The device to put the loaded model 110 | 111 | jit : bool 112 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 113 | 114 | download_root: str 115 | path to download the model files; by default, it uses "~/.cache/clip" 116 | 117 | Returns 118 | ------- 119 | model : torch.nn.Module 120 | The CLIP model 121 | 122 | preprocess : Callable[[PIL.Image], torch.Tensor] 123 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 124 | """ 125 | if name in _MODELS: 126 | model_path = _download(_MODELS[name]) 127 | elif os.path.isfile(name): 128 | model_path = name 129 | else: 130 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 131 | 132 | try: 133 | # loading JIT archive 134 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 135 | state_dict = None 136 | except RuntimeError: 137 | # loading saved state dict 138 | if jit: 139 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 140 | jit = False 141 | state_dict = torch.load(model_path, map_location="cpu") 142 | 143 | if not jit: 144 | model = build_model(state_dict or model.state_dict(), joint=joint,tsm=tsm,T=T,dropout=dropout, emb_dropout=emb_dropout,pretrain=pretrain).to(device) 145 | if str(device) == "cpu": 146 | model.float() 147 | return model, model.state_dict() 148 | 149 | # patch the device names 150 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 151 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 152 | 153 | def patch_device(module): 154 | try: 155 | graphs = [module.graph] if hasattr(module, "graph") else [] 156 | except RuntimeError: 157 | graphs = [] 158 | 159 | if hasattr(module, "forward1"): 160 | graphs.append(module.forward1.graph) 161 | 162 | for graph in graphs: 163 | for node in graph.findAllNodes("prim::Constant"): 164 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 165 | node.copyAttributes(device_node) 166 | 167 | model.apply(patch_device) 168 | 169 | if str(device) == "cpu": 170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 172 | float_node = float_input.node() 173 | 174 | def patch_float(module): 175 | try: 176 | graphs = [module.graph] if hasattr(module, "graph") else [] 177 | except RuntimeError: 178 | graphs = [] 179 | 180 | if hasattr(module, "forward1"): 181 | graphs.append(module.forward1.graph) 182 | 183 | for graph in graphs: 184 | for node in graph.findAllNodes("aten::to"): 185 | inputs = list(node.inputs()) 186 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 187 | if inputs[i].node()["value"] == 5: 188 | inputs[i].node().copyAttributes(float_node) 189 | 190 | model.apply(patch_float) 191 | patch_float(model.encode_image) 192 | patch_float(model.encode_text) 193 | 194 | model.float() 195 | 196 | return model, _transform(model.input_resolution.item()) 197 | 198 | 199 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 200 | """ 201 | Returns the tokenized representation of given input string(s) 202 | 203 | Parameters 204 | ---------- 205 | texts : Union[str, List[str]] 206 | An input string or a list of input strings to tokenize 207 | 208 | context_length : int 209 | The context length to use; all CLIP models use 77 as the context length 210 | 211 | truncate: bool 212 | Whether to truncate the text in case its encoding is longer than the context length 213 | 214 | Returns 215 | ------- 216 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 217 | """ 218 | if isinstance(texts, str): 219 | texts = [texts] 220 | 221 | sot_token = _tokenizer.encoder["<|startoftext|>"] 222 | eot_token = _tokenizer.encoder["<|endoftext|>"] 223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | 226 | for i, tokens in enumerate(all_tokens): 227 | if len(tokens) > context_length: 228 | if truncate: 229 | tokens = tokens[:context_length] 230 | tokens[-1] = eot_token 231 | else: 232 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 233 | result[i, :len(tokens)] = torch.tensor(tokens) 234 | 235 | return result 236 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /dataset/Augmentation.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data._utils.collate import default_collate 2 | from dataset.transforms import * 3 | from dataset.random_erasing import RandomErasing 4 | 5 | 6 | class GroupTransform(object): 7 | def __init__(self, transform): 8 | self.worker = transform 9 | 10 | def __call__(self, img): 11 | img_group, label = img 12 | return [self.worker(img) for img in img_group], label 13 | 14 | 15 | class SplitLabel(object): 16 | def __init__(self, transform): 17 | self.worker = transform 18 | 19 | def __call__(self, img): 20 | img_group, label = img 21 | return self.worker(img_group), label 22 | 23 | 24 | 25 | def train_augmentation(input_size, flip=True): 26 | if flip: 27 | return torchvision.transforms.Compose([ 28 | GroupRandomSizedCrop(input_size), 29 | GroupRandomHorizontalFlip(is_flow=False)]) 30 | else: 31 | return torchvision.transforms.Compose([ 32 | GroupRandomSizedCrop(input_size), 33 | # GroupMultiScaleCrop(input_size, [1, .875, .75, .66]), 34 | GroupRandomHorizontalFlip_sth()]) 35 | 36 | 37 | def get_augmentation(training, input_size=224, config=None): 38 | input_mean = [0.48145466, 0.4578275, 0.40821073] 39 | input_std = [0.26862954, 0.26130258, 0.27577711] 40 | scale_size = 256 if input_size == 224 else input_size 41 | 42 | normalize = GroupNormalize(input_mean, input_std) 43 | #if 'something' in config.data.dataset: 44 | if 'SS' in config.data_set: 45 | if scale_size == 256: 46 | groupscale = GroupScale((256, 320)) 47 | else: 48 | groupscale = GroupScale(int(scale_size)) 49 | else: 50 | groupscale = GroupScale(int(scale_size)) 51 | 52 | 53 | common = torchvision.transforms.Compose([ 54 | Stack(roll=False), 55 | ToTorchFormatTensor(div=True), 56 | normalize]) 57 | 58 | if training: 59 | auto_transform = None 60 | erase_transform = None 61 | if config.aa: ###!!! ss for True, k400 for False 62 | print('***'*20) 63 | print('use random_augment!!!') 64 | auto_transform = create_random_augment( 65 | input_size=256, # scale_size 66 | auto_augment="rand-m7-n4-mstd0.5-inc1", 67 | interpolation="bicubic" 68 | ) 69 | # if config.rand_erase: 70 | # print('***'*20) 71 | # print('use Random_Erasing!!!') 72 | # erase_transform = RandomErasing( 73 | # 0.25, 74 | # mode='pixel', 75 | # max_count=1, 76 | # num_splits=1, 77 | # device="cpu", 78 | # ) 79 | 80 | train_aug = train_augmentation( 81 | input_size, 82 | flip=False if 'SS' in config.data_set else True) 83 | 84 | unique = torchvision.transforms.Compose([ 85 | groupscale, 86 | train_aug, 87 | GroupRandomGrayscale(p=0 if 'SS' in config.data_set else 0.2), 88 | ]) 89 | 90 | if auto_transform is not None: 91 | print('=> ########## Using RandAugment!') 92 | unique = torchvision.transforms.Compose([ 93 | SplitLabel(auto_transform), unique]) 94 | 95 | if erase_transform is not None: 96 | print('=> ########## Using RandErasing!') 97 | return torchvision.transforms.Compose([ 98 | unique, common, SplitLabel(erase_transform) 99 | ]) 100 | 101 | return torchvision.transforms.Compose([unique, common]) 102 | 103 | else: 104 | unique = torchvision.transforms.Compose([ 105 | groupscale, 106 | GroupCenterCrop(input_size)]) 107 | return torchvision.transforms.Compose([unique, common]) 108 | 109 | 110 | 111 | 112 | 113 | 114 | def multiple_samples_collate(batch): 115 | """ 116 | Collate function for repeated augmentation. Each instance in the batch has 117 | more than one sample. 118 | Args: 119 | batch (tuple or list): data batch to collate. 120 | Returns: 121 | (tuple): collated data batch. 122 | """ 123 | inputs, labels = zip(*batch) 124 | # print(inputs, flush=True) 125 | # print(labels, flush=True) 126 | inputs = [item for sublist in inputs for item in sublist] 127 | labels = [item for sublist in labels for item in sublist] 128 | inputs, labels = ( 129 | default_collate(inputs), 130 | default_collate(labels), 131 | ) 132 | return inputs, labels -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/dataset/__init__.py -------------------------------------------------------------------------------- /dataset/charades.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import decord 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import io 8 | import pandas as pd 9 | import random 10 | from PIL import Image 11 | import math 12 | import copy 13 | import csv 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | # return int(self._data[1]) 26 | return 27 | 28 | @property 29 | def label(self): 30 | return int(self._data[1][1:]) 31 | 32 | @property 33 | def start_time(self): 34 | return float(self._data[2]) 35 | 36 | @property 37 | def end_time(self): 38 | return float(self._data[3]) 39 | 40 | @property 41 | def total_time(self): 42 | return float(self._data[4]) 43 | 44 | 45 | class Video_dataset(data.Dataset): 46 | def __init__(self, root_path, list_file, labels_file, 47 | num_segments=1, modality='RGB', new_length=1, 48 | image_tmpl='img_{:05d}.jpg', transform=None, 49 | random_shift=True, test_mode=False, 50 | index_bias=1, dense_sample=False, test_clips=1, 51 | num_sample=1, fps=24, mode='train'): 52 | 53 | self.root_path = root_path 54 | self.list_file = list_file 55 | self.num_segments = num_segments 56 | self.modality = modality 57 | self.seg_length = new_length 58 | self.image_tmpl = image_tmpl 59 | self.transform = transform 60 | self.random_shift = random_shift 61 | self.test_mode = test_mode 62 | self.mode = mode 63 | self.loop=False 64 | self.index_bias = index_bias 65 | self.labels_file = labels_file 66 | self.sample_range = 128 67 | self.dense_sample = dense_sample # using dense sample as I3D 68 | self.test_clips = test_clips 69 | self.num_sample = num_sample 70 | if self.dense_sample: 71 | print('=> Using dense sample for the dataset...') 72 | if self.num_sample > 1: 73 | print('=> Using repeated augmentation...') 74 | 75 | if self.index_bias is None: 76 | if self.image_tmpl == "frame{:d}.jpg": 77 | self.index_bias = 0 78 | else: 79 | self.index_bias = 1 80 | self._parse_list() 81 | self.initialized = False 82 | self.fps = fps 83 | 84 | @property 85 | def total_length(self): 86 | return self.num_segments * self.seg_length 87 | 88 | @property 89 | def classes(self): 90 | with open(self.labels_file, "r") as f: 91 | classes_all = [line.strip('\n').split(' ', 1) for line in f.readlines()] 92 | return classes_all 93 | 94 | def _parse_list(self): 95 | # check the frame number is large >3: 96 | if not self.test_mode: 97 | # read csv 98 | with open(self.list_file, "r") as f: 99 | reader = csv.reader(f) 100 | tmp = [row for row in reader][1:] 101 | 102 | tmp = [t for t in tmp if float(t[3]) > float(t[2])] 103 | self.video_list = [VideoRecord(item) for item in tmp] 104 | else: 105 | with open(self.list_file, "r") as f: 106 | reader = csv.reader(f) 107 | self.video_list = [row for row in reader][1:] 108 | print('video number:%d' % (len(self.video_list))) 109 | 110 | def _sample_indices(self, video_list_len, record): 111 | if self.dense_sample: 112 | sample_pos = max(1, 1 + video_list_len - self.sample_range) 113 | interval = self.sample_range // self.num_segments 114 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 115 | base_offsets = np.arange(self.num_segments) * interval 116 | offsets = (base_offsets + start_idx) % video_list_len 117 | return np.array(offsets) + self.index_bias 118 | else: 119 | if video_list_len <= self.total_length: 120 | import torch.distributed as dist 121 | print('record_id=',record.path, 122 | 'start_time===',record.start_time, 123 | 'end_time==',record.end_time, 124 | 'total_time==',record.total_time, 125 | 'label===',record.label) 126 | print('rank===',dist.get_rank()) 127 | if self.loop: 128 | return np.mod(np.arange( 129 | self.total_length) + randint(video_list_len // 2), 130 | video_list_len) + self.index_bias 131 | offsets = np.concatenate(( 132 | np.arange(video_list_len), 133 | randint(video_list_len, 134 | size=self.total_length - video_list_len))) 135 | return np.sort(offsets) + self.index_bias 136 | offsets = list() 137 | ticks = [i * video_list_len // self.num_segments 138 | for i in range(self.num_segments + 1)] 139 | 140 | for i in range(self.num_segments): 141 | tick_len = ticks[i + 1] - ticks[i] 142 | tick = ticks[i] 143 | if tick_len >= self.seg_length: 144 | tick += randint(tick_len - self.seg_length + 1) 145 | offsets.extend([j for j in range(tick, tick + self.seg_length)]) 146 | return np.array(offsets) + self.index_bias 147 | 148 | 149 | def _get_test_indices(self, video_list): 150 | if self.dense_sample: 151 | # multi-clip for dense sampling 152 | num_clips = self.test_clips 153 | sample_pos = max(0, len(video_list) - self.sample_range) 154 | interval = self.sample_range // self.num_segments 155 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)] 156 | base_offsets = np.arange(self.num_segments) * interval 157 | offsets = [] 158 | for start_idx in start_list: 159 | offsets.extend((base_offsets + start_idx) % len(video_list)) 160 | return np.array(offsets) + self.index_bias 161 | else: 162 | # multi-clip for uniform sampling 163 | num_clips = self.test_clips 164 | tick = len(video_list) / float(self.num_segments) 165 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int) 166 | offsets = [] 167 | for start_idx in start_list.tolist(): 168 | offsets += [ 169 | int(start_idx + tick * x) % len(video_list) 170 | for x in range(self.num_segments) 171 | ] 172 | return np.array(offsets) + self.index_bias 173 | 174 | 175 | def _decord_decode(self, video_path): 176 | try: 177 | container = decord.VideoReader(video_path) 178 | except Exception as e: 179 | print("Failed to decode {} with exception: {}".format( 180 | video_path, e)) 181 | return None 182 | 183 | return container 184 | 185 | def __getitem__(self, index): 186 | # decode frames to video_list 187 | if self.modality == 'video': 188 | _num_retries = 10 189 | for i_try in range(_num_retries): 190 | record = copy.deepcopy(self.video_list[index]) 191 | directory = os.path.join(self.root_path, record.path) 192 | video_list = self._decord_decode(directory) 193 | # video_list = self._decord_pyav(directory) 194 | if video_list is None: 195 | print("Failed to decode video idx {} from {}; trial {}".format( 196 | index, directory, i_try) 197 | ) 198 | index = random.randint(0, len(self.video_list)) 199 | continue 200 | break 201 | else: 202 | record = self.video_list[index] 203 | 204 | 205 | if not self.test_mode: # train 206 | video_list = os.listdir(os.path.join(self.root_path, record.path)) 207 | end_time = min(record.end_time, record.total_time) 208 | video_list_len = int(end_time * self.fps - record.start_time * self.fps) 209 | 210 | segment_indices = self._sample_indices(video_list_len, record) 211 | segment_indices = segment_indices + int(record.start_time * self.fps) 212 | return self.get(record, video_list, segment_indices) 213 | else: # test 214 | test_record = record 215 | video_list = os.listdir(os.path.join(self.root_path, test_record[0])) 216 | target = torch.IntTensor(157).zero_() #size=(157),全部为0,one-hot编码 217 | 218 | if test_record[9] != '': 219 | labels_mess = test_record[9].split(';') 220 | labels = [mess.split(' ')[0] for mess in labels_mess] 221 | for x in labels: 222 | target[int(x[1:])] = 1 #得到视频的类标签,然后转换成int,然后在one-hot相应位置赋值为1 223 | segment_indices = self._get_test_indices(video_list) 224 | 225 | if self.mode == 'validation': 226 | return self.val_get(test_record, video_list, segment_indices, target) 227 | else: 228 | return self.test_get(test_record, video_list, segment_indices, target) 229 | 230 | 231 | def _load_image(self, directory, idx): 232 | if self.modality == 'RGB': 233 | try: 234 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(directory, idx))).convert('RGB')] 235 | except Exception: 236 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(directory,idx))) 237 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(directory, 1))).convert('RGB')] 238 | 239 | 240 | def get(self, record, video_list, indices): 241 | images = list() 242 | for seg_ind in indices: 243 | p = int(seg_ind) 244 | if self.modality == 'video': 245 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 246 | else: 247 | seg_imgs = self._load_image(record.path, p) 248 | images.extend(seg_imgs) 249 | if p < len(video_list): 250 | p += 1 251 | 252 | if self.num_sample > 1: 253 | frame_list = [] 254 | label_list = [] 255 | for _ in range(self.num_sample): 256 | process_data, record_label = self.transform((images, record.label)) 257 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 258 | frame_list.append(process_data) 259 | label_list.append(record_label) 260 | return frame_list, label_list, 0, 0 261 | else: 262 | process_data, record_label = self.transform((images, record.label)) 263 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 264 | return process_data, record_label, 0, 0 265 | 266 | def val_get(self, record, video_list, indices, test_label): 267 | images = list() 268 | for seg_ind in indices: 269 | p = int(seg_ind) 270 | if self.modality == 'video': 271 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 272 | else: 273 | seg_imgs = self._load_image(record[0], p) 274 | images.extend(seg_imgs) 275 | if p < len(video_list): 276 | p += 1 277 | 278 | process_data, _ = self.transform((images, test_label)) 279 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 280 | return process_data, test_label, 0, 0 281 | 282 | 283 | def test_get(self, record, video_list, indices, test_label): 284 | images = list() 285 | for seg_ind in indices: 286 | p = int(seg_ind) 287 | if self.modality == 'video': 288 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 289 | else: 290 | seg_imgs = self._load_image(record[0], p) 291 | images.extend(seg_imgs) 292 | if p < len(video_list): 293 | p += 1 294 | 295 | process_data, _ = self.transform((images, test_label)) 296 | return process_data, test_label 297 | 298 | 299 | def __len__(self): 300 | return len(self.video_list) 301 | -------------------------------------------------------------------------------- /dataset/kinetics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import decord 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import io 8 | import pandas as pd 9 | import random 10 | from PIL import Image 11 | import math 12 | import copy 13 | 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | return int(self._data[1]) 26 | 27 | @property 28 | def label(self): 29 | return int(self._data[-1]) 30 | 31 | 32 | class Video_dataset(data.Dataset): 33 | def __init__(self, root_path, list_file, labels_file, 34 | num_segments=1, modality='RGB', new_length=1, 35 | image_tmpl='img_{:05d}.jpg', transform=None, 36 | random_shift=True, test_mode=False, 37 | index_bias=1, dense_sample=False, test_clips=3, 38 | num_sample=1): 39 | 40 | self.root_path = root_path 41 | self.list_file = list_file 42 | self.num_segments = num_segments 43 | self.modality = modality 44 | self.seg_length = new_length 45 | self.image_tmpl = image_tmpl 46 | self.transform = transform 47 | self.random_shift = random_shift 48 | self.test_mode = test_mode 49 | self.loop=False 50 | self.index_bias = index_bias 51 | self.labels_file = labels_file 52 | self.sample_range = 128 53 | self.dense_sample = dense_sample # using dense sample as I3D 54 | self.test_clips = test_clips 55 | self.num_sample = num_sample 56 | if self.dense_sample: 57 | print('=> Using dense sample for the dataset...') 58 | if self.num_sample > 1: 59 | print('=> Using repeated augmentation...') 60 | 61 | if self.index_bias is None: 62 | if self.image_tmpl == "frame{:d}.jpg": 63 | self.index_bias = 0 64 | else: 65 | self.index_bias = 1 66 | self._parse_list() 67 | self.initialized = False 68 | 69 | @property 70 | def total_length(self): 71 | return self.num_segments * self.seg_length 72 | 73 | @property 74 | def classes(self): 75 | classes_all = pd.read_csv(self.labels_file) 76 | return classes_all.values.tolist() 77 | 78 | def _parse_list(self): 79 | tmp = [x.strip().split(' ') for x in open(self.list_file)]#######for debug ###!!! 80 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config 81 | if not self.test_mode: 82 | tmp = [item for item in tmp if int(item[1]) >= 8] 83 | self.video_list = [VideoRecord(item) for item in tmp] 84 | print('video number:%d' % (len(self.video_list))) 85 | 86 | def _sample_indices(self, video_list): 87 | if self.dense_sample: 88 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 89 | interval = self.sample_range // self.num_segments 90 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 91 | base_offsets = np.arange(self.num_segments) * interval 92 | offsets = (base_offsets + start_idx) % len(video_list) 93 | return np.array(offsets) + self.index_bias 94 | else: 95 | seg_size = float(len(video_list) - 1) / self.num_segments 96 | offsets = [] 97 | for i in range(self.num_segments): 98 | start = int(np.round(seg_size * i)) 99 | end = int(np.round(seg_size * (i + 1))) 100 | offsets.append(random.randint(start, end)) 101 | 102 | return np.array(offsets) + self.index_bias 103 | 104 | def _get_val_indices(self, video_list): 105 | if self.dense_sample: 106 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 107 | t_stride = self.sample_range // self.num_segments 108 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 109 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)] 110 | return np.array(offsets) + self.index_bias 111 | else: 112 | tick = len(video_list) / float(self.num_segments) 113 | offsets = [int(tick * x) % len(video_list) for x in range(self.num_segments)] 114 | 115 | return np.array(offsets) + self.index_bias 116 | 117 | 118 | def _get_test_indices(self, video_list): 119 | if self.dense_sample: 120 | # multi-clip for dense sampling 121 | num_clips = self.test_clips 122 | sample_pos = max(0, len(video_list) - self.sample_range) 123 | interval = self.sample_range // self.num_segments 124 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)] 125 | base_offsets = np.arange(self.num_segments) * interval 126 | offsets = [] 127 | for start_idx in start_list: 128 | offsets.extend((base_offsets + start_idx) % len(video_list)) 129 | return np.array(offsets) + self.index_bias 130 | else: 131 | # multi-clip for uniform sampling 132 | num_clips = self.test_clips 133 | 134 | tick = len(video_list) / float(self.num_segments) 135 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int) 136 | offsets = [] 137 | for start_idx in start_list.tolist(): 138 | offsets += [ 139 | int(start_idx + tick * x) % len(video_list) 140 | for x in range(self.num_segments) 141 | ] 142 | return np.array(offsets) + self.index_bias 143 | 144 | 145 | def _decord_decode(self, video_path): 146 | try: 147 | container = decord.VideoReader(video_path) 148 | except Exception as e: 149 | print("Failed to decode {} with exception: {}".format( 150 | video_path, e)) 151 | return None 152 | 153 | return container 154 | 155 | def __getitem__(self, index): 156 | # decode frames to video_list 157 | if self.modality == 'video': 158 | _num_retries = 10 159 | for i_try in range(_num_retries): 160 | record = copy.deepcopy(self.video_list[index]) 161 | directory = os.path.join(self.root_path, record.path) 162 | video_list = self._decord_decode(directory) 163 | # video_list = self._decord_pyav(directory) 164 | if video_list is None: 165 | print("Failed to decode video idx {} from {}; trial {}".format( 166 | index, directory, i_try) 167 | ) 168 | index = random.randint(0, len(self.video_list)) 169 | continue 170 | break 171 | else: 172 | record = self.video_list[index] 173 | video_list = os.listdir(os.path.join(self.root_path, record.path)) 174 | 175 | if not self.test_mode: # train/val 176 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list) 177 | else: # test 178 | segment_indices = self._get_test_indices(video_list) 179 | 180 | return self.get(record, video_list, segment_indices) 181 | 182 | 183 | def _load_image(self, directory, idx): 184 | if self.modality == 'RGB': 185 | try: 186 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 187 | except Exception: 188 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 189 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 190 | 191 | 192 | def get(self, record, video_list, indices): 193 | images = list() 194 | for seg_ind in indices: 195 | p = int(seg_ind) 196 | if self.modality == 'video': 197 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 198 | else: 199 | seg_imgs = self._load_image(record.path, p) 200 | images.extend(seg_imgs) 201 | if p < len(video_list): 202 | p += 1 203 | 204 | if self.test_mode: 205 | process_data, record_label = self.transform((images, record.label)) 206 | return process_data, record_label 207 | else: 208 | if self.num_sample > 1: 209 | frame_list = [] 210 | label_list = [] 211 | for _ in range(self.num_sample): 212 | process_data, record_label = self.transform((images, record.label)) 213 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 214 | 215 | frame_list.append(process_data) 216 | label_list.append(record_label) 217 | return frame_list, label_list, 0, 0 218 | else: 219 | process_data, record_label = self.transform((images, record.label)) 220 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 221 | return process_data, record_label, 0, 0 222 | 223 | def __len__(self): 224 | return len(self.video_list) 225 | -------------------------------------------------------------------------------- /dataset/random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This implementation is based on 3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py 4 | pulished under an Apache License 2.0. 5 | """ 6 | import math 7 | import random 8 | import torch 9 | 10 | 11 | def _get_pixels( 12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" 13 | ): 14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() 15 | # paths, flip the order so normal is run on CPU if this becomes a problem 16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 17 | if per_pixel: 18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_() 19 | elif rand_color: 20 | return torch.empty( 21 | (patch_size[0], 1, 1), dtype=dtype, device=device 22 | ).normal_() 23 | else: 24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) 25 | 26 | 27 | class RandomErasing: 28 | """Randomly selects a rectangle region in an image and erases its pixels. 29 | 'Random Erasing Data Augmentation' by Zhong et al. 30 | See https://arxiv.org/pdf/1708.04896.pdf 31 | This variant of RandomErasing is intended to be applied to either a batch 32 | or single image tensor after it has been normalized by dataset mean and std. 33 | Args: 34 | probability: Probability that the Random Erasing operation will be performed. 35 | min_area: Minimum percentage of erased area wrt input image area. 36 | max_area: Maximum percentage of erased area wrt input image area. 37 | min_aspect: Minimum aspect ratio of erased area. 38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel' 39 | 'const' - erase block is constant color of 0 for all channels 40 | 'rand' - erase block is same per-channel random (normal) color 41 | 'pixel' - erase block is per-pixel random (normal) color 42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count. 43 | per-image count is randomly chosen between 1 and this value. 44 | """ 45 | 46 | def __init__( 47 | self, 48 | probability=0.5, 49 | min_area=0.02, 50 | max_area=1 / 3, 51 | min_aspect=0.3, 52 | max_aspect=None, 53 | mode="const", 54 | min_count=1, 55 | max_count=None, 56 | num_splits=0, 57 | device="cuda", 58 | cube=True, 59 | ): 60 | self.probability = probability 61 | self.min_area = min_area 62 | self.max_area = max_area 63 | max_aspect = max_aspect or 1 / min_aspect 64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 65 | self.min_count = min_count 66 | self.max_count = max_count or min_count 67 | self.num_splits = num_splits 68 | mode = mode.lower() 69 | self.rand_color = False 70 | self.per_pixel = False 71 | self.cube = cube 72 | if mode == "rand": 73 | self.rand_color = True # per block random normal 74 | elif mode == "pixel": 75 | self.per_pixel = True # per pixel random normal 76 | else: 77 | assert not mode or mode == "const" 78 | self.device = device 79 | 80 | def _erase(self, img, chan, img_h, img_w, dtype): 81 | if random.random() > self.probability: 82 | return 83 | area = img_h * img_w 84 | count = ( 85 | self.min_count 86 | if self.min_count == self.max_count 87 | else random.randint(self.min_count, self.max_count) 88 | ) 89 | for _ in range(count): 90 | for _ in range(10): 91 | target_area = ( 92 | random.uniform(self.min_area, self.max_area) * area / count 93 | ) 94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 95 | h = int(round(math.sqrt(target_area * aspect_ratio))) 96 | w = int(round(math.sqrt(target_area / aspect_ratio))) 97 | if w < img_w and h < img_h: 98 | top = random.randint(0, img_h - h) 99 | left = random.randint(0, img_w - w) 100 | img[:, top : top + h, left : left + w] = _get_pixels( 101 | self.per_pixel, 102 | self.rand_color, 103 | (chan, h, w), 104 | dtype=dtype, 105 | device=self.device, 106 | ) 107 | break 108 | 109 | def _erase_cube( 110 | self, 111 | img, 112 | batch_start, 113 | batch_size, 114 | chan, 115 | img_h, 116 | img_w, 117 | dtype, 118 | ): 119 | if random.random() > self.probability: 120 | return 121 | area = img_h * img_w 122 | count = ( 123 | self.min_count 124 | if self.min_count == self.max_count 125 | else random.randint(self.min_count, self.max_count) 126 | ) 127 | for _ in range(count): 128 | for _ in range(100): 129 | target_area = ( 130 | random.uniform(self.min_area, self.max_area) * area / count 131 | ) 132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 133 | h = int(round(math.sqrt(target_area * aspect_ratio))) 134 | w = int(round(math.sqrt(target_area / aspect_ratio))) 135 | if w < img_w and h < img_h: 136 | top = random.randint(0, img_h - h) 137 | left = random.randint(0, img_w - w) 138 | for i in range(batch_start, batch_size): 139 | img_instance = img[i] 140 | img_instance[ 141 | :, top : top + h, left : left + w 142 | ] = _get_pixels( 143 | self.per_pixel, 144 | self.rand_color, 145 | (chan, h, w), 146 | dtype=dtype, 147 | device=self.device, 148 | ) 149 | break 150 | 151 | def __call__(self, input): 152 | if len(input.size()) == 3: 153 | self._erase(input, *input.size(), input.dtype) 154 | else: 155 | batch_size, chan, img_h, img_w = input.size() 156 | # skip first slice of batch if num_splits is set (for clean portion of samples) 157 | batch_start = ( 158 | batch_size // self.num_splits if self.num_splits > 1 else 0 159 | ) 160 | if self.cube: 161 | self._erase_cube( 162 | input, 163 | batch_start, 164 | batch_size, 165 | chan, 166 | img_h, 167 | img_w, 168 | input.dtype, 169 | ) 170 | else: 171 | for i in range(batch_start, batch_size): 172 | self._erase(input[i], chan, img_h, img_w, input.dtype) 173 | return input 174 | -------------------------------------------------------------------------------- /dataset/sth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import decord 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import io 8 | import pandas as pd 9 | import random 10 | from PIL import Image 11 | import math 12 | import copy 13 | 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | return int(self._data[1]) 26 | 27 | @property 28 | def label(self): 29 | return int(self._data[-1]) 30 | 31 | 32 | class Video_dataset(data.Dataset): 33 | def __init__(self, root_path, list_file, labels_file, 34 | num_segments=1, modality='RGB', new_length=1, 35 | image_tmpl='img_{:05d}.jpg', transform=None, 36 | random_shift=True, test_mode=False, 37 | index_bias=1, dense_sample=False, test_clips=3, 38 | num_sample=1): 39 | 40 | self.root_path = root_path 41 | self.list_file = list_file 42 | self.num_segments = num_segments 43 | self.modality = modality 44 | self.seg_length = new_length 45 | self.image_tmpl = image_tmpl 46 | self.transform = transform 47 | self.random_shift = random_shift 48 | self.test_mode = test_mode 49 | self.loop=False 50 | self.index_bias = index_bias 51 | self.labels_file = labels_file 52 | self.sample_range = 128 53 | self.dense_sample = dense_sample # using dense sample as I3D 54 | self.test_clips = test_clips 55 | self.num_sample = num_sample 56 | if self.dense_sample: 57 | print('=> Using dense sample for the dataset...') 58 | if self.num_sample > 1: 59 | print('=> Using repeated augmentation...') 60 | 61 | if self.index_bias is None: 62 | if self.image_tmpl == "frame{:d}.jpg": 63 | self.index_bias = 0 64 | else: 65 | self.index_bias = 1 66 | self._parse_list() 67 | self.initialized = False 68 | 69 | @property 70 | def total_length(self): 71 | return self.num_segments * self.seg_length 72 | 73 | @property 74 | def classes(self): 75 | classes_all = pd.read_csv(self.labels_file) 76 | return classes_all.values.tolist() 77 | 78 | def _parse_list(self): 79 | # check the frame number is large >3: 80 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 81 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config 82 | if not self.test_mode: 83 | tmp = [item for item in tmp if int(item[1]) >= 8] 84 | self.video_list = [VideoRecord(item) for item in tmp] 85 | print('video number:%d' % (len(self.video_list))) 86 | 87 | def _sample_indices(self, video_list): 88 | if self.dense_sample: 89 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 90 | interval = self.sample_range // self.num_segments 91 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 92 | base_offsets = np.arange(self.num_segments) * interval 93 | offsets = (base_offsets + start_idx) % len(video_list) 94 | return np.array(offsets) + self.index_bias 95 | else: 96 | if len(video_list) <= self.total_length: 97 | if self.loop: 98 | return np.mod(np.arange( 99 | self.total_length) + randint(len(video_list) // 2), 100 | len(video_list)) + self.index_bias 101 | offsets = np.concatenate(( 102 | np.arange(len(video_list)), 103 | randint(len(video_list), 104 | size=self.total_length - len(video_list)))) 105 | return np.sort(offsets) + self.index_bias 106 | offsets = list() 107 | ticks = [i * len(video_list) // self.num_segments 108 | for i in range(self.num_segments + 1)] 109 | 110 | for i in range(self.num_segments): 111 | tick_len = ticks[i + 1] - ticks[i] 112 | tick = ticks[i] 113 | if tick_len >= self.seg_length: 114 | tick += randint(tick_len - self.seg_length + 1) 115 | offsets.extend([j for j in range(tick, tick + self.seg_length)]) 116 | return np.array(offsets) + self.index_bias 117 | 118 | def _get_val_indices(self, video_list): 119 | if self.dense_sample: 120 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 121 | t_stride = self.sample_range // self.num_segments 122 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 123 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)] 124 | return np.array(offsets) + self.index_bias 125 | else: 126 | seg_size = float(len(video_list) - 1) / self.num_segments 127 | offsets = [] 128 | for i in range(self.num_segments): 129 | start = int(np.round(seg_size * i)) 130 | frame_index = start + int(seg_size / 2) 131 | offsets.append(frame_index) 132 | 133 | return np.array(offsets) + self.index_bias 134 | 135 | 136 | def _get_test_indices(self, video_list): 137 | if self.dense_sample: 138 | # multi-clip for dense sampling 139 | num_clips = self.test_clips 140 | sample_pos = max(0, len(video_list) - self.sample_range) 141 | interval = self.sample_range // self.num_segments 142 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)] 143 | base_offsets = np.arange(self.num_segments) * interval 144 | offsets = [] 145 | for start_idx in start_list: 146 | offsets.extend((base_offsets + start_idx) % len(video_list)) 147 | return np.array(offsets) + self.index_bias 148 | else: 149 | # multi-clip for uniform sampling 150 | num_clips = self.test_clips 151 | 152 | seg_size = float(len(video_list) - 1) / self.num_segments 153 | seq = [] 154 | duration = seg_size / (num_clips + 1) 155 | for temporal_sample_index in range(num_clips): 156 | for i in range(self.num_segments): 157 | start = int(np.round(seg_size * i)) 158 | frame_index = start + int(duration * (temporal_sample_index + 1)) 159 | seq.append(frame_index) 160 | return np.array(seq) + self.index_bias 161 | 162 | 163 | def _decord_decode(self, video_path): 164 | try: 165 | container = decord.VideoReader(video_path) 166 | except Exception as e: 167 | print("Failed to decode {} with exception: {}".format( 168 | video_path, e)) 169 | return None 170 | 171 | return container 172 | 173 | def __getitem__(self, index): 174 | # decode frames to video_list 175 | if self.modality == 'video': 176 | _num_retries = 10 177 | for i_try in range(_num_retries): 178 | record = copy.deepcopy(self.video_list[index]) 179 | directory = os.path.join(self.root_path, record.path) 180 | video_list = self._decord_decode(directory) 181 | # video_list = self._decord_pyav(directory) 182 | if video_list is None: 183 | print("Failed to decode video idx {} from {}; trial {}".format( 184 | index, directory, i_try) 185 | ) 186 | index = random.randint(0, len(self.video_list)) 187 | continue 188 | break 189 | else: 190 | record = self.video_list[index] 191 | video_list = os.listdir(os.path.join(self.root_path, record.path)) 192 | 193 | if not self.test_mode: # train/val 194 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list) 195 | else: # test 196 | segment_indices = self._get_test_indices(video_list) 197 | 198 | return self.get(record, video_list, segment_indices) 199 | 200 | 201 | def _load_image(self, directory, idx): 202 | if self.modality == 'RGB': 203 | try: 204 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 205 | except Exception: 206 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 207 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 208 | 209 | 210 | def get(self, record, video_list, indices): 211 | images = list() 212 | for seg_ind in indices: 213 | p = int(seg_ind) 214 | if self.modality == 'video': 215 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 216 | else: 217 | seg_imgs = self._load_image(record.path, p) 218 | images.extend(seg_imgs) 219 | if p < len(video_list): 220 | p += 1 221 | if self.num_sample > 1: 222 | frame_list = [] 223 | label_list = [] 224 | for _ in range(self.num_sample): 225 | process_data, record_label = self.transform((images, record.label)) 226 | frame_list.append(process_data) 227 | label_list.append(record_label) 228 | return frame_list, label_list 229 | else: 230 | process_data, record_label = self.transform((images, record.label)) 231 | return process_data, record_label 232 | 233 | def __len__(self): 234 | return len(self.video_list) 235 | -------------------------------------------------------------------------------- /dataset/sthvideo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | import decord 4 | import os 5 | import numpy as np 6 | from numpy.random import randint 7 | import io 8 | import pandas as pd 9 | import random 10 | from PIL import Image 11 | import math 12 | import copy 13 | 14 | 15 | class VideoRecord(object): 16 | def __init__(self, row): 17 | self._data = row 18 | 19 | @property 20 | def path(self): 21 | return self._data[0] 22 | 23 | @property 24 | def num_frames(self): 25 | return int(self._data[1]) 26 | 27 | @property 28 | def label(self): 29 | return int(self._data[-1]) 30 | 31 | 32 | class Video_dataset(data.Dataset): 33 | def __init__(self, root_path, list_file, labels_file, 34 | num_segments=1, modality='RGB', new_length=1, 35 | image_tmpl='img_{:05d}.jpg', transform=None, 36 | random_shift=True, test_mode=False, 37 | index_bias=1, dense_sample=False, test_clips=3, 38 | num_sample=1): 39 | 40 | self.root_path = root_path 41 | self.list_file = list_file 42 | self.num_segments = num_segments 43 | self.modality = modality 44 | self.seg_length = new_length 45 | self.image_tmpl = image_tmpl 46 | self.transform = transform 47 | self.random_shift = random_shift 48 | self.test_mode = test_mode 49 | self.loop=False 50 | self.index_bias = index_bias 51 | self.labels_file = labels_file 52 | self.sample_range = 128 53 | self.dense_sample = dense_sample # using dense sample as I3D 54 | self.test_clips = test_clips 55 | self.num_sample = num_sample 56 | if self.dense_sample: 57 | print('=> Using dense sample for the dataset...') 58 | if self.num_sample > 1: 59 | print('=> Using repeated augmentation...') 60 | 61 | if self.index_bias is None: 62 | if self.image_tmpl == "frame{:d}.jpg": 63 | self.index_bias = 0 64 | else: 65 | self.index_bias = 1 66 | self._parse_list() 67 | self.initialized = False 68 | 69 | @property 70 | def total_length(self): 71 | return self.num_segments * self.seg_length 72 | 73 | @property 74 | def classes(self): 75 | classes_all = pd.read_csv(self.labels_file) 76 | return classes_all.values.tolist() 77 | 78 | def _parse_list(self): 79 | # check the frame number is large >3: 80 | tmp = [x.strip().split(' ') for x in open(self.list_file)] 81 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config 82 | if not self.test_mode: 83 | tmp = [item for item in tmp if int(item[1]) >= 8] 84 | self.video_list = [VideoRecord(item) for item in tmp] 85 | print('video number:%d' % (len(self.video_list))) 86 | 87 | def _sample_indices(self, video_list): 88 | if self.dense_sample: 89 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 90 | interval = self.sample_range // self.num_segments 91 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 92 | base_offsets = np.arange(self.num_segments) * interval 93 | offsets = (base_offsets + start_idx) % len(video_list) 94 | return np.array(offsets) + self.index_bias 95 | else: 96 | if len(video_list) <= self.total_length: 97 | if self.loop: 98 | return np.mod(np.arange( 99 | self.total_length) + randint(len(video_list) // 2), 100 | len(video_list)) + self.index_bias 101 | offsets = np.concatenate(( 102 | np.arange(len(video_list)), 103 | randint(len(video_list), 104 | size=self.total_length - len(video_list)))) 105 | return np.sort(offsets) + self.index_bias 106 | offsets = list() 107 | ticks = [i * len(video_list) // self.num_segments 108 | for i in range(self.num_segments + 1)] 109 | 110 | for i in range(self.num_segments): 111 | tick_len = ticks[i + 1] - ticks[i] 112 | tick = ticks[i] 113 | if tick_len >= self.seg_length: 114 | tick += randint(tick_len - self.seg_length + 1) 115 | offsets.extend([j for j in range(tick, tick + self.seg_length)]) 116 | 117 | return np.array(offsets) + self.index_bias 118 | 119 | def _get_val_indices(self, video_list): 120 | if self.dense_sample: 121 | sample_pos = max(1, 1 + len(video_list) - self.sample_range) 122 | t_stride = self.sample_range // self.num_segments 123 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1) 124 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)] 125 | return np.array(offsets) + self.index_bias 126 | else: 127 | seg_size = float(len(video_list) - 1) / self.num_segments 128 | offsets = [] 129 | for i in range(self.num_segments): 130 | start = int(np.round(seg_size * i)) 131 | frame_index = start + int(seg_size / 2) 132 | offsets.append(frame_index) 133 | 134 | return np.array(offsets) + self.index_bias 135 | 136 | 137 | def _get_test_indices(self, video_list): 138 | if self.dense_sample: 139 | # multi-clip for dense sampling 140 | num_clips = self.test_clips 141 | sample_pos = max(0, len(video_list) - self.sample_range) 142 | interval = self.sample_range // self.num_segments 143 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)] 144 | base_offsets = np.arange(self.num_segments) * interval 145 | offsets = [] 146 | for start_idx in start_list: 147 | offsets.extend((base_offsets + start_idx) % len(video_list)) 148 | return np.array(offsets) + self.index_bias 149 | else: 150 | # multi-clip for uniform sampling 151 | num_clips = self.test_clips 152 | 153 | # tick = len(video_list) / float(self.num_segments) 154 | # start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int) 155 | # offsets = [] 156 | # for start_idx in start_list.tolist(): 157 | # offsets += [ 158 | # int(start_idx + tick * x) % len(video_list) 159 | # for x in range(self.num_segments) 160 | # ] 161 | # return np.array(offsets) + self.index_bias 162 | 163 | 164 | ############ ATM implementation 165 | seg_size = float(len(video_list) - 1) / self.num_segments 166 | seq = [] 167 | duration = seg_size / (num_clips + 1) 168 | for temporal_sample_index in range(num_clips): 169 | for i in range(self.num_segments): 170 | start = int(np.round(seg_size * i)) 171 | frame_index = start + int(duration * (temporal_sample_index + 1)) 172 | seq.append(frame_index) 173 | return np.array(seq) + self.index_bias 174 | 175 | 176 | def _decord_decode(self, video_path): 177 | try: 178 | container = decord.VideoReader(video_path) 179 | except Exception as e: 180 | print("Failed to decode {} with exception: {}".format( 181 | video_path, e)) 182 | return None 183 | 184 | return container 185 | 186 | def __getitem__(self, index): 187 | # decode frames to video_list 188 | if self.modality == 'video': 189 | _num_retries = 10 190 | for i_try in range(_num_retries): 191 | record = copy.deepcopy(self.video_list[index]) 192 | directory = os.path.join(self.root_path, record.path) 193 | video_list = self._decord_decode(directory) 194 | # video_list = self._decord_pyav(directory) 195 | if video_list is None: 196 | print("Failed to decode video idx {} from {}; trial {}".format( 197 | index, directory, i_try) 198 | ) 199 | index = random.randint(0, len(self.video_list)) 200 | continue 201 | break 202 | else: 203 | record = self.video_list[index] 204 | video_list = os.listdir(os.path.join(self.root_path, record.path)) 205 | 206 | if not self.test_mode: # train/val 207 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list) 208 | else: # test 209 | segment_indices = self._get_test_indices(video_list) 210 | 211 | return self.get(record, video_list, segment_indices) 212 | 213 | 214 | def _load_image(self, directory, idx): 215 | if self.modality == 'RGB': 216 | try: 217 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')] 218 | except Exception: 219 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx))) 220 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')] 221 | 222 | 223 | def get(self, record, video_list, indices): 224 | images = list() 225 | for seg_ind in indices: 226 | p = int(seg_ind) 227 | if self.modality == 'video': 228 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')] 229 | else: 230 | seg_imgs = self._load_image(record.path, p) 231 | images.extend(seg_imgs) 232 | if p < len(video_list): 233 | p += 1 234 | if self.num_sample > 1: 235 | frame_list = [] 236 | label_list = [] 237 | for _ in range(self.num_sample): 238 | process_data, record_label = self.transform((images, record.label)) 239 | 240 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 241 | frame_list.append(process_data) 242 | label_list.append(record_label) 243 | #print('frame_list',len(frame_list), frame_list[0].shape) 244 | return frame_list, label_list, 0, 0 245 | else: 246 | process_data, record_label = self.transform((images, record.label)) 247 | 248 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!! 249 | 250 | return process_data, record_label, 0, 0 251 | 252 | def __len__(self): 253 | return len(self.video_list) 254 | -------------------------------------------------------------------------------- /engine_for_finetuning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import sys 5 | from typing import Iterable, Optional 6 | import torch 7 | from dataset.mixup import Mixup 8 | from timm.utils import accuracy, ModelEma 9 | import utils 10 | from scipy.special import softmax ###!!! 11 | from utils import AverageMeter, gather_labels 12 | import torch.distributed as dist 13 | 14 | class AllGather(torch.autograd.Function): 15 | """An autograd function that performs allgather on a tensor.""" 16 | 17 | @staticmethod 18 | def forward(ctx, tensor): 19 | output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())] 20 | torch.distributed.all_gather(output, tensor) 21 | ctx.rank = dist.get_rank() 22 | ctx.batch_size = tensor.shape[0] 23 | return torch.cat(output, dim=0) 24 | 25 | @staticmethod 26 | def backward(ctx, grad_output): 27 | return ( 28 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)], 29 | None, 30 | ) 31 | 32 | allgather = AllGather.apply 33 | 34 | def train_class_batch(model, samples, target, criterion): 35 | outputs = model(samples) 36 | loss = criterion(outputs, target) 37 | return loss, outputs 38 | 39 | 40 | def get_loss_scale_for_deepspeed(model): 41 | optimizer = model.optimizer 42 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale 43 | 44 | 45 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, 46 | data_loader: Iterable, optimizer: torch.optim.Optimizer, 47 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0, 48 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None, 49 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None, 50 | num_training_steps_per_epoch=None, update_freq=None): 51 | model.train(True) 52 | metric_logger = utils.MetricLogger(delimiter=" ") 53 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 54 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) 55 | header = 'Epoch: [{}]'.format(epoch) 56 | print_freq = 10 57 | 58 | if loss_scaler is None: 59 | model.zero_grad() 60 | model.micro_steps = 0 61 | else: 62 | optimizer.zero_grad() 63 | 64 | #import pdb;pdb.set_trace() 65 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): 66 | #import pdb;pdb.set_trace() 67 | 68 | if len(samples) == 2: 69 | samples = torch.cat([samples[0], samples[1]], 0) ###!!! 70 | targets = torch.cat([targets[0], targets[1]], 0) ###!!! 71 | 72 | step = data_iter_step // update_freq 73 | if step >= num_training_steps_per_epoch: 74 | continue 75 | it = start_steps + step # global training iteration 76 | # Update LR & WD for the first acc 77 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0: 78 | for i, param_group in enumerate(optimizer.param_groups): 79 | if lr_schedule_values is not None: 80 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"] 81 | if wd_schedule_values is not None and param_group["weight_decay"] > 0: 82 | param_group["weight_decay"] = wd_schedule_values[it] 83 | 84 | samples = samples.to(device, non_blocking=True) 85 | targets = targets.to(device, non_blocking=True) 86 | 87 | if mixup_fn is not None: 88 | samples, targets = mixup_fn(samples, targets) 89 | 90 | if loss_scaler is None: 91 | samples = samples.half() 92 | loss, output = train_class_batch( 93 | model, samples, targets, criterion) 94 | else: 95 | with torch.cuda.amp.autocast(): 96 | loss, output = train_class_batch( 97 | model, samples, targets, criterion) 98 | 99 | loss_value = loss.item() 100 | 101 | if not math.isfinite(loss_value): 102 | print("Loss is {}, stopping training".format(loss_value)) 103 | sys.exit(1) 104 | 105 | if loss_scaler is None: 106 | loss /= update_freq 107 | model.backward(loss) 108 | model.step() 109 | 110 | if (data_iter_step + 1) % update_freq == 0: 111 | # model.zero_grad() 112 | # Deepspeed will call step() & model.zero_grad() automatic 113 | if model_ema is not None: 114 | model_ema.update(model) 115 | grad_norm = None 116 | loss_scale_value = get_loss_scale_for_deepspeed(model) 117 | else: 118 | # this attribute is added by timm on one optimizer (adahessian) 119 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order 120 | loss /= update_freq 121 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm, 122 | parameters=model.parameters(), create_graph=is_second_order, 123 | update_grad=(data_iter_step + 1) % update_freq == 0) 124 | if (data_iter_step + 1) % update_freq == 0: 125 | optimizer.zero_grad() 126 | if model_ema is not None: 127 | model_ema.update(model) 128 | loss_scale_value = loss_scaler.state_dict()["scale"] 129 | 130 | torch.cuda.synchronize() 131 | 132 | if mixup_fn is None: 133 | class_acc = (output.max(-1)[-1] == targets).float().mean() 134 | else: 135 | class_acc = None 136 | metric_logger.update(loss=loss_value) 137 | metric_logger.update(class_acc=class_acc) 138 | metric_logger.update(loss_scale=loss_scale_value) 139 | min_lr = 10. 140 | max_lr = 0. 141 | for group in optimizer.param_groups: 142 | min_lr = min(min_lr, group["lr"]) 143 | max_lr = max(max_lr, group["lr"]) 144 | 145 | metric_logger.update(lr=max_lr) 146 | metric_logger.update(min_lr=min_lr) 147 | weight_decay_value = None 148 | for group in optimizer.param_groups: 149 | if group["weight_decay"] > 0: 150 | weight_decay_value = group["weight_decay"] 151 | metric_logger.update(weight_decay=weight_decay_value) 152 | metric_logger.update(grad_norm=grad_norm) 153 | 154 | if log_writer is not None: 155 | log_writer.update(loss=loss_value, head="loss") 156 | log_writer.update(class_acc=class_acc, head="loss") 157 | log_writer.update(loss_scale=loss_scale_value, head="opt") 158 | log_writer.update(lr=max_lr, head="opt") 159 | log_writer.update(min_lr=min_lr, head="opt") 160 | log_writer.update(weight_decay=weight_decay_value, head="opt") 161 | log_writer.update(grad_norm=grad_norm, head="opt") 162 | 163 | log_writer.set_step() 164 | 165 | # gather the stats from all processes 166 | metric_logger.synchronize_between_processes() 167 | print("Averaged stats:", metric_logger) 168 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 169 | 170 | 171 | @torch.no_grad() 172 | def validation_one_epoch(data_loader, model, device): 173 | criterion = torch.nn.CrossEntropyLoss() 174 | 175 | metric_logger = utils.MetricLogger(delimiter=" ") 176 | header = 'Val:' 177 | 178 | # switch to evaluation mode 179 | model.eval() 180 | 181 | for batch in metric_logger.log_every(data_loader, 10, header): 182 | videos = batch[0] 183 | target = batch[1] 184 | videos = videos.to(device, non_blocking=True) 185 | target = target.to(device, non_blocking=True) 186 | 187 | # compute output 188 | with torch.cuda.amp.autocast(): 189 | output = model(videos) 190 | #import pdb;pdb.set_trace() 191 | loss = criterion(output, target) 192 | 193 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 194 | 195 | batch_size = videos.shape[0] 196 | metric_logger.update(loss=loss.item()) 197 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 198 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 199 | # gather the stats from all processes 200 | metric_logger.synchronize_between_processes() 201 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 202 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 203 | 204 | 205 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 206 | 207 | @torch.no_grad() 208 | def charades_validation_one_epoch(data_loader, model, device, dataset = 'anet'): 209 | #epoch, val_loader, device, model, config, text_embedding, logger): 210 | mAP = AverageMeter() 211 | model.eval() 212 | from torchnet import meter 213 | maper = meter.mAPMeter() 214 | 215 | with torch.no_grad(): 216 | for i, batch in enumerate(data_loader): 217 | 218 | videos = batch[0] 219 | labels = batch[1] 220 | videos = videos.to(device, non_blocking=True) 221 | labels = labels.to(device, non_blocking=True) 222 | 223 | # compute output 224 | with torch.cuda.amp.autocast(): 225 | output = model(videos) 226 | 227 | ###!!! 228 | if dataset == 'anet': 229 | target = torch.IntTensor(labels.shape[0],200).zero_() #全部为0,one-hot编码 230 | for b, label in enumerate(labels): 231 | target[b][int(label)] = 1 232 | else: 233 | target = labels 234 | 235 | output = allgather(output) 236 | target = gather_labels(target) 237 | 238 | maper.add(output, target) 239 | mAP.update(maper.value().numpy(),target.size(0)) 240 | 241 | if i % 10 == 0: 242 | print('Test: [{0}/{1},mAP:{map:.3f}]\t'.format(i, len(data_loader), map=mAP.avg * 100)) 243 | 244 | print('Testing Results mAP === {mAP_result:.3f}'.format(mAP_result=mAP.avg * 100)) 245 | return mAP.avg * 100 246 | 247 | 248 | @torch.no_grad() 249 | def final_test(data_loader, model, device, file): 250 | criterion = torch.nn.CrossEntropyLoss() 251 | 252 | metric_logger = utils.MetricLogger(delimiter=" ") 253 | header = 'Test:' 254 | 255 | # switch to evaluation mode 256 | model.eval() 257 | final_result = [] 258 | 259 | for batch in metric_logger.log_every(data_loader, 10, header): 260 | videos = batch[0] 261 | target = batch[1] 262 | ids = batch[2] 263 | chunk_nb = batch[3] 264 | split_nb = batch[4] 265 | videos = videos.to(device, non_blocking=True) 266 | target = target.to(device, non_blocking=True) 267 | 268 | # compute output 269 | with torch.cuda.amp.autocast(): 270 | output = model(videos) 271 | loss = criterion(output, target) 272 | 273 | for i in range(output.size(0)): 274 | string = "{} {} {} {} {}\n".format(ids[i], \ 275 | str(output.data[i].cpu().numpy().tolist()), \ 276 | str(int(target[i].cpu().numpy())), \ 277 | str(int(chunk_nb[i].cpu().numpy())), \ 278 | str(int(split_nb[i].cpu().numpy()))) 279 | final_result.append(string) 280 | 281 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 282 | 283 | batch_size = videos.shape[0] 284 | metric_logger.update(loss=loss.item()) 285 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size) 286 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size) 287 | 288 | if not os.path.exists(file): 289 | os.mknod(file) 290 | with open(file, 'w') as f: 291 | f.write("{}, {}\n".format(acc1, acc5)) 292 | for line in final_result: 293 | f.write(line) 294 | # gather the stats from all processes 295 | metric_logger.synchronize_between_processes() 296 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}' 297 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss)) 298 | 299 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()} 300 | 301 | 302 | def merge(eval_path, num_tasks): 303 | dict_feats = {} 304 | dict_label = {} 305 | dict_pos = {} 306 | print("Reading individual output files") 307 | 308 | for x in range(num_tasks): 309 | file = os.path.join(eval_path, str(x) + '.txt') 310 | lines = open(file, 'r').readlines()[1:] 311 | for line in lines: 312 | line = line.strip() 313 | name = line.split('[')[0] 314 | label = line.split(']')[1].split(' ')[1] 315 | chunk_nb = line.split(']')[1].split(' ')[2] 316 | split_nb = line.split(']')[1].split(' ')[3] 317 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',') 318 | data = softmax(data) ###!!! 319 | if not name in dict_feats: 320 | dict_feats[name] = [] 321 | dict_label[name] = 0 322 | dict_pos[name] = [] 323 | if chunk_nb + split_nb in dict_pos[name]: 324 | continue 325 | dict_feats[name].append(data) 326 | dict_pos[name].append(chunk_nb + split_nb) 327 | dict_label[name] = label 328 | print("Computing final results") 329 | 330 | input_lst = [] 331 | print(len(dict_feats)) 332 | for i, item in enumerate(dict_feats): 333 | input_lst.append([i, item, dict_feats[item], dict_label[item]]) 334 | from multiprocessing import Pool 335 | p = Pool(64) 336 | ans = p.map(compute_video, input_lst) 337 | top1 = [x[1] for x in ans] 338 | top5 = [x[2] for x in ans] 339 | pred = [x[0] for x in ans] 340 | label = [x[3] for x in ans] 341 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5) 342 | return final_top1*100 ,final_top5*100 343 | 344 | def compute_video(lst): 345 | i, video_id, data, label = lst 346 | feat = [x for x in data] 347 | feat = np.mean(feat, axis=0) 348 | pred = np.argmax(feat) 349 | top1 = (int(pred) == int(label)) * 1.0 350 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0 351 | return [pred, top1, top5, int(label)] 352 | -------------------------------------------------------------------------------- /eva_clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer 3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint 4 | from .loss import ClipLoss 5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\ 6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype 7 | from .openai import load_openai_model, list_openai_models 8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\ 9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained 10 | from .tokenizer import SimpleTokenizer, tokenize 11 | from .transform import image_transform -------------------------------------------------------------------------------- /eva_clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/eva_clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /eva_clip/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) 3 | -------------------------------------------------------------------------------- /eva_clip/hf_configs.py: -------------------------------------------------------------------------------- 1 | # HF architecture dict: 2 | arch_dict = { 3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta 4 | "roberta": { 5 | "config_names": { 6 | "context_length": "max_position_embeddings", 7 | "vocab_size": "vocab_size", 8 | "width": "hidden_size", 9 | "heads": "num_attention_heads", 10 | "layers": "num_hidden_layers", 11 | "layer_attr": "layer", 12 | "token_embeddings_attr": "embeddings" 13 | }, 14 | "pooler": "mean_pooler", 15 | }, 16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig 17 | "xlm-roberta": { 18 | "config_names": { 19 | "context_length": "max_position_embeddings", 20 | "vocab_size": "vocab_size", 21 | "width": "hidden_size", 22 | "heads": "num_attention_heads", 23 | "layers": "num_hidden_layers", 24 | "layer_attr": "layer", 25 | "token_embeddings_attr": "embeddings" 26 | }, 27 | "pooler": "mean_pooler", 28 | }, 29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5 30 | "mt5": { 31 | "config_names": { 32 | # unlimited seqlen 33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273 34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374 35 | "context_length": "", 36 | "vocab_size": "vocab_size", 37 | "width": "d_model", 38 | "heads": "num_heads", 39 | "layers": "num_layers", 40 | "layer_attr": "block", 41 | "token_embeddings_attr": "embed_tokens" 42 | }, 43 | "pooler": "mean_pooler", 44 | }, 45 | "bert": { 46 | "config_names": { 47 | "context_length": "max_position_embeddings", 48 | "vocab_size": "vocab_size", 49 | "width": "hidden_size", 50 | "heads": "num_attention_heads", 51 | "layers": "num_hidden_layers", 52 | "layer_attr": "layer", 53 | "token_embeddings_attr": "embeddings" 54 | }, 55 | "pooler": "mean_pooler", 56 | } 57 | } 58 | -------------------------------------------------------------------------------- /eva_clip/hf_model.py: -------------------------------------------------------------------------------- 1 | """ huggingface model adapter 2 | 3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model. 4 | """ 5 | 6 | import re 7 | 8 | import torch 9 | import torch.nn as nn 10 | from torch.nn import functional as F 11 | from torch import TensorType 12 | try: 13 | import transformers 14 | from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig 15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \ 16 | BaseModelOutputWithPoolingAndCrossAttentions 17 | except ImportError as e: 18 | transformers = None 19 | 20 | 21 | class BaseModelOutput: 22 | pass 23 | 24 | 25 | class PretrainedConfig: 26 | pass 27 | 28 | from .hf_configs import arch_dict 29 | 30 | # utils 31 | def _camel2snake(s): 32 | return re.sub(r'(? TensorType: 140 | # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device) 141 | # attn_mask = (x != self.config.pad_token_id).long() 142 | # out = self.transformer( 143 | # input_ids=x, 144 | # attention_mask=attn_mask, 145 | # encoder_hidden_states = image_embeds, 146 | # encoder_attention_mask = image_atts, 147 | # ) 148 | # pooled_out = self.pooler(out, attn_mask) 149 | 150 | # return self.itm_proj(pooled_out) 151 | 152 | def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None): 153 | if masked_indices is None: 154 | masked_indices = torch.bernoulli(probability_matrix).bool() 155 | 156 | masked_indices[input_ids == self.tokenizer.pad_token_id] = False 157 | masked_indices[input_ids == self.tokenizer.cls_token_id] = False 158 | 159 | if targets is not None: 160 | targets[~masked_indices] = -100 # We only compute loss on masked tokens 161 | 162 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK]) 163 | indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices 164 | input_ids[indices_replaced] = self.tokenizer.mask_token_id 165 | 166 | # 10% of the time, we replace masked input tokens with random word 167 | indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced 168 | random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device) 169 | input_ids[indices_random] = random_words[indices_random] 170 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged 171 | 172 | if targets is not None: 173 | return input_ids, targets 174 | else: 175 | return input_ids 176 | 177 | def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25): 178 | labels = input_ids.clone() 179 | attn_mask = (input_ids != self.config.pad_token_id).long() 180 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device) 181 | vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"]) 182 | probability_matrix = torch.full(labels.shape, mlm_probability) 183 | input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels, 184 | probability_matrix = probability_matrix) 185 | mlm_output = self.transformer(input_ids, 186 | attention_mask = attn_mask, 187 | encoder_hidden_states = image_embeds, 188 | encoder_attention_mask = image_atts, 189 | return_dict = True, 190 | labels = labels, 191 | ) 192 | return mlm_output.loss 193 | # mlm_output = self.transformer(input_ids, 194 | # attention_mask = attn_mask, 195 | # encoder_hidden_states = image_embeds, 196 | # encoder_attention_mask = image_atts, 197 | # return_dict = True, 198 | # ).last_hidden_state 199 | # logits = self.mlm_proj(mlm_output) 200 | 201 | # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size) 202 | # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size) 203 | # labels = labels[:, 1:].contiguous().view(-1) 204 | 205 | # mlm_loss = F.cross_entropy( 206 | # logits, 207 | # labels, 208 | # # label_smoothing=0.1, 209 | # ) 210 | # return mlm_loss 211 | 212 | 213 | def forward(self, x:TensorType) -> TensorType: 214 | attn_mask = (x != self.config.pad_token_id).long() 215 | out = self.transformer(input_ids=x, attention_mask=attn_mask) 216 | pooled_out = self.pooler(out, attn_mask) 217 | 218 | return self.proj(pooled_out) 219 | 220 | def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True): 221 | if not unlocked_layers: # full freezing 222 | for n, p in self.transformer.named_parameters(): 223 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 224 | return 225 | 226 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer 227 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 228 | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model") 229 | embeddings = getattr( 230 | self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"]) 231 | modules = [embeddings, *layer_list][:-unlocked_layers] 232 | # freeze layers 233 | for module in modules: 234 | for n, p in module.named_parameters(): 235 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False 236 | 237 | 238 | @torch.jit.ignore 239 | def set_grad_checkpointing(self, enable=True): 240 | self.transformer.gradient_checkpointing_enable() 241 | 242 | def get_num_layers(self): 243 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer 244 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"]) 245 | return len(layer_list) 246 | 247 | def init_parameters(self): 248 | pass 249 | -------------------------------------------------------------------------------- /eva_clip/loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn import functional as F 5 | 6 | try: 7 | import torch.distributed.nn 8 | from torch import distributed as dist 9 | has_distributed = True 10 | except ImportError: 11 | has_distributed = False 12 | 13 | try: 14 | import horovod.torch as hvd 15 | except ImportError: 16 | hvd = None 17 | 18 | from timm.loss import LabelSmoothingCrossEntropy 19 | 20 | 21 | def gather_features( 22 | image_features, 23 | text_features, 24 | local_loss=False, 25 | gather_with_grad=False, 26 | rank=0, 27 | world_size=1, 28 | use_horovod=False 29 | ): 30 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.' 31 | if use_horovod: 32 | assert hvd is not None, 'Please install horovod' 33 | if gather_with_grad: 34 | all_image_features = hvd.allgather(image_features) 35 | all_text_features = hvd.allgather(text_features) 36 | else: 37 | with torch.no_grad(): 38 | all_image_features = hvd.allgather(image_features) 39 | all_text_features = hvd.allgather(text_features) 40 | if not local_loss: 41 | # ensure grads for local rank when all_* features don't have a gradient 42 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0)) 43 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0)) 44 | gathered_image_features[rank] = image_features 45 | gathered_text_features[rank] = text_features 46 | all_image_features = torch.cat(gathered_image_features, dim=0) 47 | all_text_features = torch.cat(gathered_text_features, dim=0) 48 | else: 49 | # We gather tensors from all gpus 50 | if gather_with_grad: 51 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0) 52 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0) 53 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0) 54 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0) 55 | else: 56 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)] 57 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)] 58 | dist.all_gather(gathered_image_features, image_features) 59 | dist.all_gather(gathered_text_features, text_features) 60 | if not local_loss: 61 | # ensure grads for local rank when all_* features don't have a gradient 62 | gathered_image_features[rank] = image_features 63 | gathered_text_features[rank] = text_features 64 | all_image_features = torch.cat(gathered_image_features, dim=0) 65 | all_text_features = torch.cat(gathered_text_features, dim=0) 66 | 67 | return all_image_features, all_text_features 68 | 69 | 70 | class ClipLoss(nn.Module): 71 | 72 | def __init__( 73 | self, 74 | local_loss=False, 75 | gather_with_grad=False, 76 | cache_labels=False, 77 | rank=0, 78 | world_size=1, 79 | use_horovod=False, 80 | smoothing=0., 81 | ): 82 | super().__init__() 83 | self.local_loss = local_loss 84 | self.gather_with_grad = gather_with_grad 85 | self.cache_labels = cache_labels 86 | self.rank = rank 87 | self.world_size = world_size 88 | self.use_horovod = use_horovod 89 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None 90 | 91 | # cache state 92 | self.prev_num_logits = 0 93 | self.labels = {} 94 | 95 | def forward(self, image_features, text_features, logit_scale=1.): 96 | device = image_features.device 97 | if self.world_size > 1: 98 | all_image_features, all_text_features = gather_features( 99 | image_features, text_features, 100 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod) 101 | 102 | if self.local_loss: 103 | logits_per_image = logit_scale * image_features @ all_text_features.T 104 | logits_per_text = logit_scale * text_features @ all_image_features.T 105 | else: 106 | logits_per_image = logit_scale * all_image_features @ all_text_features.T 107 | logits_per_text = logits_per_image.T 108 | else: 109 | logits_per_image = logit_scale * image_features @ text_features.T 110 | logits_per_text = logit_scale * text_features @ image_features.T 111 | # calculated ground-truth and cache if enabled 112 | num_logits = logits_per_image.shape[0] 113 | if self.prev_num_logits != num_logits or device not in self.labels: 114 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 115 | if self.world_size > 1 and self.local_loss: 116 | labels = labels + num_logits * self.rank 117 | if self.cache_labels: 118 | self.labels[device] = labels 119 | self.prev_num_logits = num_logits 120 | else: 121 | labels = self.labels[device] 122 | 123 | if self.label_smoothing_cross_entropy: 124 | total_loss = ( 125 | self.label_smoothing_cross_entropy(logits_per_image, labels) + 126 | self.label_smoothing_cross_entropy(logits_per_text, labels) 127 | ) / 2 128 | else: 129 | total_loss = ( 130 | F.cross_entropy(logits_per_image, labels) + 131 | F.cross_entropy(logits_per_text, labels) 132 | ) / 2 133 | 134 | acc = None 135 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image) 136 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text) 137 | acc = {"i2t": i2t_acc, "t2i": t2i_acc} 138 | return total_loss, acc -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "patch_size": 16, 8 | "eva_model_name": "eva-clip-b-16", 9 | "ls_init_value": 0.1, 10 | "drop_path_rate": 0.0 11 | }, 12 | "text_cfg": { 13 | "context_length": 77, 14 | "vocab_size": 49408, 15 | "width": 512, 16 | "heads": 8, 17 | "layers": 12 18 | } 19 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-g-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 1024, 19 | "heads": 16, 20 | "layers": 24, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA01-CLIP-g-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 40, 6 | "width": 1408, 7 | "head_width": 88, 8 | "mlp_ratio": 4.3637, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-g-14-x", 11 | "drop_path_rate": 0.4, 12 | "xattn": true, 13 | "fusedLN": true 14 | }, 15 | "text_cfg": { 16 | "context_length": 77, 17 | "vocab_size": 49408, 18 | "width": 768, 19 | "heads": 12, 20 | "layers": 12, 21 | "xattn": false, 22 | "fusedLN": true 23 | } 24 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-B-16.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 512, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 12, 6 | "width": 768, 7 | "head_width": 64, 8 | "patch_size": 16, 9 | "mlp_ratio": 2.6667, 10 | "eva_model_name": "eva-clip-b-16-X", 11 | "drop_path_rate": 0.0, 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 512, 24 | "heads": 8, 25 | "layers": 12, 26 | "xattn": true, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-L-14-336.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 336, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14-336", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-L-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 768, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 24, 6 | "width": 1024, 7 | "drop_path_rate": 0, 8 | "head_width": 64, 9 | "mlp_ratio": 2.6667, 10 | "patch_size": 14, 11 | "eva_model_name": "eva-clip-l-14", 12 | "xattn": true, 13 | "fusedLN": true, 14 | "rope": true, 15 | "pt_hw_seq_len": 16, 16 | "intp_freq": true, 17 | "naiveswiglu": true, 18 | "subln": true 19 | }, 20 | "text_cfg": { 21 | "context_length": 77, 22 | "vocab_size": 49408, 23 | "width": 768, 24 | "heads": 12, 25 | "layers": 12, 26 | "xattn": false, 27 | "fusedLN": true 28 | } 29 | } -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1280, 20 | "heads": 20, 21 | "layers": 32, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } 26 | -------------------------------------------------------------------------------- /eva_clip/model_configs/EVA02-CLIP-bigE-14.json: -------------------------------------------------------------------------------- 1 | { 2 | "embed_dim": 1024, 3 | "vision_cfg": { 4 | "image_size": 224, 5 | "layers": 64, 6 | "width": 1792, 7 | "head_width": 112, 8 | "mlp_ratio": 8.571428571428571, 9 | "patch_size": 14, 10 | "eva_model_name": "eva-clip-4b-14-x", 11 | "drop_path_rate": 0, 12 | "xattn": true, 13 | "postnorm": true, 14 | "fusedLN": true 15 | }, 16 | "text_cfg": { 17 | "context_length": 77, 18 | "vocab_size": 49408, 19 | "width": 1024, 20 | "heads": 16, 21 | "layers": 24, 22 | "xattn": false, 23 | "fusedLN": true 24 | } 25 | } -------------------------------------------------------------------------------- /eva_clip/modified_resnet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | 7 | from eva_clip.utils import freeze_batch_norm_2d 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.act1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.act2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.act3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.act1(self.bn1(self.conv1(x))) 46 | out = self.act2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.act3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x, key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0., 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | 92 | return x[0] 93 | 94 | 95 | class ModifiedResNet(nn.Module): 96 | """ 97 | A ResNet class that is similar to torchvision's but contains the following changes: 98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 100 | - The final pooling layer is a QKV attention instead of an average pool 101 | """ 102 | 103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64): 104 | super().__init__() 105 | self.output_dim = output_dim 106 | self.image_size = image_size 107 | 108 | # the 3-layer stem 109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 110 | self.bn1 = nn.BatchNorm2d(width // 2) 111 | self.act1 = nn.ReLU(inplace=True) 112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 113 | self.bn2 = nn.BatchNorm2d(width // 2) 114 | self.act2 = nn.ReLU(inplace=True) 115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 116 | self.bn3 = nn.BatchNorm2d(width) 117 | self.act3 = nn.ReLU(inplace=True) 118 | self.avgpool = nn.AvgPool2d(2) 119 | 120 | # residual layers 121 | self._inplanes = width # this is a *mutable* variable used during construction 122 | self.layer1 = self._make_layer(width, layers[0]) 123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 126 | 127 | embed_dim = width * 32 # the ResNet feature dimension 128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) 129 | 130 | self.init_parameters() 131 | 132 | def _make_layer(self, planes, blocks, stride=1): 133 | layers = [Bottleneck(self._inplanes, planes, stride)] 134 | 135 | self._inplanes = planes * Bottleneck.expansion 136 | for _ in range(1, blocks): 137 | layers.append(Bottleneck(self._inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def init_parameters(self): 142 | if self.attnpool is not None: 143 | std = self.attnpool.c_proj.in_features ** -0.5 144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std) 145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std) 146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std) 147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std) 148 | 149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: 150 | for name, param in resnet_block.named_parameters(): 151 | if name.endswith("bn3.weight"): 152 | nn.init.zeros_(param) 153 | 154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model' 156 | for param in self.parameters(): 157 | param.requires_grad = False 158 | if freeze_bn_stats: 159 | freeze_batch_norm_2d(self) 160 | 161 | @torch.jit.ignore 162 | def set_grad_checkpointing(self, enable=True): 163 | # FIXME support for non-transformer 164 | pass 165 | 166 | def stem(self, x): 167 | x = self.act1(self.bn1(self.conv1(x))) 168 | x = self.act2(self.bn2(self.conv2(x))) 169 | x = self.act3(self.bn3(self.conv3(x))) 170 | x = self.avgpool(x) 171 | return x 172 | 173 | def forward(self, x): 174 | x = self.stem(x) 175 | x = self.layer1(x) 176 | x = self.layer2(x) 177 | x = self.layer3(x) 178 | x = self.layer4(x) 179 | x = self.attnpool(x) 180 | 181 | return x 182 | -------------------------------------------------------------------------------- /eva_clip/openai.py: -------------------------------------------------------------------------------- 1 | """ OpenAI pretrained model functions 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | 6 | import os 7 | import warnings 8 | from typing import List, Optional, Union 9 | 10 | import torch 11 | 12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype 13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url 14 | 15 | __all__ = ["list_openai_models", "load_openai_model"] 16 | 17 | 18 | def list_openai_models() -> List[str]: 19 | """Returns the names of available CLIP models""" 20 | return list_pretrained_models_by_tag('openai') 21 | 22 | 23 | def load_openai_model( 24 | name: str, 25 | precision: Optional[str] = None, 26 | device: Optional[Union[str, torch.device]] = None, 27 | jit: bool = True, 28 | cache_dir: Optional[str] = None, 29 | ): 30 | """Load a CLIP model 31 | 32 | Parameters 33 | ---------- 34 | name : str 35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 36 | precision: str 37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'. 38 | device : Union[str, torch.device] 39 | The device to put the loaded model 40 | jit : bool 41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 42 | cache_dir : Optional[str] 43 | The directory to cache the downloaded model weights 44 | 45 | Returns 46 | ------- 47 | model : torch.nn.Module 48 | The CLIP model 49 | preprocess : Callable[[PIL.Image], torch.Tensor] 50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 51 | """ 52 | if device is None: 53 | device = "cuda" if torch.cuda.is_available() else "cpu" 54 | if precision is None: 55 | precision = 'fp32' if device == 'cpu' else 'fp16' 56 | 57 | if get_pretrained_url(name, 'openai'): 58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir) 59 | elif os.path.isfile(name): 60 | model_path = name 61 | else: 62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}") 63 | 64 | try: 65 | # loading JIT archive 66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 67 | state_dict = None 68 | except RuntimeError: 69 | # loading saved state dict 70 | if jit: 71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 72 | jit = False 73 | state_dict = torch.load(model_path, map_location="cpu") 74 | 75 | if not jit: 76 | # Build a non-jit model from the OpenAI jitted model state dict 77 | cast_dtype = get_cast_dtype(precision) 78 | try: 79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype) 80 | except KeyError: 81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} 82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype) 83 | 84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use 85 | model = model.to(device) 86 | if precision.startswith('amp') or precision == 'fp32': 87 | model.float() 88 | elif precision == 'bf16': 89 | convert_weights_to_lp(model, dtype=torch.bfloat16) 90 | 91 | return model 92 | 93 | # patch the device names 94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 96 | 97 | def patch_device(module): 98 | try: 99 | graphs = [module.graph] if hasattr(module, "graph") else [] 100 | except RuntimeError: 101 | graphs = [] 102 | 103 | if hasattr(module, "forward1"): 104 | graphs.append(module.forward1.graph) 105 | 106 | for graph in graphs: 107 | for node in graph.findAllNodes("prim::Constant"): 108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 109 | node.copyAttributes(device_node) 110 | 111 | model.apply(patch_device) 112 | patch_device(model.encode_image) 113 | patch_device(model.encode_text) 114 | 115 | # patch dtype to float32 (typically for CPU) 116 | if precision == 'fp32': 117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 119 | float_node = float_input.node() 120 | 121 | def patch_float(module): 122 | try: 123 | graphs = [module.graph] if hasattr(module, "graph") else [] 124 | except RuntimeError: 125 | graphs = [] 126 | 127 | if hasattr(module, "forward1"): 128 | graphs.append(module.forward1.graph) 129 | 130 | for graph in graphs: 131 | for node in graph.findAllNodes("aten::to"): 132 | inputs = list(node.inputs()) 133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 134 | if inputs[i].node()["value"] == 5: 135 | inputs[i].node().copyAttributes(float_node) 136 | 137 | model.apply(patch_float) 138 | patch_float(model.encode_image) 139 | patch_float(model.encode_text) 140 | model.float() 141 | 142 | # ensure image_size attr available at consistent location for both jit and non-jit 143 | model.visual.image_size = model.input_resolution.item() 144 | return model 145 | -------------------------------------------------------------------------------- /eva_clip/pretrained.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from functools import partial 6 | from typing import Dict, Union 7 | 8 | from tqdm import tqdm 9 | 10 | try: 11 | from huggingface_hub import hf_hub_download 12 | _has_hf_hub = True 13 | except ImportError: 14 | hf_hub_download = None 15 | _has_hf_hub = False 16 | 17 | 18 | def _pcfg(url='', hf_hub='', filename='', mean=None, std=None): 19 | return dict( 20 | url=url, 21 | hf_hub=hf_hub, 22 | mean=mean, 23 | std=std, 24 | ) 25 | 26 | _VITB32 = dict( 27 | openai=_pcfg( 28 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 29 | laion400m_e31=_pcfg( 30 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 31 | laion400m_e32=_pcfg( 32 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 33 | laion2b_e16=_pcfg( 34 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"), 35 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/') 36 | ) 37 | 38 | _VITB32_quickgelu = dict( 39 | openai=_pcfg( 40 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"), 41 | laion400m_e31=_pcfg( 42 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"), 43 | laion400m_e32=_pcfg( 44 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"), 45 | ) 46 | 47 | _VITB16 = dict( 48 | openai=_pcfg( 49 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"), 50 | laion400m_e31=_pcfg( 51 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"), 52 | laion400m_e32=_pcfg( 53 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"), 54 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'), 55 | ) 56 | 57 | _EVAB16 = dict( 58 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), 59 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'), 60 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), 61 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'), 62 | ) 63 | 64 | _VITB16_PLUS_240 = dict( 65 | laion400m_e31=_pcfg( 66 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"), 67 | laion400m_e32=_pcfg( 68 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"), 69 | ) 70 | 71 | _VITL14 = dict( 72 | openai=_pcfg( 73 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"), 74 | laion400m_e31=_pcfg( 75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"), 76 | laion400m_e32=_pcfg( 77 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"), 78 | laion2b_s32b_b82k=_pcfg( 79 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/', 80 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 81 | ) 82 | 83 | _EVAL14 = dict( 84 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), 85 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'), 86 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), 87 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'), 88 | ) 89 | 90 | _VITL14_336 = dict( 91 | openai=_pcfg( 92 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"), 93 | ) 94 | 95 | _EVAL14_336 = dict( 96 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), 97 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'), 98 | eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), 99 | eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'), 100 | ) 101 | 102 | _VITH14 = dict( 103 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'), 104 | ) 105 | 106 | _VITg14 = dict( 107 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'), 108 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'), 109 | ) 110 | 111 | _EVAg14 = dict( 112 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), 113 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), 114 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), 115 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'), 116 | ) 117 | 118 | _EVAg14_PLUS = dict( 119 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'), 120 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'), 121 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), 122 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'), 123 | ) 124 | 125 | _VITbigG14 = dict( 126 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'), 127 | ) 128 | 129 | _EVAbigE14 = dict( 130 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 131 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 132 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), 133 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'), 134 | ) 135 | 136 | _EVAbigE14_PLUS = dict( 137 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 138 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'), 139 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), 140 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'), 141 | ) 142 | 143 | 144 | _PRETRAINED = { 145 | # "ViT-B-32": _VITB32, 146 | "OpenaiCLIP-B-32": _VITB32, 147 | "OpenCLIP-B-32": _VITB32, 148 | 149 | # "ViT-B-32-quickgelu": _VITB32_quickgelu, 150 | "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu, 151 | "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu, 152 | 153 | # "ViT-B-16": _VITB16, 154 | "OpenaiCLIP-B-16": _VITB16, 155 | "OpenCLIP-B-16": _VITB16, 156 | 157 | "EVA02-B-16": _EVAB16, 158 | "EVA02-CLIP-B-16": _EVAB16, 159 | 160 | # "ViT-B-16-plus-240": _VITB16_PLUS_240, 161 | "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240, 162 | 163 | # "ViT-L-14": _VITL14, 164 | "OpenaiCLIP-L-14": _VITL14, 165 | "OpenCLIP-L-14": _VITL14, 166 | 167 | "EVA02-L-14": _EVAL14, 168 | "EVA02-CLIP-L-14": _EVAL14, 169 | 170 | # "ViT-L-14-336": _VITL14_336, 171 | "OpenaiCLIP-L-14-336": _VITL14_336, 172 | 173 | "EVA02-CLIP-L-14-336": _EVAL14_336, 174 | 175 | # "ViT-H-14": _VITH14, 176 | # "ViT-g-14": _VITg14, 177 | "OpenCLIP-H-14": _VITH14, 178 | "OpenCLIP-g-14": _VITg14, 179 | 180 | "EVA01-CLIP-g-14": _EVAg14, 181 | "EVA01-CLIP-g-14-plus": _EVAg14_PLUS, 182 | 183 | # "ViT-bigG-14": _VITbigG14, 184 | "OpenCLIP-bigG-14": _VITbigG14, 185 | 186 | "EVA02-CLIP-bigE-14": _EVAbigE14, 187 | "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS, 188 | } 189 | 190 | 191 | def _clean_tag(tag: str): 192 | # normalize pretrained tags 193 | return tag.lower().replace('-', '_') 194 | 195 | 196 | def list_pretrained(as_str: bool = False): 197 | """ returns list of pretrained models 198 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True 199 | """ 200 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()] 201 | 202 | 203 | def list_pretrained_models_by_tag(tag: str): 204 | """ return all models having the specified pretrain tag """ 205 | models = [] 206 | tag = _clean_tag(tag) 207 | for k in _PRETRAINED.keys(): 208 | if tag in _PRETRAINED[k]: 209 | models.append(k) 210 | return models 211 | 212 | 213 | def list_pretrained_tags_by_model(model: str): 214 | """ return all pretrain tags for the specified model architecture """ 215 | tags = [] 216 | if model in _PRETRAINED: 217 | tags.extend(_PRETRAINED[model].keys()) 218 | return tags 219 | 220 | 221 | def is_pretrained_cfg(model: str, tag: str): 222 | if model not in _PRETRAINED: 223 | return False 224 | return _clean_tag(tag) in _PRETRAINED[model] 225 | 226 | 227 | def get_pretrained_cfg(model: str, tag: str): 228 | if model not in _PRETRAINED: 229 | return {} 230 | model_pretrained = _PRETRAINED[model] 231 | return model_pretrained.get(_clean_tag(tag), {}) 232 | 233 | 234 | def get_pretrained_url(model: str, tag: str): 235 | cfg = get_pretrained_cfg(model, _clean_tag(tag)) 236 | return cfg.get('url', '') 237 | 238 | 239 | def download_pretrained_from_url( 240 | url: str, 241 | cache_dir: Union[str, None] = None, 242 | ): 243 | if not cache_dir: 244 | cache_dir = os.path.expanduser("~/.cache/clip") 245 | os.makedirs(cache_dir, exist_ok=True) 246 | filename = os.path.basename(url) 247 | 248 | if 'openaipublic' in url: 249 | expected_sha256 = url.split("/")[-2] 250 | elif 'mlfoundations' in url: 251 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1] 252 | else: 253 | expected_sha256 = '' 254 | 255 | download_target = os.path.join(cache_dir, filename) 256 | 257 | if os.path.exists(download_target) and not os.path.isfile(download_target): 258 | raise RuntimeError(f"{download_target} exists and is not a regular file") 259 | 260 | if os.path.isfile(download_target): 261 | if expected_sha256: 262 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 263 | return download_target 264 | else: 265 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 266 | else: 267 | return download_target 268 | 269 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 270 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 271 | while True: 272 | buffer = source.read(8192) 273 | if not buffer: 274 | break 275 | 276 | output.write(buffer) 277 | loop.update(len(buffer)) 278 | 279 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256): 280 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 281 | 282 | return download_target 283 | 284 | 285 | def has_hf_hub(necessary=False): 286 | if not _has_hf_hub and necessary: 287 | # if no HF Hub module installed, and it is necessary to continue, raise error 288 | raise RuntimeError( 289 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.') 290 | return _has_hf_hub 291 | 292 | 293 | def download_pretrained_from_hf( 294 | model_id: str, 295 | filename: str = 'open_clip_pytorch_model.bin', 296 | revision=None, 297 | cache_dir: Union[str, None] = None, 298 | ): 299 | has_hf_hub(True) 300 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir) 301 | return cached_file 302 | 303 | 304 | def download_pretrained( 305 | cfg: Dict, 306 | force_hf_hub: bool = False, 307 | cache_dir: Union[str, None] = None, 308 | ): 309 | target = '' 310 | if not cfg: 311 | return target 312 | 313 | download_url = cfg.get('url', '') 314 | download_hf_hub = cfg.get('hf_hub', '') 315 | if download_hf_hub and force_hf_hub: 316 | # use HF hub even if url exists 317 | download_url = '' 318 | 319 | if download_url: 320 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir) 321 | elif download_hf_hub: 322 | has_hf_hub(True) 323 | # we assume the hf_hub entries in pretrained config combine model_id + filename in 324 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and 325 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'. 326 | model_id, filename = os.path.split(download_hf_hub) 327 | if filename: 328 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir) 329 | else: 330 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir) 331 | 332 | return target 333 | -------------------------------------------------------------------------------- /eva_clip/rope.py: -------------------------------------------------------------------------------- 1 | from math import pi 2 | import torch 3 | from torch import nn 4 | from einops import rearrange, repeat 5 | import logging 6 | 7 | def broadcat(tensors, dim = -1): 8 | num_tensors = len(tensors) 9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors))) 10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' 11 | shape_len = list(shape_lens)[0] 12 | dim = (dim + shape_len) if dim < 0 else dim 13 | dims = list(zip(*map(lambda t: list(t.shape), tensors))) 14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] 15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' 16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) 17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) 18 | expanded_dims.insert(dim, (dim, dims[dim])) 19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) 20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) 21 | return torch.cat(tensors, dim = dim) 22 | 23 | def rotate_half(x): 24 | x = rearrange(x, '... (d r) -> ... d r', r = 2) 25 | x1, x2 = x.unbind(dim = -1) 26 | x = torch.stack((-x2, x1), dim = -1) 27 | return rearrange(x, '... d r -> ... (d r)') 28 | 29 | 30 | class VisionRotaryEmbedding(nn.Module): 31 | def __init__( 32 | self, 33 | dim, 34 | pt_seq_len, 35 | ft_seq_len=None, 36 | custom_freqs = None, 37 | freqs_for = 'lang', 38 | theta = 10000, 39 | max_freq = 10, 40 | num_freqs = 1, 41 | ): 42 | super().__init__() 43 | if custom_freqs: 44 | freqs = custom_freqs 45 | elif freqs_for == 'lang': 46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 47 | elif freqs_for == 'pixel': 48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 49 | elif freqs_for == 'constant': 50 | freqs = torch.ones(num_freqs).float() 51 | else: 52 | raise ValueError(f'unknown modality {freqs_for}') 53 | 54 | if ft_seq_len is None: ft_seq_len = pt_seq_len 55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 56 | 57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs) 58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) 59 | 60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs) 61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) 62 | 63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) 64 | 65 | self.register_buffer("freqs_cos", freqs.cos()) 66 | self.register_buffer("freqs_sin", freqs.sin()) 67 | 68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 69 | 70 | def forward(self, t, start_index = 0): 71 | rot_dim = self.freqs_cos.shape[-1] 72 | end_index = start_index + rot_dim 73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' 74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] 75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) 76 | 77 | return torch.cat((t_left, t, t_right), dim = -1) 78 | 79 | class VisionRotaryEmbeddingFast(nn.Module): 80 | def __init__( 81 | self, 82 | dim, 83 | pt_seq_len, 84 | ft_seq_len=None, 85 | custom_freqs = None, 86 | freqs_for = 'lang', 87 | theta = 10000, 88 | max_freq = 10, 89 | num_freqs = 1, 90 | patch_dropout = 0. 91 | ): 92 | super().__init__() 93 | if custom_freqs: 94 | freqs = custom_freqs 95 | elif freqs_for == 'lang': 96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) 97 | elif freqs_for == 'pixel': 98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi 99 | elif freqs_for == 'constant': 100 | freqs = torch.ones(num_freqs).float() 101 | else: 102 | raise ValueError(f'unknown modality {freqs_for}') 103 | 104 | if ft_seq_len is None: ft_seq_len = pt_seq_len 105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len 106 | 107 | freqs = torch.einsum('..., f -> ... f', t, freqs) 108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2) 109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) 110 | 111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) 112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) 113 | 114 | self.patch_dropout = patch_dropout 115 | 116 | self.register_buffer("freqs_cos", freqs_cos) 117 | self.register_buffer("freqs_sin", freqs_sin) 118 | 119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}') 120 | 121 | def forward(self, t, patch_indices_keep=None): 122 | if patch_indices_keep is not None: 123 | batch = t.size()[0] 124 | batch_indices = torch.arange(batch) 125 | batch_indices = batch_indices[..., None] 126 | 127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]) 129 | 130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep] 131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j') 132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep] 133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j') 134 | 135 | return t * freqs_cos + rotate_half(t) * freqs_sin 136 | 137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin -------------------------------------------------------------------------------- /eva_clip/timm_model.py: -------------------------------------------------------------------------------- 1 | """ timm model adapter 2 | 3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. 4 | """ 5 | import logging 6 | from collections import OrderedDict 7 | 8 | import torch 9 | import torch.nn as nn 10 | 11 | try: 12 | import timm 13 | from timm.models.layers import Mlp, to_2tuple 14 | try: 15 | # old timm imports < 0.8.1 16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d 17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d 18 | except ImportError: 19 | # new timm imports >= 0.8.1 20 | from timm.layers import RotAttentionPool2d 21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d 22 | except ImportError: 23 | timm = None 24 | 25 | from .utils import freeze_batch_norm_2d 26 | 27 | 28 | class TimmModel(nn.Module): 29 | """ timm model adapter 30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat 31 | """ 32 | 33 | def __init__( 34 | self, 35 | model_name, 36 | embed_dim, 37 | image_size=224, 38 | pool='avg', 39 | proj='linear', 40 | proj_bias=False, 41 | drop=0., 42 | pretrained=False): 43 | super().__init__() 44 | if timm is None: 45 | raise RuntimeError("Please `pip install timm` to use timm models.") 46 | 47 | self.image_size = to_2tuple(image_size) 48 | self.trunk = timm.create_model(model_name, pretrained=pretrained) 49 | feat_size = self.trunk.default_cfg.get('pool_size', None) 50 | feature_ndim = 1 if not feat_size else 2 51 | if pool in ('abs_attn', 'rot_attn'): 52 | assert feature_ndim == 2 53 | # if attn pooling used, remove both classifier and default pool 54 | self.trunk.reset_classifier(0, global_pool='') 55 | else: 56 | # reset global pool if pool config set, otherwise leave as network default 57 | reset_kwargs = dict(global_pool=pool) if pool else {} 58 | self.trunk.reset_classifier(0, **reset_kwargs) 59 | prev_chs = self.trunk.num_features 60 | 61 | head_layers = OrderedDict() 62 | if pool == 'abs_attn': 63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim) 64 | prev_chs = embed_dim 65 | elif pool == 'rot_attn': 66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim) 67 | prev_chs = embed_dim 68 | else: 69 | assert proj, 'projection layer needed if non-attention pooling is used.' 70 | 71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used 72 | if proj == 'linear': 73 | head_layers['drop'] = nn.Dropout(drop) 74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias) 75 | elif proj == 'mlp': 76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias)) 77 | 78 | self.head = nn.Sequential(head_layers) 79 | 80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False): 81 | """ lock modules 82 | Args: 83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0) 84 | """ 85 | if not unlocked_groups: 86 | # lock full model 87 | for param in self.trunk.parameters(): 88 | param.requires_grad = False 89 | if freeze_bn_stats: 90 | freeze_batch_norm_2d(self.trunk) 91 | else: 92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change 93 | try: 94 | # FIXME import here until API stable and in an official release 95 | from timm.models.helpers import group_parameters, group_modules 96 | except ImportError: 97 | raise RuntimeError( 98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`') 99 | matcher = self.trunk.group_matcher() 100 | gparams = group_parameters(self.trunk, matcher) 101 | max_layer_id = max(gparams.keys()) 102 | max_layer_id = max_layer_id - unlocked_groups 103 | for group_idx in range(max_layer_id + 1): 104 | group = gparams[group_idx] 105 | for param in group: 106 | self.trunk.get_parameter(param).requires_grad = False 107 | if freeze_bn_stats: 108 | gmodules = group_modules(self.trunk, matcher, reverse=True) 109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} 110 | freeze_batch_norm_2d(self.trunk, gmodules) 111 | 112 | @torch.jit.ignore 113 | def set_grad_checkpointing(self, enable=True): 114 | try: 115 | self.trunk.set_grad_checkpointing(enable) 116 | except Exception as e: 117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...') 118 | 119 | def forward(self, x): 120 | x = self.trunk(x) 121 | x = self.head(x) 122 | return x 123 | -------------------------------------------------------------------------------- /eva_clip/tokenizer.py: -------------------------------------------------------------------------------- 1 | """ CLIP tokenizer 2 | 3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | import gzip 6 | import html 7 | import os 8 | from functools import lru_cache 9 | from typing import Union, List 10 | 11 | import ftfy 12 | import regex as re 13 | import torch 14 | 15 | # https://stackoverflow.com/q/62691279 16 | import os 17 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 18 | 19 | 20 | @lru_cache() 21 | def default_bpe(): 22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 23 | 24 | 25 | @lru_cache() 26 | def bytes_to_unicode(): 27 | """ 28 | Returns list of utf-8 byte and a corresponding list of unicode strings. 29 | The reversible bpe codes work on unicode strings. 30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 32 | This is a signficant percentage of your normal, say, 32K bpe vocab. 33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 34 | And avoids mapping to whitespace/control characters the bpe code barfs on. 35 | """ 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] 38 | n = 0 39 | for b in range(2**8): 40 | if b not in bs: 41 | bs.append(b) 42 | cs.append(2**8+n) 43 | n += 1 44 | cs = [chr(n) for n in cs] 45 | return dict(zip(bs, cs)) 46 | 47 | 48 | def get_pairs(word): 49 | """Return set of symbol pairs in a word. 50 | Word is represented as tuple of symbols (symbols being variable-length strings). 51 | """ 52 | pairs = set() 53 | prev_char = word[0] 54 | for char in word[1:]: 55 | pairs.add((prev_char, char)) 56 | prev_char = char 57 | return pairs 58 | 59 | 60 | def basic_clean(text): 61 | text = ftfy.fix_text(text) 62 | text = html.unescape(html.unescape(text)) 63 | return text.strip() 64 | 65 | 66 | def whitespace_clean(text): 67 | text = re.sub(r'\s+', ' ', text) 68 | text = text.strip() 69 | return text 70 | 71 | 72 | class SimpleTokenizer(object): 73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 74 | self.byte_encoder = bytes_to_unicode() 75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 77 | merges = merges[1:49152-256-2+1] 78 | merges = [tuple(merge.split()) for merge in merges] 79 | vocab = list(bytes_to_unicode().values()) 80 | vocab = vocab + [v+'' for v in vocab] 81 | for merge in merges: 82 | vocab.append(''.join(merge)) 83 | if not special_tokens: 84 | special_tokens = ['', ''] 85 | else: 86 | special_tokens = ['', ''] + special_tokens 87 | vocab.extend(special_tokens) 88 | self.encoder = dict(zip(vocab, range(len(vocab)))) 89 | self.decoder = {v: k for k, v in self.encoder.items()} 90 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 91 | self.cache = {t:t for t in special_tokens} 92 | special = "|".join(special_tokens) 93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 94 | 95 | self.vocab_size = len(self.encoder) 96 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 97 | 98 | def bpe(self, token): 99 | if token in self.cache: 100 | return self.cache[token] 101 | word = tuple(token[:-1]) + ( token[-1] + '',) 102 | pairs = get_pairs(word) 103 | 104 | if not pairs: 105 | return token+'' 106 | 107 | while True: 108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 109 | if bigram not in self.bpe_ranks: 110 | break 111 | first, second = bigram 112 | new_word = [] 113 | i = 0 114 | while i < len(word): 115 | try: 116 | j = word.index(first, i) 117 | new_word.extend(word[i:j]) 118 | i = j 119 | except: 120 | new_word.extend(word[i:]) 121 | break 122 | 123 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 124 | new_word.append(first+second) 125 | i += 2 126 | else: 127 | new_word.append(word[i]) 128 | i += 1 129 | new_word = tuple(new_word) 130 | word = new_word 131 | if len(word) == 1: 132 | break 133 | else: 134 | pairs = get_pairs(word) 135 | word = ' '.join(word) 136 | self.cache[token] = word 137 | return word 138 | 139 | def encode(self, text): 140 | bpe_tokens = [] 141 | text = whitespace_clean(basic_clean(text)).lower() 142 | for token in re.findall(self.pat, text): 143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 145 | return bpe_tokens 146 | 147 | def decode(self, tokens): 148 | text = ''.join([self.decoder[token] for token in tokens]) 149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 150 | return text 151 | 152 | 153 | _tokenizer = SimpleTokenizer() 154 | 155 | 156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 157 | """ 158 | Returns the tokenized representation of given input string(s) 159 | 160 | Parameters 161 | ---------- 162 | texts : Union[str, List[str]] 163 | An input string or a list of input strings to tokenize 164 | context_length : int 165 | The context length to use; all CLIP models use 77 as the context length 166 | 167 | Returns 168 | ------- 169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 170 | """ 171 | if isinstance(texts, str): 172 | texts = [texts] 173 | 174 | sot_token = _tokenizer.encoder[""] 175 | eot_token = _tokenizer.encoder[""] 176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 178 | 179 | for i, tokens in enumerate(all_tokens): 180 | if len(tokens) > context_length: 181 | tokens = tokens[:context_length] # Truncate 182 | tokens[-1] = eot_token 183 | result[i, :len(tokens)] = torch.tensor(tokens) 184 | 185 | return result 186 | 187 | 188 | class HFTokenizer: 189 | "HuggingFace tokenizer wrapper" 190 | def __init__(self, tokenizer_name:str): 191 | from transformers import AutoTokenizer 192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) 193 | 194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor: 195 | # same cleaning as for default tokenizer, except lowercasing 196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance 197 | if isinstance(texts, str): 198 | texts = [texts] 199 | texts = [whitespace_clean(basic_clean(text)) for text in texts] 200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids 201 | return input_ids 202 | -------------------------------------------------------------------------------- /eva_clip/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.transforms.functional as F 6 | 7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \ 8 | CenterCrop 9 | 10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 11 | 12 | 13 | class ResizeMaxSize(nn.Module): 14 | 15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 16 | super().__init__() 17 | if not isinstance(max_size, int): 18 | raise TypeError(f"Size should be int. Got {type(max_size)}") 19 | self.max_size = max_size 20 | self.interpolation = interpolation 21 | self.fn = min if fn == 'min' else min 22 | self.fill = fill 23 | 24 | def forward(self, img): 25 | if isinstance(img, torch.Tensor): 26 | height, width = img.shape[:2] 27 | else: 28 | width, height = img.size 29 | scale = self.max_size / float(max(height, width)) 30 | if scale != 1.0: 31 | new_size = tuple(round(dim * scale) for dim in (height, width)) 32 | img = F.resize(img, new_size, self.interpolation) 33 | pad_h = self.max_size - new_size[0] 34 | pad_w = self.max_size - new_size[1] 35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 36 | return img 37 | 38 | 39 | def _convert_to_rgb(image): 40 | return image.convert('RGB') 41 | 42 | 43 | # class CatGen(nn.Module): 44 | # def __init__(self, num=4): 45 | # self.num = num 46 | # def mixgen_batch(image, text): 47 | # batch_size = image.shape[0] 48 | # index = np.random.permutation(batch_size) 49 | 50 | # cat_images = [] 51 | # for i in range(batch_size): 52 | # # image mixup 53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:] 54 | # # text concat 55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0] 56 | # text = torch.stack(text) 57 | # return image, text 58 | 59 | 60 | def image_transform( 61 | image_size: int, 62 | is_train: bool, 63 | mean: Optional[Tuple[float, ...]] = None, 64 | std: Optional[Tuple[float, ...]] = None, 65 | resize_longest_max: bool = False, 66 | fill_color: int = 0, 67 | ): 68 | mean = mean or OPENAI_DATASET_MEAN 69 | if not isinstance(mean, (list, tuple)): 70 | mean = (mean,) * 3 71 | 72 | std = std or OPENAI_DATASET_STD 73 | if not isinstance(std, (list, tuple)): 74 | std = (std,) * 3 75 | 76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge 78 | image_size = image_size[0] 79 | 80 | normalize = Normalize(mean=mean, std=std) 81 | if is_train: 82 | return Compose([ 83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 84 | _convert_to_rgb, 85 | ToTensor(), 86 | normalize, 87 | ]) 88 | else: 89 | if resize_longest_max: 90 | transforms = [ 91 | ResizeMaxSize(image_size, fill=fill_color) 92 | ] 93 | else: 94 | transforms = [ 95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 96 | CenterCrop(image_size), 97 | ] 98 | transforms.extend([ 99 | _convert_to_rgb, 100 | ToTensor(), 101 | normalize, 102 | ]) 103 | return Compose(transforms) 104 | -------------------------------------------------------------------------------- /lists/kinetics_400_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,abseiling 3 | 1,air drumming 4 | 2,answering questions 5 | 3,applauding 6 | 4,applying cream 7 | 5,archery 8 | 6,arm wrestling 9 | 7,arranging flowers 10 | 8,assembling computer 11 | 9,auctioning 12 | 10,baby waking up 13 | 11,baking cookies 14 | 12,balloon blowing 15 | 13,bandaging 16 | 14,barbequing 17 | 15,bartending 18 | 16,beatboxing 19 | 17,bee keeping 20 | 18,belly dancing 21 | 19,bench pressing 22 | 20,bending back 23 | 21,bending metal 24 | 22,biking through snow 25 | 23,blasting sand 26 | 24,blowing glass 27 | 25,blowing leaves 28 | 26,blowing nose 29 | 27,blowing out candles 30 | 28,bobsledding 31 | 29,bookbinding 32 | 30,bouncing on trampoline 33 | 31,bowling 34 | 32,braiding hair 35 | 33,breading or breadcrumbing 36 | 34,breakdancing 37 | 35,brush painting 38 | 36,brushing hair 39 | 37,brushing teeth 40 | 38,building cabinet 41 | 39,building shed 42 | 40,bungee jumping 43 | 41,busking 44 | 42,canoeing or kayaking 45 | 43,capoeira 46 | 44,carrying baby 47 | 45,cartwheeling 48 | 46,carving pumpkin 49 | 47,catching fish 50 | 48,catching or throwing baseball 51 | 49,catching or throwing frisbee 52 | 50,catching or throwing softball 53 | 51,celebrating 54 | 52,changing oil 55 | 53,changing wheel 56 | 54,checking tires 57 | 55,cheerleading 58 | 56,chopping wood 59 | 57,clapping 60 | 58,clay pottery making 61 | 59,clean and jerk 62 | 60,cleaning floor 63 | 61,cleaning gutters 64 | 62,cleaning pool 65 | 63,cleaning shoes 66 | 64,cleaning toilet 67 | 65,cleaning windows 68 | 66,climbing a rope 69 | 67,climbing ladder 70 | 68,climbing tree 71 | 69,contact juggling 72 | 70,cooking chicken 73 | 71,cooking egg 74 | 72,cooking on campfire 75 | 73,cooking sausages 76 | 74,counting money 77 | 75,country line dancing 78 | 76,cracking neck 79 | 77,crawling baby 80 | 78,crossing river 81 | 79,crying 82 | 80,curling hair 83 | 81,cutting nails 84 | 82,cutting pineapple 85 | 83,cutting watermelon 86 | 84,dancing ballet 87 | 85,dancing charleston 88 | 86,dancing gangnam style 89 | 87,dancing macarena 90 | 88,deadlifting 91 | 89,decorating the christmas tree 92 | 90,digging 93 | 91,dining 94 | 92,disc golfing 95 | 93,diving cliff 96 | 94,dodgeball 97 | 95,doing aerobics 98 | 96,doing laundry 99 | 97,doing nails 100 | 98,drawing 101 | 99,dribbling basketball 102 | 100,drinking 103 | 101,drinking beer 104 | 102,drinking shots 105 | 103,driving car 106 | 104,driving tractor 107 | 105,drop kicking 108 | 106,drumming fingers 109 | 107,dunking basketball 110 | 108,dying hair 111 | 109,eating burger 112 | 110,eating cake 113 | 111,eating carrots 114 | 112,eating chips 115 | 113,eating doughnuts 116 | 114,eating hotdog 117 | 115,eating ice cream 118 | 116,eating spaghetti 119 | 117,eating watermelon 120 | 118,egg hunting 121 | 119,exercising arm 122 | 120,exercising with an exercise ball 123 | 121,extinguishing fire 124 | 122,faceplanting 125 | 123,feeding birds 126 | 124,feeding fish 127 | 125,feeding goats 128 | 126,filling eyebrows 129 | 127,finger snapping 130 | 128,fixing hair 131 | 129,flipping pancake 132 | 130,flying kite 133 | 131,folding clothes 134 | 132,folding napkins 135 | 133,folding paper 136 | 134,front raises 137 | 135,frying vegetables 138 | 136,garbage collecting 139 | 137,gargling 140 | 138,getting a haircut 141 | 139,getting a tattoo 142 | 140,giving or receiving award 143 | 141,golf chipping 144 | 142,golf driving 145 | 143,golf putting 146 | 144,grinding meat 147 | 145,grooming dog 148 | 146,grooming horse 149 | 147,gymnastics tumbling 150 | 148,hammer throw 151 | 149,headbanging 152 | 150,headbutting 153 | 151,high jump 154 | 152,high kick 155 | 153,hitting baseball 156 | 154,hockey stop 157 | 155,holding snake 158 | 156,hopscotch 159 | 157,hoverboarding 160 | 158,hugging 161 | 159,hula hooping 162 | 160,hurdling 163 | 161,hurling (sport) 164 | 162,ice climbing 165 | 163,ice fishing 166 | 164,ice skating 167 | 165,ironing 168 | 166,javelin throw 169 | 167,jetskiing 170 | 168,jogging 171 | 169,juggling balls 172 | 170,juggling fire 173 | 171,juggling soccer ball 174 | 172,jumping into pool 175 | 173,jumpstyle dancing 176 | 174,kicking field goal 177 | 175,kicking soccer ball 178 | 176,kissing 179 | 177,kitesurfing 180 | 178,knitting 181 | 179,krumping 182 | 180,laughing 183 | 181,laying bricks 184 | 182,long jump 185 | 183,lunge 186 | 184,making a cake 187 | 185,making a sandwich 188 | 186,making bed 189 | 187,making jewelry 190 | 188,making pizza 191 | 189,making snowman 192 | 190,making sushi 193 | 191,making tea 194 | 192,marching 195 | 193,massaging back 196 | 194,massaging feet 197 | 195,massaging legs 198 | 196,massaging person's head 199 | 197,milking cow 200 | 198,mopping floor 201 | 199,motorcycling 202 | 200,moving furniture 203 | 201,mowing lawn 204 | 202,news anchoring 205 | 203,opening bottle 206 | 204,opening present 207 | 205,paragliding 208 | 206,parasailing 209 | 207,parkour 210 | 208,passing American football (in game) 211 | 209,passing American football (not in game) 212 | 210,peeling apples 213 | 211,peeling potatoes 214 | 212,petting animal (not cat) 215 | 213,petting cat 216 | 214,picking fruit 217 | 215,planting trees 218 | 216,plastering 219 | 217,playing accordion 220 | 218,playing badminton 221 | 219,playing bagpipes 222 | 220,playing basketball 223 | 221,playing bass guitar 224 | 222,playing cards 225 | 223,playing cello 226 | 224,playing chess 227 | 225,playing clarinet 228 | 226,playing controller 229 | 227,playing cricket 230 | 228,playing cymbals 231 | 229,playing didgeridoo 232 | 230,playing drums 233 | 231,playing flute 234 | 232,playing guitar 235 | 233,playing harmonica 236 | 234,playing harp 237 | 235,playing ice hockey 238 | 236,playing keyboard 239 | 237,playing kickball 240 | 238,playing monopoly 241 | 239,playing organ 242 | 240,playing paintball 243 | 241,playing piano 244 | 242,playing poker 245 | 243,playing recorder 246 | 244,playing saxophone 247 | 245,playing squash or racquetball 248 | 246,playing tennis 249 | 247,playing trombone 250 | 248,playing trumpet 251 | 249,playing ukulele 252 | 250,playing violin 253 | 251,playing volleyball 254 | 252,playing xylophone 255 | 253,pole vault 256 | 254,presenting weather forecast 257 | 255,pull ups 258 | 256,pumping fist 259 | 257,pumping gas 260 | 258,punching bag 261 | 259,punching person (boxing) 262 | 260,push up 263 | 261,pushing car 264 | 262,pushing cart 265 | 263,pushing wheelchair 266 | 264,reading book 267 | 265,reading newspaper 268 | 266,recording music 269 | 267,riding a bike 270 | 268,riding camel 271 | 269,riding elephant 272 | 270,riding mechanical bull 273 | 271,riding mountain bike 274 | 272,riding mule 275 | 273,riding or walking with horse 276 | 274,riding scooter 277 | 275,riding unicycle 278 | 276,ripping paper 279 | 277,robot dancing 280 | 278,rock climbing 281 | 279,rock scissors paper 282 | 280,roller skating 283 | 281,running on treadmill 284 | 282,sailing 285 | 283,salsa dancing 286 | 284,sanding floor 287 | 285,scrambling eggs 288 | 286,scuba diving 289 | 287,setting table 290 | 288,shaking hands 291 | 289,shaking head 292 | 290,sharpening knives 293 | 291,sharpening pencil 294 | 292,shaving head 295 | 293,shaving legs 296 | 294,shearing sheep 297 | 295,shining shoes 298 | 296,shooting basketball 299 | 297,shooting goal (soccer) 300 | 298,shot put 301 | 299,shoveling snow 302 | 300,shredding paper 303 | 301,shuffling cards 304 | 302,side kick 305 | 303,sign language interpreting 306 | 304,singing 307 | 305,situp 308 | 306,skateboarding 309 | 307,ski jumping 310 | 308,skiing (not slalom or crosscountry) 311 | 309,skiing crosscountry 312 | 310,skiing slalom 313 | 311,skipping rope 314 | 312,skydiving 315 | 313,slacklining 316 | 314,slapping 317 | 315,sled dog racing 318 | 316,smoking 319 | 317,smoking hookah 320 | 318,snatch weight lifting 321 | 319,sneezing 322 | 320,sniffing 323 | 321,snorkeling 324 | 322,snowboarding 325 | 323,snowkiting 326 | 324,snowmobiling 327 | 325,somersaulting 328 | 326,spinning poi 329 | 327,spray painting 330 | 328,spraying 331 | 329,springboard diving 332 | 330,squat 333 | 331,sticking tongue out 334 | 332,stomping grapes 335 | 333,stretching arm 336 | 334,stretching leg 337 | 335,strumming guitar 338 | 336,surfing crowd 339 | 337,surfing water 340 | 338,sweeping floor 341 | 339,swimming backstroke 342 | 340,swimming breast stroke 343 | 341,swimming butterfly stroke 344 | 342,swing dancing 345 | 343,swinging legs 346 | 344,swinging on something 347 | 345,sword fighting 348 | 346,tai chi 349 | 347,taking a shower 350 | 348,tango dancing 351 | 349,tap dancing 352 | 350,tapping guitar 353 | 351,tapping pen 354 | 352,tasting beer 355 | 353,tasting food 356 | 354,testifying 357 | 355,texting 358 | 356,throwing axe 359 | 357,throwing ball 360 | 358,throwing discus 361 | 359,tickling 362 | 360,tobogganing 363 | 361,tossing coin 364 | 362,tossing salad 365 | 363,training dog 366 | 364,trapezing 367 | 365,trimming or shaving beard 368 | 366,trimming trees 369 | 367,triple jump 370 | 368,tying bow tie 371 | 369,tying knot (not on a tie) 372 | 370,tying tie 373 | 371,unboxing 374 | 372,unloading truck 375 | 373,using computer 376 | 374,using remote controller (not gaming) 377 | 375,using segway 378 | 376,vault 379 | 377,waiting in line 380 | 378,walking the dog 381 | 379,washing dishes 382 | 380,washing feet 383 | 381,washing hair 384 | 382,washing hands 385 | 383,water skiing 386 | 384,water sliding 387 | 385,watering plants 388 | 386,waxing back 389 | 387,waxing chest 390 | 388,waxing eyebrows 391 | 389,waxing legs 392 | 390,weaving basket 393 | 391,welding 394 | 392,whistling 395 | 393,windsurfing 396 | 394,wrapping present 397 | 395,wrestling 398 | 396,writing 399 | 397,yawning 400 | 398,yoga 401 | 399,zumba 402 | -------------------------------------------------------------------------------- /lists/sth_labels.csv: -------------------------------------------------------------------------------- 1 | id,name 2 | 0,Approaching something with your camera 3 | 1,Attaching something to something 4 | 2,Bending something so that it deforms 5 | 3,Bending something until it breaks 6 | 4,Burying something in something 7 | 5,Closing something 8 | 6,Covering something with something 9 | 7,Digging something out of something 10 | 8,Dropping something behind something 11 | 9,Dropping something in front of something 12 | 10,Dropping something into something 13 | 11,Dropping something next to something 14 | 12,Dropping something onto something 15 | 13,Failing to put something into something because something does not fit 16 | 14,Folding something 17 | 15,Hitting something with something 18 | 16,Holding something 19 | 17,Holding something behind something 20 | 18,Holding something in front of something 21 | 19,Holding something next to something 22 | 20,Holding something over something 23 | 21,Laying something on the table on its side not upright 24 | 22,Letting something roll along a flat surface 25 | 23,Letting something roll down a slanted surface 26 | 24,Letting something roll up a slanted surface so it rolls back down 27 | 25,Lifting a surface with something on it but not enough for it to slide down 28 | 26,Lifting a surface with something on it until it starts sliding down 29 | 27,Lifting something up completely without letting it drop down 30 | 28,Lifting something up completely then letting it drop down 31 | 29,Lifting something with something on it 32 | 30,Lifting up one end of something without letting it drop down 33 | 31,Lifting up one end of something then letting it drop down 34 | 32,Moving away from something with your camera 35 | 33,Moving part of something 36 | 34,Moving something across a surface until it falls down 37 | 35,Moving something across a surface without it falling down 38 | 36,Moving something and something away from each other 39 | 37,Moving something and something closer to each other 40 | 38,Moving something and something so they collide with each other 41 | 39,Moving something and something so they pass each other 42 | 40,Moving something away from something 43 | 41,Moving something away from the camera 44 | 42,Moving something closer to something 45 | 43,Moving something down 46 | 44,Moving something towards the camera 47 | 45,Moving something up 48 | 46,Opening something 49 | 47,Picking something up 50 | 48,Piling something up 51 | 49,Plugging something into something 52 | 50,Plugging something into something but pulling it right out as you remove your hand 53 | 51,Poking a hole into some substance 54 | 52,Poking a hole into something soft 55 | 53,Poking a stack of something so the stack collapses 56 | 54,Poking a stack of something without the stack collapsing 57 | 55,Poking something so it slightly moves 58 | 56,Poking something so lightly that it doesn't or almost doesn't move 59 | 57,Poking something so that it falls over 60 | 58,Poking something so that it spins around 61 | 59,Pouring something into something 62 | 60,Pouring something into something until it overflows 63 | 61,Pouring something onto something 64 | 62,Pouring something out of something 65 | 63,Pretending or failing to wipe something off of something 66 | 64,Pretending or trying and failing to twist something 67 | 65,Pretending to be tearing something that is not tearable 68 | 66,Pretending to close something without actually closing it 69 | 67,Pretending to open something without actually opening it 70 | 68,Pretending to pick something up 71 | 69,Pretending to poke something 72 | 70,Pretending to pour something out of something but something is empty 73 | 71,Pretending to put something behind something 74 | 72,Pretending to put something into something 75 | 73,Pretending to put something next to something 76 | 74,Pretending to put something on a surface 77 | 75,Pretending to put something onto something 78 | 76,Pretending to put something underneath something 79 | 77,Pretending to scoop something up with something 80 | 78,Pretending to spread air onto something 81 | 79,Pretending to sprinkle air onto something 82 | 80,Pretending to squeeze something 83 | 81,Pretending to take something from somewhere 84 | 82,Pretending to take something out of something 85 | 83,Pretending to throw something 86 | 84,Pretending to turn something upside down 87 | 85,Pulling something from behind of something 88 | 86,Pulling something from left to right 89 | 87,Pulling something from right to left 90 | 88,Pulling something onto something 91 | 89,Pulling something out of something 92 | 90,Pulling two ends of something but nothing happens 93 | 91,Pulling two ends of something so that it gets stretched 94 | 92,Pulling two ends of something so that it separates into two pieces 95 | 93,Pushing something from left to right 96 | 94,Pushing something from right to left 97 | 95,Pushing something off of something 98 | 96,Pushing something onto something 99 | 97,Pushing something so it spins 100 | 98,Pushing something so that it almost falls off but doesn't 101 | 99,Pushing something so that it falls off the table 102 | 100,Pushing something so that it slightly moves 103 | 101,Pushing something with something 104 | 102,Putting number of something onto something 105 | 103,Putting something and something on the table 106 | 104,Putting something behind something 107 | 105,Putting something in front of something 108 | 106,Putting something into something 109 | 107,Putting something next to something 110 | 108,Putting something on a flat surface without letting it roll 111 | 109,Putting something on a surface 112 | 110,Putting something on the edge of something so it is not supported and falls down 113 | 111,Putting something onto a slanted surface but it doesn't glide down 114 | 112,Putting something onto something 115 | 113,Putting something onto something else that cannot support it so it falls down 116 | 114,Putting something similar to other things that are already on the table 117 | 115,Putting something that can't roll onto a slanted surface so it slides down 118 | 116,Putting something that can't roll onto a slanted surface so it stays where it is 119 | 117,Putting something that cannot actually stand upright upright on the table so it falls on its side 120 | 118,Putting something underneath something 121 | 119,Putting something upright on the table 122 | 120,Putting something something and something on the table 123 | 121,Removing something revealing something behind 124 | 122,Rolling something on a flat surface 125 | 123,Scooping something up with something 126 | 124,Showing a photo of something to the camera 127 | 125,Showing something behind something 128 | 126,Showing something next to something 129 | 127,Showing something on top of something 130 | 128,Showing something to the camera 131 | 129,Showing that something is empty 132 | 130,Showing that something is inside something 133 | 131,Something being deflected from something 134 | 132,Something colliding with something and both are being deflected 135 | 133,Something colliding with something and both come to a halt 136 | 134,Something falling like a feather or paper 137 | 135,Something falling like a rock 138 | 136,Spilling something behind something 139 | 137,Spilling something next to something 140 | 138,Spilling something onto something 141 | 139,Spinning something so it continues spinning 142 | 140,Spinning something that quickly stops spinning 143 | 141,Spreading something onto something 144 | 142,Sprinkling something onto something 145 | 143,Squeezing something 146 | 144,Stacking number of something 147 | 145,Stuffing something into something 148 | 146,Taking one of many similar things on the table 149 | 147,Taking something from somewhere 150 | 148,Taking something out of something 151 | 149,Tearing something into two pieces 152 | 150,Tearing something just a little bit 153 | 151,Throwing something 154 | 152,Throwing something against something 155 | 153,Throwing something in the air and catching it 156 | 154,Throwing something in the air and letting it fall 157 | 155,Throwing something onto a surface 158 | 156,Tilting something with something on it slightly so it doesn't fall down 159 | 157,Tilting something with something on it until it falls off 160 | 158,Tipping something over 161 | 159,Tipping something with something in it over so something in it falls out 162 | 160,Touching (without moving) part of something 163 | 161,Trying but failing to attach something to something because it doesn't stick 164 | 162,Trying to bend something unbendable so nothing happens 165 | 163,Trying to pour something into something but missing so it spills next to it 166 | 164,Turning something upside down 167 | 165,Turning the camera downwards while filming something 168 | 166,Turning the camera left while filming something 169 | 167,Turning the camera right while filming something 170 | 168,Turning the camera upwards while filming something 171 | 169,Twisting (wringing) something wet until water comes out 172 | 170,Twisting something 173 | 171,Uncovering something 174 | 172,Unfolding something 175 | 173,Wiping something off of something -------------------------------------------------------------------------------- /optim_factory.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import optim as optim 3 | 4 | from timm.optim.adafactor import Adafactor 5 | from timm.optim.adahessian import Adahessian 6 | from timm.optim.adamp import AdamP 7 | from timm.optim.lookahead import Lookahead 8 | from timm.optim.nadam import Nadam 9 | from timm.optim.novograd import NovoGrad 10 | from timm.optim.nvnovograd import NvNovoGrad 11 | from timm.optim.radam import RAdam 12 | from timm.optim.rmsprop_tf import RMSpropTF 13 | from timm.optim.sgdp import SGDP 14 | 15 | import json 16 | 17 | try: 18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD 19 | has_apex = True 20 | except ImportError: 21 | has_apex = False 22 | 23 | 24 | def get_num_layer_for_vit(var_name, num_max_layer): 25 | if var_name in ("pos_embed", "temporal_embed"): ###!!! 26 | return 0 27 | elif var_name.startswith("visual.conv1"): 28 | return 0 29 | elif var_name.startswith("visual.ln_pre"): 30 | return 0 31 | elif var_name.startswith("visual.transformer.resblocks"): 32 | layer_id = int(var_name.split('.')[3]) #// 2 33 | return layer_id + 1 34 | else: 35 | return num_max_layer - 1 36 | 37 | def get_num_layer_for_eva_vit(var_name, num_max_layer): 38 | if var_name in ("pos_embed", "temporal_embed"): ###!!! 39 | return 0 40 | elif var_name.startswith("visual.patch_embed"): 41 | return 0 42 | #elif var_name.startswith("visual.ln_pre"): 43 | # return 0 44 | elif var_name.startswith("visual.blocks"): 45 | layer_id = int(var_name.split('.')[2]) #// 2 46 | return layer_id + 1 47 | else: 48 | return num_max_layer - 1 49 | 50 | class LayerDecayValueAssigner(object): 51 | def __init__(self, values, vit_arch='clip'): 52 | self.values = values 53 | self.vit_arch = vit_arch 54 | 55 | def get_scale(self, layer_id): 56 | return self.values[layer_id] 57 | 58 | def get_layer_id(self, var_name): 59 | if self.vit_arch == 'clip': 60 | return get_num_layer_for_vit(var_name, len(self.values)) 61 | elif self.vit_arch == 'eva_clip': 62 | return get_num_layer_for_eva_vit(var_name, len(self.values)) 63 | 64 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None): 65 | parameter_group_names = {} 66 | parameter_group_vars = {} 67 | 68 | for name, param in model.named_parameters(): 69 | if not param.requires_grad: 70 | continue # frozen weights 71 | 72 | 73 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: ###!!! 74 | group_name = "no_decay" 75 | this_weight_decay = 0. 76 | elif 'control_point' in name: 77 | group_name = "control_point" 78 | this_weight_decay = weight_decay 79 | group_name = 'control_point' 80 | else: 81 | group_name = "decay" 82 | this_weight_decay = weight_decay 83 | 84 | if get_num_layer is not None: 85 | layer_id = get_num_layer(name) 86 | group_name = "layer_%d_%s" % (layer_id, group_name) 87 | else: 88 | layer_id = None 89 | 90 | if group_name not in parameter_group_names: 91 | if get_layer_scale is not None: 92 | if 'control_point' not in group_name and 'visual' in name: #for clip part 93 | scale = get_layer_scale(layer_id) #0.001 94 | else: 95 | scale = 1. #get_layer_scale(layer_id) 96 | else: 97 | scale = 1. 98 | 99 | parameter_group_names[group_name] = { 100 | "weight_decay": this_weight_decay, 101 | "params": [], 102 | "lr_scale": scale 103 | } 104 | parameter_group_vars[group_name] = { 105 | "weight_decay": this_weight_decay, 106 | "params": [], 107 | "lr_scale": scale 108 | } 109 | 110 | parameter_group_vars[group_name]["params"].append(param) 111 | parameter_group_names[group_name]["params"].append(name) 112 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) 113 | return list(parameter_group_vars.values()) 114 | 115 | 116 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None): 117 | opt_lower = args.opt.lower() 118 | weight_decay = args.weight_decay 119 | if weight_decay and filter_bias_and_bn: 120 | skip = {} 121 | if skip_list is not None: 122 | skip = skip_list 123 | elif hasattr(model, 'no_weight_decay'): 124 | skip = model.no_weight_decay() 125 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale) 126 | weight_decay = 0. 127 | else: 128 | parameters = model.parameters() 129 | 130 | if 'fused' in opt_lower: 131 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers' 132 | 133 | opt_args = dict(lr=args.lr, weight_decay=weight_decay) 134 | if hasattr(args, 'opt_eps') and args.opt_eps is not None: 135 | opt_args['eps'] = args.opt_eps 136 | if hasattr(args, 'opt_betas') and args.opt_betas is not None: 137 | opt_args['betas'] = args.opt_betas 138 | 139 | print("optimizer settings:", opt_args) 140 | 141 | opt_split = opt_lower.split('_') 142 | opt_lower = opt_split[-1] 143 | if opt_lower == 'sgd' or opt_lower == 'nesterov': 144 | opt_args.pop('eps', None) 145 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 146 | elif opt_lower == 'momentum': 147 | opt_args.pop('eps', None) 148 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 149 | elif opt_lower == 'adam': 150 | optimizer = optim.Adam(parameters, **opt_args) 151 | elif opt_lower == 'adamw': 152 | optimizer = optim.AdamW(parameters, **opt_args) 153 | elif opt_lower == 'nadam': 154 | optimizer = Nadam(parameters, **opt_args) 155 | elif opt_lower == 'radam': 156 | optimizer = RAdam(parameters, **opt_args) 157 | elif opt_lower == 'adamp': 158 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args) 159 | elif opt_lower == 'sgdp': 160 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args) 161 | elif opt_lower == 'adadelta': 162 | optimizer = optim.Adadelta(parameters, **opt_args) 163 | elif opt_lower == 'adafactor': 164 | if not args.lr: 165 | opt_args['lr'] = None 166 | optimizer = Adafactor(parameters, **opt_args) 167 | elif opt_lower == 'adahessian': 168 | optimizer = Adahessian(parameters, **opt_args) 169 | elif opt_lower == 'rmsprop': 170 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 171 | elif opt_lower == 'rmsproptf': 172 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args) 173 | elif opt_lower == 'novograd': 174 | optimizer = NovoGrad(parameters, **opt_args) 175 | elif opt_lower == 'nvnovograd': 176 | optimizer = NvNovoGrad(parameters, **opt_args) 177 | elif opt_lower == 'fusedsgd': 178 | opt_args.pop('eps', None) 179 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args) 180 | elif opt_lower == 'fusedmomentum': 181 | opt_args.pop('eps', None) 182 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args) 183 | elif opt_lower == 'fusedadam': 184 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args) 185 | elif opt_lower == 'fusedadamw': 186 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args) 187 | elif opt_lower == 'fusedlamb': 188 | optimizer = FusedLAMB(parameters, **opt_args) 189 | elif opt_lower == 'fusednovograd': 190 | opt_args.setdefault('betas', (0.95, 0.98)) 191 | optimizer = FusedNovoGrad(parameters, **opt_args) 192 | else: 193 | assert False and "Invalid optimizer" 194 | raise ValueError 195 | 196 | if len(opt_split) > 1: 197 | if opt_split[0] == 'lookahead': 198 | optimizer = Lookahead(optimizer) 199 | 200 | return optimizer 201 | 202 | -------------------------------------------------------------------------------- /pics/ATM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/pics/ATM.png -------------------------------------------------------------------------------- /scripts/k400/train_base.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/k400/base_f8' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model ViT-B/16 \ 15 | --data_set Kinetics-400 \ 16 | --nb_classes 400 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 8 \ 25 | --embed_dim 512 \ 26 | --opt adamw \ 27 | --lr 3e-4 \ 28 | --layer_decay 0.65 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 20 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed -------------------------------------------------------------------------------- /scripts/k400/train_eva_large.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/k400/eva_large_f8' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model EVA02-CLIP-L-14 \ 15 | --data_set Kinetics-400 \ 16 | --nb_classes 400 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 8 \ 25 | --embed_dim 768 \ 26 | --opt adamw \ 27 | --lr 2e-4 \ 28 | --layer_decay 0.7 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 20 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed -------------------------------------------------------------------------------- /scripts/k400/train_eva_large_336.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/k400/eva_large_336_f8' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model EVA02-CLIP-L-14-336 \ 15 | --data_set Kinetics-400 \ 16 | --nb_classes 400 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 336 \ 22 | --short_side_size 336 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 8 \ 25 | --embed_dim 768 \ 26 | --opt adamw \ 27 | --lr 2e-4 \ 28 | --layer_decay 0.7 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 20 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed -------------------------------------------------------------------------------- /scripts/ssv1/test_base_f16.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F16: Overall Prec@1 60.609% Prec@5 86.450% 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 6 | 7 | MODEL_PATH='output/ssv1/base_f16_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV1 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 16 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv1/test_base_f32.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F32: Overall Prec@1 61.450% Prec@5 86.251% 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 6 | 7 | MODEL_PATH='output/ssv1/base_f32_bs10_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV1 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 32 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv1/test_base_f8.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F8: Overall Prec@1 58.744% Prec@5 85.427% 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 6 | 7 | MODEL_PATH='output/ssv1/base_f8_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18899 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV1 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 8 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv1/test_large_f16.sh: -------------------------------------------------------------------------------- 1 | # ViT-L/14 F16: 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 6 | 7 | MODEL_PATH='output/ssv1/large_f16_bs6_atm14/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-L/14 \ 17 | --data_set SSV1 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 16 \ 27 | --embed_dim 768 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv1/train_base.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/ssv1/base_f8' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model ViT-B/16 \ 15 | --data_set SSV1 \ 16 | --nb_classes 174 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 8 \ 25 | --embed_dim 512 \ 26 | --opt adamw \ 27 | --lr 7e-4 \ 28 | --layer_decay 0.70 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 20 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed \ 36 | --aa True -------------------------------------------------------------------------------- /scripts/ssv1/train_eva_large.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/ssv1/eva_large_f16' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model EVA02-CLIP-L-14 \ 15 | --data_set SSV1 \ 16 | --nb_classes 174 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 16 \ 25 | --embed_dim 768 \ 26 | --opt adamw \ 27 | --lr 1e-3 \ 28 | --layer_decay 0.75 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 15 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed \ 36 | --aa True -------------------------------------------------------------------------------- /scripts/ssv2/test_base_f16.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F16: 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 6 | 7 | MODEL_PATH='output/ssv2/base_f16_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV2 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 16 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv2/test_base_f32.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F32: 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 6 | 7 | MODEL_PATH='../ATM_ICCV23_CODE_MODELS/output/ssv2/base_f32_bs10_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV2 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 32 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv2/test_base_f8.sh: -------------------------------------------------------------------------------- 1 | # ViT-B/16 F8: 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 6 | 7 | MODEL_PATH='output/ssv2/base_f8_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18899 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-B/16 \ 17 | --data_set SSV2 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 8 \ 27 | --embed_dim 512 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv2/test_large_f16.sh: -------------------------------------------------------------------------------- 1 | # ViT-L/14 F16: 2 | 3 | # Set the path to save checkpoints 4 | OUTPUT_DIR='./output/test' 5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 6 | 7 | MODEL_PATH='../ATM_ICCV23_CODE_MODELS/output/ssv2/large_f16_bs6_atm14/checkpoint-best/mp_rank_00_model_states.pt' 8 | 9 | MASTER_ADDR='127.0.0.1' 10 | 11 | NNODES=1 12 | 13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 15 | test_for_frame.py \ 16 | --model ViT-L/14 \ 17 | --data_set SSV2 \ 18 | --nb_classes 174 \ 19 | --data_path ${DATA_PATH} \ 20 | --log_dir ${OUTPUT_DIR} \ 21 | --output_dir ${OUTPUT_DIR} \ 22 | --batch_size 10 \ 23 | --input_size 224 \ 24 | --short_side_size 224 \ 25 | --save_ckpt_freq 5 \ 26 | --num_frames 16 \ 27 | --embed_dim 768 \ 28 | --dist_eval \ 29 | --test_num_segment 2 \ 30 | --test_num_crop 3 \ 31 | --resume ${MODEL_PATH} 32 | 33 | 34 | -------------------------------------------------------------------------------- /scripts/ssv2/train_base.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/ssv2/base_f8' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model ViT-B/16 \ 15 | --data_set SSV2 \ 16 | --nb_classes 174 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 8 \ 25 | --embed_dim 512 \ 26 | --opt adamw \ 27 | --lr 7e-4 \ 28 | --layer_decay 0.70 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 20 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed \ 36 | --aa True -------------------------------------------------------------------------------- /scripts/ssv2/train_eva_large.sh: -------------------------------------------------------------------------------- 1 | # Set the path to save checkpoints 2 | OUTPUT_DIR='output/ssv2/eva_large_f16' 3 | 4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames' 5 | # path to pretrain model 6 | MODEL_PATH=' ' 7 | 8 | MASTER_ADDR='127.0.0.1' 9 | NNODES=1 10 | 11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \ 12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \ 13 | run_class_finetuning.py \ 14 | --model EVA02-CLIP-L-14 \ 15 | --data_set SSV2 \ 16 | --nb_classes 174 \ 17 | --data_path ${DATA_PATH} \ 18 | --log_dir ${OUTPUT_DIR} \ 19 | --output_dir ${OUTPUT_DIR} \ 20 | --batch_size 16 \ 21 | --input_size 224 \ 22 | --short_side_size 224 \ 23 | --save_ckpt_freq 5 \ 24 | --num_frames 16 \ 25 | --embed_dim 768 \ 26 | --opt adamw \ 27 | --lr 1e-3 \ 28 | --layer_decay 0.75 \ 29 | --opt_betas 0.9 0.999 \ 30 | --weight_decay 0.05 \ 31 | --epochs 15 \ 32 | --drop_path 0. \ 33 | --drop 0. \ 34 | --dist_eval \ 35 | --enable_deepspeed \ 36 | --aa True --------------------------------------------------------------------------------