├── .gitignore
├── LICENSE
├── README.md
├── clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── dataset
├── Augmentation.py
├── __init__.py
├── charades.py
├── kinetics.py
├── mixup.py
├── rand_augment.py
├── random_erasing.py
├── sth.py
├── sthvideo.py
└── transforms.py
├── datasets.py
├── engine_for_finetuning.py
├── eva_clip
├── __init__.py
├── bpe_simple_vocab_16e6.txt.gz
├── constants.py
├── eva_vit_model.py
├── factory.py
├── hf_configs.py
├── hf_model.py
├── loss.py
├── model.py
├── model_configs
│ ├── EVA01-CLIP-B-16.json
│ ├── EVA01-CLIP-g-14-plus.json
│ ├── EVA01-CLIP-g-14.json
│ ├── EVA02-CLIP-B-16.json
│ ├── EVA02-CLIP-L-14-336.json
│ ├── EVA02-CLIP-L-14.json
│ ├── EVA02-CLIP-bigE-14-plus.json
│ └── EVA02-CLIP-bigE-14.json
├── modified_resnet.py
├── openai.py
├── pretrained.py
├── rope.py
├── timm_model.py
├── tokenizer.py
├── transform.py
├── transformer.py
└── utils.py
├── lists
├── k400
│ ├── kinetics_rgb_train_se320.txt
│ └── kinetics_rgb_val_se320.txt
├── kinetics_400_labels.csv
├── sth_labels.csv
├── sthv1
│ ├── train_rgb.txt
│ └── val_rgb.txt
└── sthv2
│ ├── train_rgb.txt
│ └── val_rgb.txt
├── modules
├── ATM.py
└── temporal_modeling.py
├── optim_factory.py
├── pics
└── ATM.png
├── run_class_finetuning.py
├── scripts
├── k400
│ ├── train_base.sh
│ ├── train_eva_large.sh
│ └── train_eva_large_336.sh
├── ssv1
│ ├── test_base_f16.sh
│ ├── test_base_f32.sh
│ ├── test_base_f8.sh
│ ├── test_large_f16.sh
│ ├── train_base.sh
│ └── train_eva_large.sh
└── ssv2
│ ├── test_base_f16.sh
│ ├── test_base_f32.sh
│ ├── test_base_f8.sh
│ ├── test_large_f16.sh
│ ├── train_base.sh
│ └── train_eva_large.sh
├── test_for_frame.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Wenhao Wu
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
If you like our project, please give us a star ⭐ on GitHub for latest update.
8 |
9 |
10 |
11 | [](https://openaccess.thecvf.com/content/ICCV2023/html/Wu_What_Can_Simple_Arithmetic_Operations_Do_for_Temporal_Modeling_ICCV_2023_paper.html)
12 | [](https://arxiv.org/abs/2307.08908)
13 |
14 |
15 | [Wenhao Wu](https://whwu95.github.io/)1,2, [Yuxin Song]()2, [Zhun Sun]()2, [Jingdong Wang](https://jingdongwang2017.github.io/)3, [Chang Xu](http://changxu.xyz/)1, [Wanli Ouyang](https://wlouyang.github.io/)3,1
16 |
17 |
18 | 1[The University of Sydney](https://www.sydney.edu.au/), 2[Baidu](https://vis.baidu.com/#/), 3[Shanghai AI Lab](https://www.shlab.org.cn/)
19 |
20 |
21 |
22 |
23 | ***
24 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-something-1?p=what-can-simple-arithmetic-operations-do-for)
25 | [](https://paperswithcode.com/sota/action-classification-on-kinetics-400?p=what-can-simple-arithmetic-operations-do-for)
26 | [](https://paperswithcode.com/sota/action-recognition-in-videos-on-something?p=what-can-simple-arithmetic-operations-do-for)
27 |
28 | This is the official implementation of our **ATM** (Arithmetic Temporal Module), which explores the potential of four simple arithmetic operations for temporal modeling.
29 |
30 | Our best model can achieve **89.4%** Top-1 Acc. on Kinetics-400, **65.6%** Top-1 Acc. on Something-Something V1, **74.6%** Top-1 Acc. on Something-Something V2!
31 |
32 |
33 | 🔥 I also have other recent video recognition projects that may interest you ✨.
34 |
35 | > [**Side4Video: Spatial-Temporal Side Network for Memory-Efficient Image-to-Video Transfer Learning**](https://arxiv.org/abs/2311.15769)
36 | > Huanjin Yao, Wenhao Wu, Zhiheng Li
37 | > [](https://arxiv.org/abs/2311.15769) [](https://github.com/HJYao00/Side4Video)
38 |
39 |
40 |
41 | > [**Bidirectional Cross-Modal Knowledge Exploration for Video Recognition with Pre-trained Vision-Language Models**](https://arxiv.org/abs/2301.00182)
42 | > Wenhao Wu, Xiaohan Wang, Haipeng Luo, Jingdong Wang, Yi Yang, Wanli Ouyang
43 | > [](https://openaccess.thecvf.com/content/CVPR2023/html/Wu_Bidirectional_Cross-Modal_Knowledge_Exploration_for_Video_Recognition_With_Pre-Trained_Vision-Language_CVPR_2023_paper.html) [](https://github.com/whwu95/BIKE)
44 |
45 |
46 | > [**Revisiting Classifier: Transferring Vision-Language Models for Video Recognition**](https://arxiv.org/abs/2207.01297)
47 | > Wenhao Wu, Zhun Sun, Wanli Ouyang
48 | > [](https://ojs.aaai.org/index.php/AAAI/article/view/25386/25158) [](https://link.springer.com/article/10.1007/s11263-023-01876-w) [](https://github.com/whwu95/Text4Vis)
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
64 |
65 |
66 |
67 | ## 📣 News
68 |
69 |
70 | - [x] `Nov 29, 2023`: Training codes have be released.
71 | - [x] `July 14, 2023`: 🎉Our **ATM** has been accepted by **ICCV-2023**.
72 |
73 |
74 |
75 | ## 🌈 Overview
76 | 
77 | The key motivation behind ATM is to explore the potential of simple arithmetic operations to capture auxiliary temporal clues that may be embedded in current video features, without relying on the elaborate design. The ATM can be integrated into both vanilla CNN backbone (e.g., ResNet) and Vision Transformer (e.g., ViT) for video action recognition.
78 |
79 |
80 | ## 🚀 Training & Testing
81 | We offer training and testing scripts for Kinetics-400, Sth-Sth V1, and Sth-Sth V2. Please refer to the [*script*](https://github.com/whwu95/ATM/tree/main/scripts) folder for details. For example, you can run:
82 |
83 | ```sh
84 | # Train the 8 Frames ViT-B/32 model on Sth-Sth v1.
85 | sh scripts/ssv1/train_base.sh
86 |
87 | # Test the 8 Frames ViT-B/32 model on Sth-Sth v1.
88 | sh scripts/ssv1/test_base_f8.sh
89 | ```
90 |
91 |
92 |
93 |
94 | ## 📌 BibTeX & Citation
95 |
96 | If you use our code in your research or wish to refer to the baseline results, please use the following BibTeX entry😁.
97 |
98 |
99 | ```bibtex
100 | @inproceedings{atm,
101 | title={What Can Simple Arithmetic Operations Do for Temporal Modeling?},
102 | author={Wu, Wenhao and Song, Yuxin and Sun, Zhun and Wang, Jingdong and Xu, Chang and Ouyang, Wanli},
103 | booktitle={Proceedings of the IEEE International Conference on Computer Vision (ICCV)},
104 | year={2023}
105 | }
106 | ```
107 |
108 |
109 |
110 | ## 🎗️ Acknowledgement
111 |
112 | This repository is built upon portions of [VideoMAE](https://github.com/MCG-NJU/VideoMAE), [CLIP](https://github.com/openai/CLIP), and [EVA](https://github.com/baaivision/EVA). Thanks to the contributors of these great codebases.
113 |
114 |
115 | ## 👫 Contact
116 | For any question, please file an issue or contact [Wenhao Wu](https://whwu95.github.io/).
117 |
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
3 |
4 |
5 |
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14-336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"
40 | }
41 |
42 |
43 | def _download(url: str, root: str = os.path.expanduser("./clip_pretrain")):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(
95 | name: str,
96 | device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
97 | jit=True,
98 | tsm=False, joint=False, T=8, dropout=0.,
99 | emb_dropout=0.,
100 | pretrain=True):
101 | """Load a CLIP model
102 |
103 | Parameters
104 | ----------
105 | name : str
106 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
107 |
108 | device : Union[str, torch.device]
109 | The device to put the loaded model
110 |
111 | jit : bool
112 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
113 |
114 | download_root: str
115 | path to download the model files; by default, it uses "~/.cache/clip"
116 |
117 | Returns
118 | -------
119 | model : torch.nn.Module
120 | The CLIP model
121 |
122 | preprocess : Callable[[PIL.Image], torch.Tensor]
123 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
124 | """
125 | if name in _MODELS:
126 | model_path = _download(_MODELS[name])
127 | elif os.path.isfile(name):
128 | model_path = name
129 | else:
130 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
131 |
132 | try:
133 | # loading JIT archive
134 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
135 | state_dict = None
136 | except RuntimeError:
137 | # loading saved state dict
138 | if jit:
139 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
140 | jit = False
141 | state_dict = torch.load(model_path, map_location="cpu")
142 |
143 | if not jit:
144 | model = build_model(state_dict or model.state_dict(), joint=joint,tsm=tsm,T=T,dropout=dropout, emb_dropout=emb_dropout,pretrain=pretrain).to(device)
145 | if str(device) == "cpu":
146 | model.float()
147 | return model, model.state_dict()
148 |
149 | # patch the device names
150 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
151 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
152 |
153 | def patch_device(module):
154 | try:
155 | graphs = [module.graph] if hasattr(module, "graph") else []
156 | except RuntimeError:
157 | graphs = []
158 |
159 | if hasattr(module, "forward1"):
160 | graphs.append(module.forward1.graph)
161 |
162 | for graph in graphs:
163 | for node in graph.findAllNodes("prim::Constant"):
164 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
165 | node.copyAttributes(device_node)
166 |
167 | model.apply(patch_device)
168 |
169 | if str(device) == "cpu":
170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
172 | float_node = float_input.node()
173 |
174 | def patch_float(module):
175 | try:
176 | graphs = [module.graph] if hasattr(module, "graph") else []
177 | except RuntimeError:
178 | graphs = []
179 |
180 | if hasattr(module, "forward1"):
181 | graphs.append(module.forward1.graph)
182 |
183 | for graph in graphs:
184 | for node in graph.findAllNodes("aten::to"):
185 | inputs = list(node.inputs())
186 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
187 | if inputs[i].node()["value"] == 5:
188 | inputs[i].node().copyAttributes(float_node)
189 |
190 | model.apply(patch_float)
191 | patch_float(model.encode_image)
192 | patch_float(model.encode_text)
193 |
194 | model.float()
195 |
196 | return model, _transform(model.input_resolution.item())
197 |
198 |
199 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
200 | """
201 | Returns the tokenized representation of given input string(s)
202 |
203 | Parameters
204 | ----------
205 | texts : Union[str, List[str]]
206 | An input string or a list of input strings to tokenize
207 |
208 | context_length : int
209 | The context length to use; all CLIP models use 77 as the context length
210 |
211 | truncate: bool
212 | Whether to truncate the text in case its encoding is longer than the context length
213 |
214 | Returns
215 | -------
216 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
217 | """
218 | if isinstance(texts, str):
219 | texts = [texts]
220 |
221 | sot_token = _tokenizer.encoder["<|startoftext|>"]
222 | eot_token = _tokenizer.encoder["<|endoftext|>"]
223 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 |
226 | for i, tokens in enumerate(all_tokens):
227 | if len(tokens) > context_length:
228 | if truncate:
229 | tokens = tokens[:context_length]
230 | tokens[-1] = eot_token
231 | else:
232 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
233 | result[i, :len(tokens)] = torch.tensor(tokens)
234 |
235 | return result
236 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/dataset/Augmentation.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data._utils.collate import default_collate
2 | from dataset.transforms import *
3 | from dataset.random_erasing import RandomErasing
4 |
5 |
6 | class GroupTransform(object):
7 | def __init__(self, transform):
8 | self.worker = transform
9 |
10 | def __call__(self, img):
11 | img_group, label = img
12 | return [self.worker(img) for img in img_group], label
13 |
14 |
15 | class SplitLabel(object):
16 | def __init__(self, transform):
17 | self.worker = transform
18 |
19 | def __call__(self, img):
20 | img_group, label = img
21 | return self.worker(img_group), label
22 |
23 |
24 |
25 | def train_augmentation(input_size, flip=True):
26 | if flip:
27 | return torchvision.transforms.Compose([
28 | GroupRandomSizedCrop(input_size),
29 | GroupRandomHorizontalFlip(is_flow=False)])
30 | else:
31 | return torchvision.transforms.Compose([
32 | GroupRandomSizedCrop(input_size),
33 | # GroupMultiScaleCrop(input_size, [1, .875, .75, .66]),
34 | GroupRandomHorizontalFlip_sth()])
35 |
36 |
37 | def get_augmentation(training, input_size=224, config=None):
38 | input_mean = [0.48145466, 0.4578275, 0.40821073]
39 | input_std = [0.26862954, 0.26130258, 0.27577711]
40 | scale_size = 256 if input_size == 224 else input_size
41 |
42 | normalize = GroupNormalize(input_mean, input_std)
43 | #if 'something' in config.data.dataset:
44 | if 'SS' in config.data_set:
45 | if scale_size == 256:
46 | groupscale = GroupScale((256, 320))
47 | else:
48 | groupscale = GroupScale(int(scale_size))
49 | else:
50 | groupscale = GroupScale(int(scale_size))
51 |
52 |
53 | common = torchvision.transforms.Compose([
54 | Stack(roll=False),
55 | ToTorchFormatTensor(div=True),
56 | normalize])
57 |
58 | if training:
59 | auto_transform = None
60 | erase_transform = None
61 | if config.aa: ###!!! ss for True, k400 for False
62 | print('***'*20)
63 | print('use random_augment!!!')
64 | auto_transform = create_random_augment(
65 | input_size=256, # scale_size
66 | auto_augment="rand-m7-n4-mstd0.5-inc1",
67 | interpolation="bicubic"
68 | )
69 | # if config.rand_erase:
70 | # print('***'*20)
71 | # print('use Random_Erasing!!!')
72 | # erase_transform = RandomErasing(
73 | # 0.25,
74 | # mode='pixel',
75 | # max_count=1,
76 | # num_splits=1,
77 | # device="cpu",
78 | # )
79 |
80 | train_aug = train_augmentation(
81 | input_size,
82 | flip=False if 'SS' in config.data_set else True)
83 |
84 | unique = torchvision.transforms.Compose([
85 | groupscale,
86 | train_aug,
87 | GroupRandomGrayscale(p=0 if 'SS' in config.data_set else 0.2),
88 | ])
89 |
90 | if auto_transform is not None:
91 | print('=> ########## Using RandAugment!')
92 | unique = torchvision.transforms.Compose([
93 | SplitLabel(auto_transform), unique])
94 |
95 | if erase_transform is not None:
96 | print('=> ########## Using RandErasing!')
97 | return torchvision.transforms.Compose([
98 | unique, common, SplitLabel(erase_transform)
99 | ])
100 |
101 | return torchvision.transforms.Compose([unique, common])
102 |
103 | else:
104 | unique = torchvision.transforms.Compose([
105 | groupscale,
106 | GroupCenterCrop(input_size)])
107 | return torchvision.transforms.Compose([unique, common])
108 |
109 |
110 |
111 |
112 |
113 |
114 | def multiple_samples_collate(batch):
115 | """
116 | Collate function for repeated augmentation. Each instance in the batch has
117 | more than one sample.
118 | Args:
119 | batch (tuple or list): data batch to collate.
120 | Returns:
121 | (tuple): collated data batch.
122 | """
123 | inputs, labels = zip(*batch)
124 | # print(inputs, flush=True)
125 | # print(labels, flush=True)
126 | inputs = [item for sublist in inputs for item in sublist]
127 | labels = [item for sublist in labels for item in sublist]
128 | inputs, labels = (
129 | default_collate(inputs),
130 | default_collate(labels),
131 | )
132 | return inputs, labels
--------------------------------------------------------------------------------
/dataset/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/dataset/__init__.py
--------------------------------------------------------------------------------
/dataset/charades.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import decord
4 | import os
5 | import numpy as np
6 | from numpy.random import randint
7 | import io
8 | import pandas as pd
9 | import random
10 | from PIL import Image
11 | import math
12 | import copy
13 | import csv
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | # return int(self._data[1])
26 | return
27 |
28 | @property
29 | def label(self):
30 | return int(self._data[1][1:])
31 |
32 | @property
33 | def start_time(self):
34 | return float(self._data[2])
35 |
36 | @property
37 | def end_time(self):
38 | return float(self._data[3])
39 |
40 | @property
41 | def total_time(self):
42 | return float(self._data[4])
43 |
44 |
45 | class Video_dataset(data.Dataset):
46 | def __init__(self, root_path, list_file, labels_file,
47 | num_segments=1, modality='RGB', new_length=1,
48 | image_tmpl='img_{:05d}.jpg', transform=None,
49 | random_shift=True, test_mode=False,
50 | index_bias=1, dense_sample=False, test_clips=1,
51 | num_sample=1, fps=24, mode='train'):
52 |
53 | self.root_path = root_path
54 | self.list_file = list_file
55 | self.num_segments = num_segments
56 | self.modality = modality
57 | self.seg_length = new_length
58 | self.image_tmpl = image_tmpl
59 | self.transform = transform
60 | self.random_shift = random_shift
61 | self.test_mode = test_mode
62 | self.mode = mode
63 | self.loop=False
64 | self.index_bias = index_bias
65 | self.labels_file = labels_file
66 | self.sample_range = 128
67 | self.dense_sample = dense_sample # using dense sample as I3D
68 | self.test_clips = test_clips
69 | self.num_sample = num_sample
70 | if self.dense_sample:
71 | print('=> Using dense sample for the dataset...')
72 | if self.num_sample > 1:
73 | print('=> Using repeated augmentation...')
74 |
75 | if self.index_bias is None:
76 | if self.image_tmpl == "frame{:d}.jpg":
77 | self.index_bias = 0
78 | else:
79 | self.index_bias = 1
80 | self._parse_list()
81 | self.initialized = False
82 | self.fps = fps
83 |
84 | @property
85 | def total_length(self):
86 | return self.num_segments * self.seg_length
87 |
88 | @property
89 | def classes(self):
90 | with open(self.labels_file, "r") as f:
91 | classes_all = [line.strip('\n').split(' ', 1) for line in f.readlines()]
92 | return classes_all
93 |
94 | def _parse_list(self):
95 | # check the frame number is large >3:
96 | if not self.test_mode:
97 | # read csv
98 | with open(self.list_file, "r") as f:
99 | reader = csv.reader(f)
100 | tmp = [row for row in reader][1:]
101 |
102 | tmp = [t for t in tmp if float(t[3]) > float(t[2])]
103 | self.video_list = [VideoRecord(item) for item in tmp]
104 | else:
105 | with open(self.list_file, "r") as f:
106 | reader = csv.reader(f)
107 | self.video_list = [row for row in reader][1:]
108 | print('video number:%d' % (len(self.video_list)))
109 |
110 | def _sample_indices(self, video_list_len, record):
111 | if self.dense_sample:
112 | sample_pos = max(1, 1 + video_list_len - self.sample_range)
113 | interval = self.sample_range // self.num_segments
114 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
115 | base_offsets = np.arange(self.num_segments) * interval
116 | offsets = (base_offsets + start_idx) % video_list_len
117 | return np.array(offsets) + self.index_bias
118 | else:
119 | if video_list_len <= self.total_length:
120 | import torch.distributed as dist
121 | print('record_id=',record.path,
122 | 'start_time===',record.start_time,
123 | 'end_time==',record.end_time,
124 | 'total_time==',record.total_time,
125 | 'label===',record.label)
126 | print('rank===',dist.get_rank())
127 | if self.loop:
128 | return np.mod(np.arange(
129 | self.total_length) + randint(video_list_len // 2),
130 | video_list_len) + self.index_bias
131 | offsets = np.concatenate((
132 | np.arange(video_list_len),
133 | randint(video_list_len,
134 | size=self.total_length - video_list_len)))
135 | return np.sort(offsets) + self.index_bias
136 | offsets = list()
137 | ticks = [i * video_list_len // self.num_segments
138 | for i in range(self.num_segments + 1)]
139 |
140 | for i in range(self.num_segments):
141 | tick_len = ticks[i + 1] - ticks[i]
142 | tick = ticks[i]
143 | if tick_len >= self.seg_length:
144 | tick += randint(tick_len - self.seg_length + 1)
145 | offsets.extend([j for j in range(tick, tick + self.seg_length)])
146 | return np.array(offsets) + self.index_bias
147 |
148 |
149 | def _get_test_indices(self, video_list):
150 | if self.dense_sample:
151 | # multi-clip for dense sampling
152 | num_clips = self.test_clips
153 | sample_pos = max(0, len(video_list) - self.sample_range)
154 | interval = self.sample_range // self.num_segments
155 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)]
156 | base_offsets = np.arange(self.num_segments) * interval
157 | offsets = []
158 | for start_idx in start_list:
159 | offsets.extend((base_offsets + start_idx) % len(video_list))
160 | return np.array(offsets) + self.index_bias
161 | else:
162 | # multi-clip for uniform sampling
163 | num_clips = self.test_clips
164 | tick = len(video_list) / float(self.num_segments)
165 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int)
166 | offsets = []
167 | for start_idx in start_list.tolist():
168 | offsets += [
169 | int(start_idx + tick * x) % len(video_list)
170 | for x in range(self.num_segments)
171 | ]
172 | return np.array(offsets) + self.index_bias
173 |
174 |
175 | def _decord_decode(self, video_path):
176 | try:
177 | container = decord.VideoReader(video_path)
178 | except Exception as e:
179 | print("Failed to decode {} with exception: {}".format(
180 | video_path, e))
181 | return None
182 |
183 | return container
184 |
185 | def __getitem__(self, index):
186 | # decode frames to video_list
187 | if self.modality == 'video':
188 | _num_retries = 10
189 | for i_try in range(_num_retries):
190 | record = copy.deepcopy(self.video_list[index])
191 | directory = os.path.join(self.root_path, record.path)
192 | video_list = self._decord_decode(directory)
193 | # video_list = self._decord_pyav(directory)
194 | if video_list is None:
195 | print("Failed to decode video idx {} from {}; trial {}".format(
196 | index, directory, i_try)
197 | )
198 | index = random.randint(0, len(self.video_list))
199 | continue
200 | break
201 | else:
202 | record = self.video_list[index]
203 |
204 |
205 | if not self.test_mode: # train
206 | video_list = os.listdir(os.path.join(self.root_path, record.path))
207 | end_time = min(record.end_time, record.total_time)
208 | video_list_len = int(end_time * self.fps - record.start_time * self.fps)
209 |
210 | segment_indices = self._sample_indices(video_list_len, record)
211 | segment_indices = segment_indices + int(record.start_time * self.fps)
212 | return self.get(record, video_list, segment_indices)
213 | else: # test
214 | test_record = record
215 | video_list = os.listdir(os.path.join(self.root_path, test_record[0]))
216 | target = torch.IntTensor(157).zero_() #size=(157),全部为0,one-hot编码
217 |
218 | if test_record[9] != '':
219 | labels_mess = test_record[9].split(';')
220 | labels = [mess.split(' ')[0] for mess in labels_mess]
221 | for x in labels:
222 | target[int(x[1:])] = 1 #得到视频的类标签,然后转换成int,然后在one-hot相应位置赋值为1
223 | segment_indices = self._get_test_indices(video_list)
224 |
225 | if self.mode == 'validation':
226 | return self.val_get(test_record, video_list, segment_indices, target)
227 | else:
228 | return self.test_get(test_record, video_list, segment_indices, target)
229 |
230 |
231 | def _load_image(self, directory, idx):
232 | if self.modality == 'RGB':
233 | try:
234 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(directory, idx))).convert('RGB')]
235 | except Exception:
236 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(directory,idx)))
237 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(directory, 1))).convert('RGB')]
238 |
239 |
240 | def get(self, record, video_list, indices):
241 | images = list()
242 | for seg_ind in indices:
243 | p = int(seg_ind)
244 | if self.modality == 'video':
245 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
246 | else:
247 | seg_imgs = self._load_image(record.path, p)
248 | images.extend(seg_imgs)
249 | if p < len(video_list):
250 | p += 1
251 |
252 | if self.num_sample > 1:
253 | frame_list = []
254 | label_list = []
255 | for _ in range(self.num_sample):
256 | process_data, record_label = self.transform((images, record.label))
257 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
258 | frame_list.append(process_data)
259 | label_list.append(record_label)
260 | return frame_list, label_list, 0, 0
261 | else:
262 | process_data, record_label = self.transform((images, record.label))
263 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
264 | return process_data, record_label, 0, 0
265 |
266 | def val_get(self, record, video_list, indices, test_label):
267 | images = list()
268 | for seg_ind in indices:
269 | p = int(seg_ind)
270 | if self.modality == 'video':
271 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
272 | else:
273 | seg_imgs = self._load_image(record[0], p)
274 | images.extend(seg_imgs)
275 | if p < len(video_list):
276 | p += 1
277 |
278 | process_data, _ = self.transform((images, test_label))
279 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
280 | return process_data, test_label, 0, 0
281 |
282 |
283 | def test_get(self, record, video_list, indices, test_label):
284 | images = list()
285 | for seg_ind in indices:
286 | p = int(seg_ind)
287 | if self.modality == 'video':
288 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
289 | else:
290 | seg_imgs = self._load_image(record[0], p)
291 | images.extend(seg_imgs)
292 | if p < len(video_list):
293 | p += 1
294 |
295 | process_data, _ = self.transform((images, test_label))
296 | return process_data, test_label
297 |
298 |
299 | def __len__(self):
300 | return len(self.video_list)
301 |
--------------------------------------------------------------------------------
/dataset/kinetics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import decord
4 | import os
5 | import numpy as np
6 | from numpy.random import randint
7 | import io
8 | import pandas as pd
9 | import random
10 | from PIL import Image
11 | import math
12 | import copy
13 |
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | return int(self._data[1])
26 |
27 | @property
28 | def label(self):
29 | return int(self._data[-1])
30 |
31 |
32 | class Video_dataset(data.Dataset):
33 | def __init__(self, root_path, list_file, labels_file,
34 | num_segments=1, modality='RGB', new_length=1,
35 | image_tmpl='img_{:05d}.jpg', transform=None,
36 | random_shift=True, test_mode=False,
37 | index_bias=1, dense_sample=False, test_clips=3,
38 | num_sample=1):
39 |
40 | self.root_path = root_path
41 | self.list_file = list_file
42 | self.num_segments = num_segments
43 | self.modality = modality
44 | self.seg_length = new_length
45 | self.image_tmpl = image_tmpl
46 | self.transform = transform
47 | self.random_shift = random_shift
48 | self.test_mode = test_mode
49 | self.loop=False
50 | self.index_bias = index_bias
51 | self.labels_file = labels_file
52 | self.sample_range = 128
53 | self.dense_sample = dense_sample # using dense sample as I3D
54 | self.test_clips = test_clips
55 | self.num_sample = num_sample
56 | if self.dense_sample:
57 | print('=> Using dense sample for the dataset...')
58 | if self.num_sample > 1:
59 | print('=> Using repeated augmentation...')
60 |
61 | if self.index_bias is None:
62 | if self.image_tmpl == "frame{:d}.jpg":
63 | self.index_bias = 0
64 | else:
65 | self.index_bias = 1
66 | self._parse_list()
67 | self.initialized = False
68 |
69 | @property
70 | def total_length(self):
71 | return self.num_segments * self.seg_length
72 |
73 | @property
74 | def classes(self):
75 | classes_all = pd.read_csv(self.labels_file)
76 | return classes_all.values.tolist()
77 |
78 | def _parse_list(self):
79 | tmp = [x.strip().split(' ') for x in open(self.list_file)]#######for debug ###!!!
80 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config
81 | if not self.test_mode:
82 | tmp = [item for item in tmp if int(item[1]) >= 8]
83 | self.video_list = [VideoRecord(item) for item in tmp]
84 | print('video number:%d' % (len(self.video_list)))
85 |
86 | def _sample_indices(self, video_list):
87 | if self.dense_sample:
88 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
89 | interval = self.sample_range // self.num_segments
90 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
91 | base_offsets = np.arange(self.num_segments) * interval
92 | offsets = (base_offsets + start_idx) % len(video_list)
93 | return np.array(offsets) + self.index_bias
94 | else:
95 | seg_size = float(len(video_list) - 1) / self.num_segments
96 | offsets = []
97 | for i in range(self.num_segments):
98 | start = int(np.round(seg_size * i))
99 | end = int(np.round(seg_size * (i + 1)))
100 | offsets.append(random.randint(start, end))
101 |
102 | return np.array(offsets) + self.index_bias
103 |
104 | def _get_val_indices(self, video_list):
105 | if self.dense_sample:
106 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
107 | t_stride = self.sample_range // self.num_segments
108 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
109 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)]
110 | return np.array(offsets) + self.index_bias
111 | else:
112 | tick = len(video_list) / float(self.num_segments)
113 | offsets = [int(tick * x) % len(video_list) for x in range(self.num_segments)]
114 |
115 | return np.array(offsets) + self.index_bias
116 |
117 |
118 | def _get_test_indices(self, video_list):
119 | if self.dense_sample:
120 | # multi-clip for dense sampling
121 | num_clips = self.test_clips
122 | sample_pos = max(0, len(video_list) - self.sample_range)
123 | interval = self.sample_range // self.num_segments
124 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)]
125 | base_offsets = np.arange(self.num_segments) * interval
126 | offsets = []
127 | for start_idx in start_list:
128 | offsets.extend((base_offsets + start_idx) % len(video_list))
129 | return np.array(offsets) + self.index_bias
130 | else:
131 | # multi-clip for uniform sampling
132 | num_clips = self.test_clips
133 |
134 | tick = len(video_list) / float(self.num_segments)
135 | start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int)
136 | offsets = []
137 | for start_idx in start_list.tolist():
138 | offsets += [
139 | int(start_idx + tick * x) % len(video_list)
140 | for x in range(self.num_segments)
141 | ]
142 | return np.array(offsets) + self.index_bias
143 |
144 |
145 | def _decord_decode(self, video_path):
146 | try:
147 | container = decord.VideoReader(video_path)
148 | except Exception as e:
149 | print("Failed to decode {} with exception: {}".format(
150 | video_path, e))
151 | return None
152 |
153 | return container
154 |
155 | def __getitem__(self, index):
156 | # decode frames to video_list
157 | if self.modality == 'video':
158 | _num_retries = 10
159 | for i_try in range(_num_retries):
160 | record = copy.deepcopy(self.video_list[index])
161 | directory = os.path.join(self.root_path, record.path)
162 | video_list = self._decord_decode(directory)
163 | # video_list = self._decord_pyav(directory)
164 | if video_list is None:
165 | print("Failed to decode video idx {} from {}; trial {}".format(
166 | index, directory, i_try)
167 | )
168 | index = random.randint(0, len(self.video_list))
169 | continue
170 | break
171 | else:
172 | record = self.video_list[index]
173 | video_list = os.listdir(os.path.join(self.root_path, record.path))
174 |
175 | if not self.test_mode: # train/val
176 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list)
177 | else: # test
178 | segment_indices = self._get_test_indices(video_list)
179 |
180 | return self.get(record, video_list, segment_indices)
181 |
182 |
183 | def _load_image(self, directory, idx):
184 | if self.modality == 'RGB':
185 | try:
186 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
187 | except Exception:
188 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
189 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
190 |
191 |
192 | def get(self, record, video_list, indices):
193 | images = list()
194 | for seg_ind in indices:
195 | p = int(seg_ind)
196 | if self.modality == 'video':
197 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
198 | else:
199 | seg_imgs = self._load_image(record.path, p)
200 | images.extend(seg_imgs)
201 | if p < len(video_list):
202 | p += 1
203 |
204 | if self.test_mode:
205 | process_data, record_label = self.transform((images, record.label))
206 | return process_data, record_label
207 | else:
208 | if self.num_sample > 1:
209 | frame_list = []
210 | label_list = []
211 | for _ in range(self.num_sample):
212 | process_data, record_label = self.transform((images, record.label))
213 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
214 |
215 | frame_list.append(process_data)
216 | label_list.append(record_label)
217 | return frame_list, label_list, 0, 0
218 | else:
219 | process_data, record_label = self.transform((images, record.label))
220 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
221 | return process_data, record_label, 0, 0
222 |
223 | def __len__(self):
224 | return len(self.video_list)
225 |
--------------------------------------------------------------------------------
/dataset/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | This implementation is based on
3 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
4 | pulished under an Apache License 2.0.
5 | """
6 | import math
7 | import random
8 | import torch
9 |
10 |
11 | def _get_pixels(
12 | per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
13 | ):
14 | # NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
15 | # paths, flip the order so normal is run on CPU if this becomes a problem
16 | # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
17 | if per_pixel:
18 | return torch.empty(patch_size, dtype=dtype, device=device).normal_()
19 | elif rand_color:
20 | return torch.empty(
21 | (patch_size[0], 1, 1), dtype=dtype, device=device
22 | ).normal_()
23 | else:
24 | return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)
25 |
26 |
27 | class RandomErasing:
28 | """Randomly selects a rectangle region in an image and erases its pixels.
29 | 'Random Erasing Data Augmentation' by Zhong et al.
30 | See https://arxiv.org/pdf/1708.04896.pdf
31 | This variant of RandomErasing is intended to be applied to either a batch
32 | or single image tensor after it has been normalized by dataset mean and std.
33 | Args:
34 | probability: Probability that the Random Erasing operation will be performed.
35 | min_area: Minimum percentage of erased area wrt input image area.
36 | max_area: Maximum percentage of erased area wrt input image area.
37 | min_aspect: Minimum aspect ratio of erased area.
38 | mode: pixel color mode, one of 'const', 'rand', or 'pixel'
39 | 'const' - erase block is constant color of 0 for all channels
40 | 'rand' - erase block is same per-channel random (normal) color
41 | 'pixel' - erase block is per-pixel random (normal) color
42 | max_count: maximum number of erasing blocks per image, area per box is scaled by count.
43 | per-image count is randomly chosen between 1 and this value.
44 | """
45 |
46 | def __init__(
47 | self,
48 | probability=0.5,
49 | min_area=0.02,
50 | max_area=1 / 3,
51 | min_aspect=0.3,
52 | max_aspect=None,
53 | mode="const",
54 | min_count=1,
55 | max_count=None,
56 | num_splits=0,
57 | device="cuda",
58 | cube=True,
59 | ):
60 | self.probability = probability
61 | self.min_area = min_area
62 | self.max_area = max_area
63 | max_aspect = max_aspect or 1 / min_aspect
64 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
65 | self.min_count = min_count
66 | self.max_count = max_count or min_count
67 | self.num_splits = num_splits
68 | mode = mode.lower()
69 | self.rand_color = False
70 | self.per_pixel = False
71 | self.cube = cube
72 | if mode == "rand":
73 | self.rand_color = True # per block random normal
74 | elif mode == "pixel":
75 | self.per_pixel = True # per pixel random normal
76 | else:
77 | assert not mode or mode == "const"
78 | self.device = device
79 |
80 | def _erase(self, img, chan, img_h, img_w, dtype):
81 | if random.random() > self.probability:
82 | return
83 | area = img_h * img_w
84 | count = (
85 | self.min_count
86 | if self.min_count == self.max_count
87 | else random.randint(self.min_count, self.max_count)
88 | )
89 | for _ in range(count):
90 | for _ in range(10):
91 | target_area = (
92 | random.uniform(self.min_area, self.max_area) * area / count
93 | )
94 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
95 | h = int(round(math.sqrt(target_area * aspect_ratio)))
96 | w = int(round(math.sqrt(target_area / aspect_ratio)))
97 | if w < img_w and h < img_h:
98 | top = random.randint(0, img_h - h)
99 | left = random.randint(0, img_w - w)
100 | img[:, top : top + h, left : left + w] = _get_pixels(
101 | self.per_pixel,
102 | self.rand_color,
103 | (chan, h, w),
104 | dtype=dtype,
105 | device=self.device,
106 | )
107 | break
108 |
109 | def _erase_cube(
110 | self,
111 | img,
112 | batch_start,
113 | batch_size,
114 | chan,
115 | img_h,
116 | img_w,
117 | dtype,
118 | ):
119 | if random.random() > self.probability:
120 | return
121 | area = img_h * img_w
122 | count = (
123 | self.min_count
124 | if self.min_count == self.max_count
125 | else random.randint(self.min_count, self.max_count)
126 | )
127 | for _ in range(count):
128 | for _ in range(100):
129 | target_area = (
130 | random.uniform(self.min_area, self.max_area) * area / count
131 | )
132 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
133 | h = int(round(math.sqrt(target_area * aspect_ratio)))
134 | w = int(round(math.sqrt(target_area / aspect_ratio)))
135 | if w < img_w and h < img_h:
136 | top = random.randint(0, img_h - h)
137 | left = random.randint(0, img_w - w)
138 | for i in range(batch_start, batch_size):
139 | img_instance = img[i]
140 | img_instance[
141 | :, top : top + h, left : left + w
142 | ] = _get_pixels(
143 | self.per_pixel,
144 | self.rand_color,
145 | (chan, h, w),
146 | dtype=dtype,
147 | device=self.device,
148 | )
149 | break
150 |
151 | def __call__(self, input):
152 | if len(input.size()) == 3:
153 | self._erase(input, *input.size(), input.dtype)
154 | else:
155 | batch_size, chan, img_h, img_w = input.size()
156 | # skip first slice of batch if num_splits is set (for clean portion of samples)
157 | batch_start = (
158 | batch_size // self.num_splits if self.num_splits > 1 else 0
159 | )
160 | if self.cube:
161 | self._erase_cube(
162 | input,
163 | batch_start,
164 | batch_size,
165 | chan,
166 | img_h,
167 | img_w,
168 | input.dtype,
169 | )
170 | else:
171 | for i in range(batch_start, batch_size):
172 | self._erase(input[i], chan, img_h, img_w, input.dtype)
173 | return input
174 |
--------------------------------------------------------------------------------
/dataset/sth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import decord
4 | import os
5 | import numpy as np
6 | from numpy.random import randint
7 | import io
8 | import pandas as pd
9 | import random
10 | from PIL import Image
11 | import math
12 | import copy
13 |
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | return int(self._data[1])
26 |
27 | @property
28 | def label(self):
29 | return int(self._data[-1])
30 |
31 |
32 | class Video_dataset(data.Dataset):
33 | def __init__(self, root_path, list_file, labels_file,
34 | num_segments=1, modality='RGB', new_length=1,
35 | image_tmpl='img_{:05d}.jpg', transform=None,
36 | random_shift=True, test_mode=False,
37 | index_bias=1, dense_sample=False, test_clips=3,
38 | num_sample=1):
39 |
40 | self.root_path = root_path
41 | self.list_file = list_file
42 | self.num_segments = num_segments
43 | self.modality = modality
44 | self.seg_length = new_length
45 | self.image_tmpl = image_tmpl
46 | self.transform = transform
47 | self.random_shift = random_shift
48 | self.test_mode = test_mode
49 | self.loop=False
50 | self.index_bias = index_bias
51 | self.labels_file = labels_file
52 | self.sample_range = 128
53 | self.dense_sample = dense_sample # using dense sample as I3D
54 | self.test_clips = test_clips
55 | self.num_sample = num_sample
56 | if self.dense_sample:
57 | print('=> Using dense sample for the dataset...')
58 | if self.num_sample > 1:
59 | print('=> Using repeated augmentation...')
60 |
61 | if self.index_bias is None:
62 | if self.image_tmpl == "frame{:d}.jpg":
63 | self.index_bias = 0
64 | else:
65 | self.index_bias = 1
66 | self._parse_list()
67 | self.initialized = False
68 |
69 | @property
70 | def total_length(self):
71 | return self.num_segments * self.seg_length
72 |
73 | @property
74 | def classes(self):
75 | classes_all = pd.read_csv(self.labels_file)
76 | return classes_all.values.tolist()
77 |
78 | def _parse_list(self):
79 | # check the frame number is large >3:
80 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
81 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config
82 | if not self.test_mode:
83 | tmp = [item for item in tmp if int(item[1]) >= 8]
84 | self.video_list = [VideoRecord(item) for item in tmp]
85 | print('video number:%d' % (len(self.video_list)))
86 |
87 | def _sample_indices(self, video_list):
88 | if self.dense_sample:
89 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
90 | interval = self.sample_range // self.num_segments
91 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
92 | base_offsets = np.arange(self.num_segments) * interval
93 | offsets = (base_offsets + start_idx) % len(video_list)
94 | return np.array(offsets) + self.index_bias
95 | else:
96 | if len(video_list) <= self.total_length:
97 | if self.loop:
98 | return np.mod(np.arange(
99 | self.total_length) + randint(len(video_list) // 2),
100 | len(video_list)) + self.index_bias
101 | offsets = np.concatenate((
102 | np.arange(len(video_list)),
103 | randint(len(video_list),
104 | size=self.total_length - len(video_list))))
105 | return np.sort(offsets) + self.index_bias
106 | offsets = list()
107 | ticks = [i * len(video_list) // self.num_segments
108 | for i in range(self.num_segments + 1)]
109 |
110 | for i in range(self.num_segments):
111 | tick_len = ticks[i + 1] - ticks[i]
112 | tick = ticks[i]
113 | if tick_len >= self.seg_length:
114 | tick += randint(tick_len - self.seg_length + 1)
115 | offsets.extend([j for j in range(tick, tick + self.seg_length)])
116 | return np.array(offsets) + self.index_bias
117 |
118 | def _get_val_indices(self, video_list):
119 | if self.dense_sample:
120 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
121 | t_stride = self.sample_range // self.num_segments
122 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
123 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)]
124 | return np.array(offsets) + self.index_bias
125 | else:
126 | seg_size = float(len(video_list) - 1) / self.num_segments
127 | offsets = []
128 | for i in range(self.num_segments):
129 | start = int(np.round(seg_size * i))
130 | frame_index = start + int(seg_size / 2)
131 | offsets.append(frame_index)
132 |
133 | return np.array(offsets) + self.index_bias
134 |
135 |
136 | def _get_test_indices(self, video_list):
137 | if self.dense_sample:
138 | # multi-clip for dense sampling
139 | num_clips = self.test_clips
140 | sample_pos = max(0, len(video_list) - self.sample_range)
141 | interval = self.sample_range // self.num_segments
142 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)]
143 | base_offsets = np.arange(self.num_segments) * interval
144 | offsets = []
145 | for start_idx in start_list:
146 | offsets.extend((base_offsets + start_idx) % len(video_list))
147 | return np.array(offsets) + self.index_bias
148 | else:
149 | # multi-clip for uniform sampling
150 | num_clips = self.test_clips
151 |
152 | seg_size = float(len(video_list) - 1) / self.num_segments
153 | seq = []
154 | duration = seg_size / (num_clips + 1)
155 | for temporal_sample_index in range(num_clips):
156 | for i in range(self.num_segments):
157 | start = int(np.round(seg_size * i))
158 | frame_index = start + int(duration * (temporal_sample_index + 1))
159 | seq.append(frame_index)
160 | return np.array(seq) + self.index_bias
161 |
162 |
163 | def _decord_decode(self, video_path):
164 | try:
165 | container = decord.VideoReader(video_path)
166 | except Exception as e:
167 | print("Failed to decode {} with exception: {}".format(
168 | video_path, e))
169 | return None
170 |
171 | return container
172 |
173 | def __getitem__(self, index):
174 | # decode frames to video_list
175 | if self.modality == 'video':
176 | _num_retries = 10
177 | for i_try in range(_num_retries):
178 | record = copy.deepcopy(self.video_list[index])
179 | directory = os.path.join(self.root_path, record.path)
180 | video_list = self._decord_decode(directory)
181 | # video_list = self._decord_pyav(directory)
182 | if video_list is None:
183 | print("Failed to decode video idx {} from {}; trial {}".format(
184 | index, directory, i_try)
185 | )
186 | index = random.randint(0, len(self.video_list))
187 | continue
188 | break
189 | else:
190 | record = self.video_list[index]
191 | video_list = os.listdir(os.path.join(self.root_path, record.path))
192 |
193 | if not self.test_mode: # train/val
194 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list)
195 | else: # test
196 | segment_indices = self._get_test_indices(video_list)
197 |
198 | return self.get(record, video_list, segment_indices)
199 |
200 |
201 | def _load_image(self, directory, idx):
202 | if self.modality == 'RGB':
203 | try:
204 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
205 | except Exception:
206 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
207 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
208 |
209 |
210 | def get(self, record, video_list, indices):
211 | images = list()
212 | for seg_ind in indices:
213 | p = int(seg_ind)
214 | if self.modality == 'video':
215 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
216 | else:
217 | seg_imgs = self._load_image(record.path, p)
218 | images.extend(seg_imgs)
219 | if p < len(video_list):
220 | p += 1
221 | if self.num_sample > 1:
222 | frame_list = []
223 | label_list = []
224 | for _ in range(self.num_sample):
225 | process_data, record_label = self.transform((images, record.label))
226 | frame_list.append(process_data)
227 | label_list.append(record_label)
228 | return frame_list, label_list
229 | else:
230 | process_data, record_label = self.transform((images, record.label))
231 | return process_data, record_label
232 |
233 | def __len__(self):
234 | return len(self.video_list)
235 |
--------------------------------------------------------------------------------
/dataset/sthvideo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 | import decord
4 | import os
5 | import numpy as np
6 | from numpy.random import randint
7 | import io
8 | import pandas as pd
9 | import random
10 | from PIL import Image
11 | import math
12 | import copy
13 |
14 |
15 | class VideoRecord(object):
16 | def __init__(self, row):
17 | self._data = row
18 |
19 | @property
20 | def path(self):
21 | return self._data[0]
22 |
23 | @property
24 | def num_frames(self):
25 | return int(self._data[1])
26 |
27 | @property
28 | def label(self):
29 | return int(self._data[-1])
30 |
31 |
32 | class Video_dataset(data.Dataset):
33 | def __init__(self, root_path, list_file, labels_file,
34 | num_segments=1, modality='RGB', new_length=1,
35 | image_tmpl='img_{:05d}.jpg', transform=None,
36 | random_shift=True, test_mode=False,
37 | index_bias=1, dense_sample=False, test_clips=3,
38 | num_sample=1):
39 |
40 | self.root_path = root_path
41 | self.list_file = list_file
42 | self.num_segments = num_segments
43 | self.modality = modality
44 | self.seg_length = new_length
45 | self.image_tmpl = image_tmpl
46 | self.transform = transform
47 | self.random_shift = random_shift
48 | self.test_mode = test_mode
49 | self.loop=False
50 | self.index_bias = index_bias
51 | self.labels_file = labels_file
52 | self.sample_range = 128
53 | self.dense_sample = dense_sample # using dense sample as I3D
54 | self.test_clips = test_clips
55 | self.num_sample = num_sample
56 | if self.dense_sample:
57 | print('=> Using dense sample for the dataset...')
58 | if self.num_sample > 1:
59 | print('=> Using repeated augmentation...')
60 |
61 | if self.index_bias is None:
62 | if self.image_tmpl == "frame{:d}.jpg":
63 | self.index_bias = 0
64 | else:
65 | self.index_bias = 1
66 | self._parse_list()
67 | self.initialized = False
68 |
69 | @property
70 | def total_length(self):
71 | return self.num_segments * self.seg_length
72 |
73 | @property
74 | def classes(self):
75 | classes_all = pd.read_csv(self.labels_file)
76 | return classes_all.values.tolist()
77 |
78 | def _parse_list(self):
79 | # check the frame number is large >3:
80 | tmp = [x.strip().split(' ') for x in open(self.list_file)]
81 | if len(tmp[0]) == 3: # skip remove_missin for decording "raw_video label" type dataset_config
82 | if not self.test_mode:
83 | tmp = [item for item in tmp if int(item[1]) >= 8]
84 | self.video_list = [VideoRecord(item) for item in tmp]
85 | print('video number:%d' % (len(self.video_list)))
86 |
87 | def _sample_indices(self, video_list):
88 | if self.dense_sample:
89 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
90 | interval = self.sample_range // self.num_segments
91 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
92 | base_offsets = np.arange(self.num_segments) * interval
93 | offsets = (base_offsets + start_idx) % len(video_list)
94 | return np.array(offsets) + self.index_bias
95 | else:
96 | if len(video_list) <= self.total_length:
97 | if self.loop:
98 | return np.mod(np.arange(
99 | self.total_length) + randint(len(video_list) // 2),
100 | len(video_list)) + self.index_bias
101 | offsets = np.concatenate((
102 | np.arange(len(video_list)),
103 | randint(len(video_list),
104 | size=self.total_length - len(video_list))))
105 | return np.sort(offsets) + self.index_bias
106 | offsets = list()
107 | ticks = [i * len(video_list) // self.num_segments
108 | for i in range(self.num_segments + 1)]
109 |
110 | for i in range(self.num_segments):
111 | tick_len = ticks[i + 1] - ticks[i]
112 | tick = ticks[i]
113 | if tick_len >= self.seg_length:
114 | tick += randint(tick_len - self.seg_length + 1)
115 | offsets.extend([j for j in range(tick, tick + self.seg_length)])
116 |
117 | return np.array(offsets) + self.index_bias
118 |
119 | def _get_val_indices(self, video_list):
120 | if self.dense_sample:
121 | sample_pos = max(1, 1 + len(video_list) - self.sample_range)
122 | t_stride = self.sample_range // self.num_segments
123 | start_idx = 0 if sample_pos == 1 else np.random.randint(0, sample_pos - 1)
124 | offsets = [(idx * t_stride + start_idx) % len(video_list) for idx in range(self.num_segments)]
125 | return np.array(offsets) + self.index_bias
126 | else:
127 | seg_size = float(len(video_list) - 1) / self.num_segments
128 | offsets = []
129 | for i in range(self.num_segments):
130 | start = int(np.round(seg_size * i))
131 | frame_index = start + int(seg_size / 2)
132 | offsets.append(frame_index)
133 |
134 | return np.array(offsets) + self.index_bias
135 |
136 |
137 | def _get_test_indices(self, video_list):
138 | if self.dense_sample:
139 | # multi-clip for dense sampling
140 | num_clips = self.test_clips
141 | sample_pos = max(0, len(video_list) - self.sample_range)
142 | interval = self.sample_range // self.num_segments
143 | start_list = [clip_idx * math.floor(sample_pos / (num_clips -1)) for clip_idx in range(num_clips)]
144 | base_offsets = np.arange(self.num_segments) * interval
145 | offsets = []
146 | for start_idx in start_list:
147 | offsets.extend((base_offsets + start_idx) % len(video_list))
148 | return np.array(offsets) + self.index_bias
149 | else:
150 | # multi-clip for uniform sampling
151 | num_clips = self.test_clips
152 |
153 | # tick = len(video_list) / float(self.num_segments)
154 | # start_list = np.linspace(0, tick - 1, num=num_clips, dtype=int)
155 | # offsets = []
156 | # for start_idx in start_list.tolist():
157 | # offsets += [
158 | # int(start_idx + tick * x) % len(video_list)
159 | # for x in range(self.num_segments)
160 | # ]
161 | # return np.array(offsets) + self.index_bias
162 |
163 |
164 | ############ ATM implementation
165 | seg_size = float(len(video_list) - 1) / self.num_segments
166 | seq = []
167 | duration = seg_size / (num_clips + 1)
168 | for temporal_sample_index in range(num_clips):
169 | for i in range(self.num_segments):
170 | start = int(np.round(seg_size * i))
171 | frame_index = start + int(duration * (temporal_sample_index + 1))
172 | seq.append(frame_index)
173 | return np.array(seq) + self.index_bias
174 |
175 |
176 | def _decord_decode(self, video_path):
177 | try:
178 | container = decord.VideoReader(video_path)
179 | except Exception as e:
180 | print("Failed to decode {} with exception: {}".format(
181 | video_path, e))
182 | return None
183 |
184 | return container
185 |
186 | def __getitem__(self, index):
187 | # decode frames to video_list
188 | if self.modality == 'video':
189 | _num_retries = 10
190 | for i_try in range(_num_retries):
191 | record = copy.deepcopy(self.video_list[index])
192 | directory = os.path.join(self.root_path, record.path)
193 | video_list = self._decord_decode(directory)
194 | # video_list = self._decord_pyav(directory)
195 | if video_list is None:
196 | print("Failed to decode video idx {} from {}; trial {}".format(
197 | index, directory, i_try)
198 | )
199 | index = random.randint(0, len(self.video_list))
200 | continue
201 | break
202 | else:
203 | record = self.video_list[index]
204 | video_list = os.listdir(os.path.join(self.root_path, record.path))
205 |
206 | if not self.test_mode: # train/val
207 | segment_indices = self._sample_indices(video_list) if self.random_shift else self._get_val_indices(video_list)
208 | else: # test
209 | segment_indices = self._get_test_indices(video_list)
210 |
211 | return self.get(record, video_list, segment_indices)
212 |
213 |
214 | def _load_image(self, directory, idx):
215 | if self.modality == 'RGB':
216 | try:
217 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(idx))).convert('RGB')]
218 | except Exception:
219 | print('error loading image:', os.path.join(self.root_path, directory, self.image_tmpl.format(idx)))
220 | return [Image.open(os.path.join(self.root_path, directory, self.image_tmpl.format(1))).convert('RGB')]
221 |
222 |
223 | def get(self, record, video_list, indices):
224 | images = list()
225 | for seg_ind in indices:
226 | p = int(seg_ind)
227 | if self.modality == 'video':
228 | seg_imgs = [Image.fromarray(video_list[p - 1].asnumpy()).convert('RGB')]
229 | else:
230 | seg_imgs = self._load_image(record.path, p)
231 | images.extend(seg_imgs)
232 | if p < len(video_list):
233 | p += 1
234 | if self.num_sample > 1:
235 | frame_list = []
236 | label_list = []
237 | for _ in range(self.num_sample):
238 | process_data, record_label = self.transform((images, record.label))
239 |
240 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
241 | frame_list.append(process_data)
242 | label_list.append(record_label)
243 | #print('frame_list',len(frame_list), frame_list[0].shape)
244 | return frame_list, label_list, 0, 0
245 | else:
246 | process_data, record_label = self.transform((images, record.label))
247 |
248 | process_data = process_data.view((self.num_segments,3)+process_data.size()[-2:]) ###!!!
249 |
250 | return process_data, record_label, 0, 0
251 |
252 | def __len__(self):
253 | return len(self.video_list)
254 |
--------------------------------------------------------------------------------
/engine_for_finetuning.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import math
4 | import sys
5 | from typing import Iterable, Optional
6 | import torch
7 | from dataset.mixup import Mixup
8 | from timm.utils import accuracy, ModelEma
9 | import utils
10 | from scipy.special import softmax ###!!!
11 | from utils import AverageMeter, gather_labels
12 | import torch.distributed as dist
13 |
14 | class AllGather(torch.autograd.Function):
15 | """An autograd function that performs allgather on a tensor."""
16 |
17 | @staticmethod
18 | def forward(ctx, tensor):
19 | output = [torch.empty_like(tensor) for _ in range(dist.get_world_size())]
20 | torch.distributed.all_gather(output, tensor)
21 | ctx.rank = dist.get_rank()
22 | ctx.batch_size = tensor.shape[0]
23 | return torch.cat(output, dim=0)
24 |
25 | @staticmethod
26 | def backward(ctx, grad_output):
27 | return (
28 | grad_output[ctx.batch_size * ctx.rank : ctx.batch_size * (ctx.rank + 1)],
29 | None,
30 | )
31 |
32 | allgather = AllGather.apply
33 |
34 | def train_class_batch(model, samples, target, criterion):
35 | outputs = model(samples)
36 | loss = criterion(outputs, target)
37 | return loss, outputs
38 |
39 |
40 | def get_loss_scale_for_deepspeed(model):
41 | optimizer = model.optimizer
42 | return optimizer.loss_scale if hasattr(optimizer, "loss_scale") else optimizer.cur_scale
43 |
44 |
45 | def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
46 | data_loader: Iterable, optimizer: torch.optim.Optimizer,
47 | device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
48 | model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None, log_writer=None,
49 | start_steps=None, lr_schedule_values=None, wd_schedule_values=None,
50 | num_training_steps_per_epoch=None, update_freq=None):
51 | model.train(True)
52 | metric_logger = utils.MetricLogger(delimiter=" ")
53 | metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
54 | metric_logger.add_meter('min_lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
55 | header = 'Epoch: [{}]'.format(epoch)
56 | print_freq = 10
57 |
58 | if loss_scaler is None:
59 | model.zero_grad()
60 | model.micro_steps = 0
61 | else:
62 | optimizer.zero_grad()
63 |
64 | #import pdb;pdb.set_trace()
65 | for data_iter_step, (samples, targets, _, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
66 | #import pdb;pdb.set_trace()
67 |
68 | if len(samples) == 2:
69 | samples = torch.cat([samples[0], samples[1]], 0) ###!!!
70 | targets = torch.cat([targets[0], targets[1]], 0) ###!!!
71 |
72 | step = data_iter_step // update_freq
73 | if step >= num_training_steps_per_epoch:
74 | continue
75 | it = start_steps + step # global training iteration
76 | # Update LR & WD for the first acc
77 | if lr_schedule_values is not None or wd_schedule_values is not None and data_iter_step % update_freq == 0:
78 | for i, param_group in enumerate(optimizer.param_groups):
79 | if lr_schedule_values is not None:
80 | param_group["lr"] = lr_schedule_values[it] * param_group["lr_scale"]
81 | if wd_schedule_values is not None and param_group["weight_decay"] > 0:
82 | param_group["weight_decay"] = wd_schedule_values[it]
83 |
84 | samples = samples.to(device, non_blocking=True)
85 | targets = targets.to(device, non_blocking=True)
86 |
87 | if mixup_fn is not None:
88 | samples, targets = mixup_fn(samples, targets)
89 |
90 | if loss_scaler is None:
91 | samples = samples.half()
92 | loss, output = train_class_batch(
93 | model, samples, targets, criterion)
94 | else:
95 | with torch.cuda.amp.autocast():
96 | loss, output = train_class_batch(
97 | model, samples, targets, criterion)
98 |
99 | loss_value = loss.item()
100 |
101 | if not math.isfinite(loss_value):
102 | print("Loss is {}, stopping training".format(loss_value))
103 | sys.exit(1)
104 |
105 | if loss_scaler is None:
106 | loss /= update_freq
107 | model.backward(loss)
108 | model.step()
109 |
110 | if (data_iter_step + 1) % update_freq == 0:
111 | # model.zero_grad()
112 | # Deepspeed will call step() & model.zero_grad() automatic
113 | if model_ema is not None:
114 | model_ema.update(model)
115 | grad_norm = None
116 | loss_scale_value = get_loss_scale_for_deepspeed(model)
117 | else:
118 | # this attribute is added by timm on one optimizer (adahessian)
119 | is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
120 | loss /= update_freq
121 | grad_norm = loss_scaler(loss, optimizer, clip_grad=max_norm,
122 | parameters=model.parameters(), create_graph=is_second_order,
123 | update_grad=(data_iter_step + 1) % update_freq == 0)
124 | if (data_iter_step + 1) % update_freq == 0:
125 | optimizer.zero_grad()
126 | if model_ema is not None:
127 | model_ema.update(model)
128 | loss_scale_value = loss_scaler.state_dict()["scale"]
129 |
130 | torch.cuda.synchronize()
131 |
132 | if mixup_fn is None:
133 | class_acc = (output.max(-1)[-1] == targets).float().mean()
134 | else:
135 | class_acc = None
136 | metric_logger.update(loss=loss_value)
137 | metric_logger.update(class_acc=class_acc)
138 | metric_logger.update(loss_scale=loss_scale_value)
139 | min_lr = 10.
140 | max_lr = 0.
141 | for group in optimizer.param_groups:
142 | min_lr = min(min_lr, group["lr"])
143 | max_lr = max(max_lr, group["lr"])
144 |
145 | metric_logger.update(lr=max_lr)
146 | metric_logger.update(min_lr=min_lr)
147 | weight_decay_value = None
148 | for group in optimizer.param_groups:
149 | if group["weight_decay"] > 0:
150 | weight_decay_value = group["weight_decay"]
151 | metric_logger.update(weight_decay=weight_decay_value)
152 | metric_logger.update(grad_norm=grad_norm)
153 |
154 | if log_writer is not None:
155 | log_writer.update(loss=loss_value, head="loss")
156 | log_writer.update(class_acc=class_acc, head="loss")
157 | log_writer.update(loss_scale=loss_scale_value, head="opt")
158 | log_writer.update(lr=max_lr, head="opt")
159 | log_writer.update(min_lr=min_lr, head="opt")
160 | log_writer.update(weight_decay=weight_decay_value, head="opt")
161 | log_writer.update(grad_norm=grad_norm, head="opt")
162 |
163 | log_writer.set_step()
164 |
165 | # gather the stats from all processes
166 | metric_logger.synchronize_between_processes()
167 | print("Averaged stats:", metric_logger)
168 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
169 |
170 |
171 | @torch.no_grad()
172 | def validation_one_epoch(data_loader, model, device):
173 | criterion = torch.nn.CrossEntropyLoss()
174 |
175 | metric_logger = utils.MetricLogger(delimiter=" ")
176 | header = 'Val:'
177 |
178 | # switch to evaluation mode
179 | model.eval()
180 |
181 | for batch in metric_logger.log_every(data_loader, 10, header):
182 | videos = batch[0]
183 | target = batch[1]
184 | videos = videos.to(device, non_blocking=True)
185 | target = target.to(device, non_blocking=True)
186 |
187 | # compute output
188 | with torch.cuda.amp.autocast():
189 | output = model(videos)
190 | #import pdb;pdb.set_trace()
191 | loss = criterion(output, target)
192 |
193 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
194 |
195 | batch_size = videos.shape[0]
196 | metric_logger.update(loss=loss.item())
197 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
198 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
199 | # gather the stats from all processes
200 | metric_logger.synchronize_between_processes()
201 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
202 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
203 |
204 |
205 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
206 |
207 | @torch.no_grad()
208 | def charades_validation_one_epoch(data_loader, model, device, dataset = 'anet'):
209 | #epoch, val_loader, device, model, config, text_embedding, logger):
210 | mAP = AverageMeter()
211 | model.eval()
212 | from torchnet import meter
213 | maper = meter.mAPMeter()
214 |
215 | with torch.no_grad():
216 | for i, batch in enumerate(data_loader):
217 |
218 | videos = batch[0]
219 | labels = batch[1]
220 | videos = videos.to(device, non_blocking=True)
221 | labels = labels.to(device, non_blocking=True)
222 |
223 | # compute output
224 | with torch.cuda.amp.autocast():
225 | output = model(videos)
226 |
227 | ###!!!
228 | if dataset == 'anet':
229 | target = torch.IntTensor(labels.shape[0],200).zero_() #全部为0,one-hot编码
230 | for b, label in enumerate(labels):
231 | target[b][int(label)] = 1
232 | else:
233 | target = labels
234 |
235 | output = allgather(output)
236 | target = gather_labels(target)
237 |
238 | maper.add(output, target)
239 | mAP.update(maper.value().numpy(),target.size(0))
240 |
241 | if i % 10 == 0:
242 | print('Test: [{0}/{1},mAP:{map:.3f}]\t'.format(i, len(data_loader), map=mAP.avg * 100))
243 |
244 | print('Testing Results mAP === {mAP_result:.3f}'.format(mAP_result=mAP.avg * 100))
245 | return mAP.avg * 100
246 |
247 |
248 | @torch.no_grad()
249 | def final_test(data_loader, model, device, file):
250 | criterion = torch.nn.CrossEntropyLoss()
251 |
252 | metric_logger = utils.MetricLogger(delimiter=" ")
253 | header = 'Test:'
254 |
255 | # switch to evaluation mode
256 | model.eval()
257 | final_result = []
258 |
259 | for batch in metric_logger.log_every(data_loader, 10, header):
260 | videos = batch[0]
261 | target = batch[1]
262 | ids = batch[2]
263 | chunk_nb = batch[3]
264 | split_nb = batch[4]
265 | videos = videos.to(device, non_blocking=True)
266 | target = target.to(device, non_blocking=True)
267 |
268 | # compute output
269 | with torch.cuda.amp.autocast():
270 | output = model(videos)
271 | loss = criterion(output, target)
272 |
273 | for i in range(output.size(0)):
274 | string = "{} {} {} {} {}\n".format(ids[i], \
275 | str(output.data[i].cpu().numpy().tolist()), \
276 | str(int(target[i].cpu().numpy())), \
277 | str(int(chunk_nb[i].cpu().numpy())), \
278 | str(int(split_nb[i].cpu().numpy())))
279 | final_result.append(string)
280 |
281 | acc1, acc5 = accuracy(output, target, topk=(1, 5))
282 |
283 | batch_size = videos.shape[0]
284 | metric_logger.update(loss=loss.item())
285 | metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
286 | metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
287 |
288 | if not os.path.exists(file):
289 | os.mknod(file)
290 | with open(file, 'w') as f:
291 | f.write("{}, {}\n".format(acc1, acc5))
292 | for line in final_result:
293 | f.write(line)
294 | # gather the stats from all processes
295 | metric_logger.synchronize_between_processes()
296 | print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
297 | .format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
298 |
299 | return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
300 |
301 |
302 | def merge(eval_path, num_tasks):
303 | dict_feats = {}
304 | dict_label = {}
305 | dict_pos = {}
306 | print("Reading individual output files")
307 |
308 | for x in range(num_tasks):
309 | file = os.path.join(eval_path, str(x) + '.txt')
310 | lines = open(file, 'r').readlines()[1:]
311 | for line in lines:
312 | line = line.strip()
313 | name = line.split('[')[0]
314 | label = line.split(']')[1].split(' ')[1]
315 | chunk_nb = line.split(']')[1].split(' ')[2]
316 | split_nb = line.split(']')[1].split(' ')[3]
317 | data = np.fromstring(line.split('[')[1].split(']')[0], dtype=np.float, sep=',')
318 | data = softmax(data) ###!!!
319 | if not name in dict_feats:
320 | dict_feats[name] = []
321 | dict_label[name] = 0
322 | dict_pos[name] = []
323 | if chunk_nb + split_nb in dict_pos[name]:
324 | continue
325 | dict_feats[name].append(data)
326 | dict_pos[name].append(chunk_nb + split_nb)
327 | dict_label[name] = label
328 | print("Computing final results")
329 |
330 | input_lst = []
331 | print(len(dict_feats))
332 | for i, item in enumerate(dict_feats):
333 | input_lst.append([i, item, dict_feats[item], dict_label[item]])
334 | from multiprocessing import Pool
335 | p = Pool(64)
336 | ans = p.map(compute_video, input_lst)
337 | top1 = [x[1] for x in ans]
338 | top5 = [x[2] for x in ans]
339 | pred = [x[0] for x in ans]
340 | label = [x[3] for x in ans]
341 | final_top1 ,final_top5 = np.mean(top1), np.mean(top5)
342 | return final_top1*100 ,final_top5*100
343 |
344 | def compute_video(lst):
345 | i, video_id, data, label = lst
346 | feat = [x for x in data]
347 | feat = np.mean(feat, axis=0)
348 | pred = np.argmax(feat)
349 | top1 = (int(pred) == int(label)) * 1.0
350 | top5 = (int(label) in np.argsort(-feat)[:5]) * 1.0
351 | return [pred, top1, top5, int(label)]
352 |
--------------------------------------------------------------------------------
/eva_clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2 | from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer
3 | from .factory import list_models, add_model_config, get_model_config, load_checkpoint
4 | from .loss import ClipLoss
5 | from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6 | convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7 | from .openai import load_openai_model, list_openai_models
8 | from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9 | get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
10 | from .tokenizer import SimpleTokenizer, tokenize
11 | from .transform import image_transform
--------------------------------------------------------------------------------
/eva_clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/eva_clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/eva_clip/constants.py:
--------------------------------------------------------------------------------
1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
3 |
--------------------------------------------------------------------------------
/eva_clip/hf_configs.py:
--------------------------------------------------------------------------------
1 | # HF architecture dict:
2 | arch_dict = {
3 | # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
4 | "roberta": {
5 | "config_names": {
6 | "context_length": "max_position_embeddings",
7 | "vocab_size": "vocab_size",
8 | "width": "hidden_size",
9 | "heads": "num_attention_heads",
10 | "layers": "num_hidden_layers",
11 | "layer_attr": "layer",
12 | "token_embeddings_attr": "embeddings"
13 | },
14 | "pooler": "mean_pooler",
15 | },
16 | # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
17 | "xlm-roberta": {
18 | "config_names": {
19 | "context_length": "max_position_embeddings",
20 | "vocab_size": "vocab_size",
21 | "width": "hidden_size",
22 | "heads": "num_attention_heads",
23 | "layers": "num_hidden_layers",
24 | "layer_attr": "layer",
25 | "token_embeddings_attr": "embeddings"
26 | },
27 | "pooler": "mean_pooler",
28 | },
29 | # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
30 | "mt5": {
31 | "config_names": {
32 | # unlimited seqlen
33 | # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
34 | # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
35 | "context_length": "",
36 | "vocab_size": "vocab_size",
37 | "width": "d_model",
38 | "heads": "num_heads",
39 | "layers": "num_layers",
40 | "layer_attr": "block",
41 | "token_embeddings_attr": "embed_tokens"
42 | },
43 | "pooler": "mean_pooler",
44 | },
45 | "bert": {
46 | "config_names": {
47 | "context_length": "max_position_embeddings",
48 | "vocab_size": "vocab_size",
49 | "width": "hidden_size",
50 | "heads": "num_attention_heads",
51 | "layers": "num_hidden_layers",
52 | "layer_attr": "layer",
53 | "token_embeddings_attr": "embeddings"
54 | },
55 | "pooler": "mean_pooler",
56 | }
57 | }
58 |
--------------------------------------------------------------------------------
/eva_clip/hf_model.py:
--------------------------------------------------------------------------------
1 | """ huggingface model adapter
2 |
3 | Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
4 | """
5 |
6 | import re
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torch.nn import functional as F
11 | from torch import TensorType
12 | try:
13 | import transformers
14 | from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
15 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
16 | BaseModelOutputWithPoolingAndCrossAttentions
17 | except ImportError as e:
18 | transformers = None
19 |
20 |
21 | class BaseModelOutput:
22 | pass
23 |
24 |
25 | class PretrainedConfig:
26 | pass
27 |
28 | from .hf_configs import arch_dict
29 |
30 | # utils
31 | def _camel2snake(s):
32 | return re.sub(r'(? TensorType:
140 | # image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
141 | # attn_mask = (x != self.config.pad_token_id).long()
142 | # out = self.transformer(
143 | # input_ids=x,
144 | # attention_mask=attn_mask,
145 | # encoder_hidden_states = image_embeds,
146 | # encoder_attention_mask = image_atts,
147 | # )
148 | # pooled_out = self.pooler(out, attn_mask)
149 |
150 | # return self.itm_proj(pooled_out)
151 |
152 | def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
153 | if masked_indices is None:
154 | masked_indices = torch.bernoulli(probability_matrix).bool()
155 |
156 | masked_indices[input_ids == self.tokenizer.pad_token_id] = False
157 | masked_indices[input_ids == self.tokenizer.cls_token_id] = False
158 |
159 | if targets is not None:
160 | targets[~masked_indices] = -100 # We only compute loss on masked tokens
161 |
162 | # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
163 | indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
164 | input_ids[indices_replaced] = self.tokenizer.mask_token_id
165 |
166 | # 10% of the time, we replace masked input tokens with random word
167 | indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
168 | random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
169 | input_ids[indices_random] = random_words[indices_random]
170 | # The rest of the time (10% of the time) we keep the masked input tokens unchanged
171 |
172 | if targets is not None:
173 | return input_ids, targets
174 | else:
175 | return input_ids
176 |
177 | def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
178 | labels = input_ids.clone()
179 | attn_mask = (input_ids != self.config.pad_token_id).long()
180 | image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
181 | vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
182 | probability_matrix = torch.full(labels.shape, mlm_probability)
183 | input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
184 | probability_matrix = probability_matrix)
185 | mlm_output = self.transformer(input_ids,
186 | attention_mask = attn_mask,
187 | encoder_hidden_states = image_embeds,
188 | encoder_attention_mask = image_atts,
189 | return_dict = True,
190 | labels = labels,
191 | )
192 | return mlm_output.loss
193 | # mlm_output = self.transformer(input_ids,
194 | # attention_mask = attn_mask,
195 | # encoder_hidden_states = image_embeds,
196 | # encoder_attention_mask = image_atts,
197 | # return_dict = True,
198 | # ).last_hidden_state
199 | # logits = self.mlm_proj(mlm_output)
200 |
201 | # # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
202 | # logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
203 | # labels = labels[:, 1:].contiguous().view(-1)
204 |
205 | # mlm_loss = F.cross_entropy(
206 | # logits,
207 | # labels,
208 | # # label_smoothing=0.1,
209 | # )
210 | # return mlm_loss
211 |
212 |
213 | def forward(self, x:TensorType) -> TensorType:
214 | attn_mask = (x != self.config.pad_token_id).long()
215 | out = self.transformer(input_ids=x, attention_mask=attn_mask)
216 | pooled_out = self.pooler(out, attn_mask)
217 |
218 | return self.proj(pooled_out)
219 |
220 | def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
221 | if not unlocked_layers: # full freezing
222 | for n, p in self.transformer.named_parameters():
223 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
224 | return
225 |
226 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
227 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
228 | print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
229 | embeddings = getattr(
230 | self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
231 | modules = [embeddings, *layer_list][:-unlocked_layers]
232 | # freeze layers
233 | for module in modules:
234 | for n, p in module.named_parameters():
235 | p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
236 |
237 |
238 | @torch.jit.ignore
239 | def set_grad_checkpointing(self, enable=True):
240 | self.transformer.gradient_checkpointing_enable()
241 |
242 | def get_num_layers(self):
243 | encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
244 | layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
245 | return len(layer_list)
246 |
247 | def init_parameters(self):
248 | pass
249 |
--------------------------------------------------------------------------------
/eva_clip/loss.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn as nn
4 | from torch.nn import functional as F
5 |
6 | try:
7 | import torch.distributed.nn
8 | from torch import distributed as dist
9 | has_distributed = True
10 | except ImportError:
11 | has_distributed = False
12 |
13 | try:
14 | import horovod.torch as hvd
15 | except ImportError:
16 | hvd = None
17 |
18 | from timm.loss import LabelSmoothingCrossEntropy
19 |
20 |
21 | def gather_features(
22 | image_features,
23 | text_features,
24 | local_loss=False,
25 | gather_with_grad=False,
26 | rank=0,
27 | world_size=1,
28 | use_horovod=False
29 | ):
30 | assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
31 | if use_horovod:
32 | assert hvd is not None, 'Please install horovod'
33 | if gather_with_grad:
34 | all_image_features = hvd.allgather(image_features)
35 | all_text_features = hvd.allgather(text_features)
36 | else:
37 | with torch.no_grad():
38 | all_image_features = hvd.allgather(image_features)
39 | all_text_features = hvd.allgather(text_features)
40 | if not local_loss:
41 | # ensure grads for local rank when all_* features don't have a gradient
42 | gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
43 | gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
44 | gathered_image_features[rank] = image_features
45 | gathered_text_features[rank] = text_features
46 | all_image_features = torch.cat(gathered_image_features, dim=0)
47 | all_text_features = torch.cat(gathered_text_features, dim=0)
48 | else:
49 | # We gather tensors from all gpus
50 | if gather_with_grad:
51 | all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
52 | all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
53 | # all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features, async_op=True), dim=0)
54 | # all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features, async_op=True), dim=0)
55 | else:
56 | gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
57 | gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
58 | dist.all_gather(gathered_image_features, image_features)
59 | dist.all_gather(gathered_text_features, text_features)
60 | if not local_loss:
61 | # ensure grads for local rank when all_* features don't have a gradient
62 | gathered_image_features[rank] = image_features
63 | gathered_text_features[rank] = text_features
64 | all_image_features = torch.cat(gathered_image_features, dim=0)
65 | all_text_features = torch.cat(gathered_text_features, dim=0)
66 |
67 | return all_image_features, all_text_features
68 |
69 |
70 | class ClipLoss(nn.Module):
71 |
72 | def __init__(
73 | self,
74 | local_loss=False,
75 | gather_with_grad=False,
76 | cache_labels=False,
77 | rank=0,
78 | world_size=1,
79 | use_horovod=False,
80 | smoothing=0.,
81 | ):
82 | super().__init__()
83 | self.local_loss = local_loss
84 | self.gather_with_grad = gather_with_grad
85 | self.cache_labels = cache_labels
86 | self.rank = rank
87 | self.world_size = world_size
88 | self.use_horovod = use_horovod
89 | self.label_smoothing_cross_entropy = LabelSmoothingCrossEntropy(smoothing=smoothing) if smoothing > 0 else None
90 |
91 | # cache state
92 | self.prev_num_logits = 0
93 | self.labels = {}
94 |
95 | def forward(self, image_features, text_features, logit_scale=1.):
96 | device = image_features.device
97 | if self.world_size > 1:
98 | all_image_features, all_text_features = gather_features(
99 | image_features, text_features,
100 | self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
101 |
102 | if self.local_loss:
103 | logits_per_image = logit_scale * image_features @ all_text_features.T
104 | logits_per_text = logit_scale * text_features @ all_image_features.T
105 | else:
106 | logits_per_image = logit_scale * all_image_features @ all_text_features.T
107 | logits_per_text = logits_per_image.T
108 | else:
109 | logits_per_image = logit_scale * image_features @ text_features.T
110 | logits_per_text = logit_scale * text_features @ image_features.T
111 | # calculated ground-truth and cache if enabled
112 | num_logits = logits_per_image.shape[0]
113 | if self.prev_num_logits != num_logits or device not in self.labels:
114 | labels = torch.arange(num_logits, device=device, dtype=torch.long)
115 | if self.world_size > 1 and self.local_loss:
116 | labels = labels + num_logits * self.rank
117 | if self.cache_labels:
118 | self.labels[device] = labels
119 | self.prev_num_logits = num_logits
120 | else:
121 | labels = self.labels[device]
122 |
123 | if self.label_smoothing_cross_entropy:
124 | total_loss = (
125 | self.label_smoothing_cross_entropy(logits_per_image, labels) +
126 | self.label_smoothing_cross_entropy(logits_per_text, labels)
127 | ) / 2
128 | else:
129 | total_loss = (
130 | F.cross_entropy(logits_per_image, labels) +
131 | F.cross_entropy(logits_per_text, labels)
132 | ) / 2
133 |
134 | acc = None
135 | i2t_acc = (logits_per_image.argmax(-1) == labels).sum() / len(logits_per_image)
136 | t2i_acc = (logits_per_text.argmax(-1) == labels).sum() / len(logits_per_text)
137 | acc = {"i2t": i2t_acc, "t2i": t2i_acc}
138 | return total_loss, acc
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA01-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "patch_size": 16,
8 | "eva_model_name": "eva-clip-b-16",
9 | "ls_init_value": 0.1,
10 | "drop_path_rate": 0.0
11 | },
12 | "text_cfg": {
13 | "context_length": 77,
14 | "vocab_size": 49408,
15 | "width": 512,
16 | "heads": 8,
17 | "layers": 12
18 | }
19 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA01-CLIP-g-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 1024,
19 | "heads": 16,
20 | "layers": 24,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA01-CLIP-g-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 40,
6 | "width": 1408,
7 | "head_width": 88,
8 | "mlp_ratio": 4.3637,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-g-14-x",
11 | "drop_path_rate": 0.4,
12 | "xattn": true,
13 | "fusedLN": true
14 | },
15 | "text_cfg": {
16 | "context_length": 77,
17 | "vocab_size": 49408,
18 | "width": 768,
19 | "heads": 12,
20 | "layers": 12,
21 | "xattn": false,
22 | "fusedLN": true
23 | }
24 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA02-CLIP-B-16.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 512,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 12,
6 | "width": 768,
7 | "head_width": 64,
8 | "patch_size": 16,
9 | "mlp_ratio": 2.6667,
10 | "eva_model_name": "eva-clip-b-16-X",
11 | "drop_path_rate": 0.0,
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 512,
24 | "heads": 8,
25 | "layers": 12,
26 | "xattn": true,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA02-CLIP-L-14-336.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 336,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14-336",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA02-CLIP-L-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 768,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 24,
6 | "width": 1024,
7 | "drop_path_rate": 0,
8 | "head_width": 64,
9 | "mlp_ratio": 2.6667,
10 | "patch_size": 14,
11 | "eva_model_name": "eva-clip-l-14",
12 | "xattn": true,
13 | "fusedLN": true,
14 | "rope": true,
15 | "pt_hw_seq_len": 16,
16 | "intp_freq": true,
17 | "naiveswiglu": true,
18 | "subln": true
19 | },
20 | "text_cfg": {
21 | "context_length": 77,
22 | "vocab_size": 49408,
23 | "width": 768,
24 | "heads": 12,
25 | "layers": 12,
26 | "xattn": false,
27 | "fusedLN": true
28 | }
29 | }
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA02-CLIP-bigE-14-plus.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1280,
20 | "heads": 20,
21 | "layers": 32,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
26 |
--------------------------------------------------------------------------------
/eva_clip/model_configs/EVA02-CLIP-bigE-14.json:
--------------------------------------------------------------------------------
1 | {
2 | "embed_dim": 1024,
3 | "vision_cfg": {
4 | "image_size": 224,
5 | "layers": 64,
6 | "width": 1792,
7 | "head_width": 112,
8 | "mlp_ratio": 8.571428571428571,
9 | "patch_size": 14,
10 | "eva_model_name": "eva-clip-4b-14-x",
11 | "drop_path_rate": 0,
12 | "xattn": true,
13 | "postnorm": true,
14 | "fusedLN": true
15 | },
16 | "text_cfg": {
17 | "context_length": 77,
18 | "vocab_size": 49408,
19 | "width": 1024,
20 | "heads": 16,
21 | "layers": 24,
22 | "xattn": false,
23 | "fusedLN": true
24 | }
25 | }
--------------------------------------------------------------------------------
/eva_clip/modified_resnet.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 |
7 | from eva_clip.utils import freeze_batch_norm_2d
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.act1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.act2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.act3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.act1(self.bn1(self.conv1(x)))
46 | out = self.act2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.act3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x, key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0.,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 |
92 | return x[0]
93 |
94 |
95 | class ModifiedResNet(nn.Module):
96 | """
97 | A ResNet class that is similar to torchvision's but contains the following changes:
98 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
99 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
100 | - The final pooling layer is a QKV attention instead of an average pool
101 | """
102 |
103 | def __init__(self, layers, output_dim, heads, image_size=224, width=64):
104 | super().__init__()
105 | self.output_dim = output_dim
106 | self.image_size = image_size
107 |
108 | # the 3-layer stem
109 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
110 | self.bn1 = nn.BatchNorm2d(width // 2)
111 | self.act1 = nn.ReLU(inplace=True)
112 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
113 | self.bn2 = nn.BatchNorm2d(width // 2)
114 | self.act2 = nn.ReLU(inplace=True)
115 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
116 | self.bn3 = nn.BatchNorm2d(width)
117 | self.act3 = nn.ReLU(inplace=True)
118 | self.avgpool = nn.AvgPool2d(2)
119 |
120 | # residual layers
121 | self._inplanes = width # this is a *mutable* variable used during construction
122 | self.layer1 = self._make_layer(width, layers[0])
123 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
124 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
125 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
126 |
127 | embed_dim = width * 32 # the ResNet feature dimension
128 | self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
129 |
130 | self.init_parameters()
131 |
132 | def _make_layer(self, planes, blocks, stride=1):
133 | layers = [Bottleneck(self._inplanes, planes, stride)]
134 |
135 | self._inplanes = planes * Bottleneck.expansion
136 | for _ in range(1, blocks):
137 | layers.append(Bottleneck(self._inplanes, planes))
138 |
139 | return nn.Sequential(*layers)
140 |
141 | def init_parameters(self):
142 | if self.attnpool is not None:
143 | std = self.attnpool.c_proj.in_features ** -0.5
144 | nn.init.normal_(self.attnpool.q_proj.weight, std=std)
145 | nn.init.normal_(self.attnpool.k_proj.weight, std=std)
146 | nn.init.normal_(self.attnpool.v_proj.weight, std=std)
147 | nn.init.normal_(self.attnpool.c_proj.weight, std=std)
148 |
149 | for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
150 | for name, param in resnet_block.named_parameters():
151 | if name.endswith("bn3.weight"):
152 | nn.init.zeros_(param)
153 |
154 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
155 | assert unlocked_groups == 0, 'partial locking not currently supported for this model'
156 | for param in self.parameters():
157 | param.requires_grad = False
158 | if freeze_bn_stats:
159 | freeze_batch_norm_2d(self)
160 |
161 | @torch.jit.ignore
162 | def set_grad_checkpointing(self, enable=True):
163 | # FIXME support for non-transformer
164 | pass
165 |
166 | def stem(self, x):
167 | x = self.act1(self.bn1(self.conv1(x)))
168 | x = self.act2(self.bn2(self.conv2(x)))
169 | x = self.act3(self.bn3(self.conv3(x)))
170 | x = self.avgpool(x)
171 | return x
172 |
173 | def forward(self, x):
174 | x = self.stem(x)
175 | x = self.layer1(x)
176 | x = self.layer2(x)
177 | x = self.layer3(x)
178 | x = self.layer4(x)
179 | x = self.attnpool(x)
180 |
181 | return x
182 |
--------------------------------------------------------------------------------
/eva_clip/openai.py:
--------------------------------------------------------------------------------
1 | """ OpenAI pretrained model functions
2 |
3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 |
6 | import os
7 | import warnings
8 | from typing import List, Optional, Union
9 |
10 | import torch
11 |
12 | from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
13 | from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
14 |
15 | __all__ = ["list_openai_models", "load_openai_model"]
16 |
17 |
18 | def list_openai_models() -> List[str]:
19 | """Returns the names of available CLIP models"""
20 | return list_pretrained_models_by_tag('openai')
21 |
22 |
23 | def load_openai_model(
24 | name: str,
25 | precision: Optional[str] = None,
26 | device: Optional[Union[str, torch.device]] = None,
27 | jit: bool = True,
28 | cache_dir: Optional[str] = None,
29 | ):
30 | """Load a CLIP model
31 |
32 | Parameters
33 | ----------
34 | name : str
35 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
36 | precision: str
37 | Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
38 | device : Union[str, torch.device]
39 | The device to put the loaded model
40 | jit : bool
41 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
42 | cache_dir : Optional[str]
43 | The directory to cache the downloaded model weights
44 |
45 | Returns
46 | -------
47 | model : torch.nn.Module
48 | The CLIP model
49 | preprocess : Callable[[PIL.Image], torch.Tensor]
50 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
51 | """
52 | if device is None:
53 | device = "cuda" if torch.cuda.is_available() else "cpu"
54 | if precision is None:
55 | precision = 'fp32' if device == 'cpu' else 'fp16'
56 |
57 | if get_pretrained_url(name, 'openai'):
58 | model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
59 | elif os.path.isfile(name):
60 | model_path = name
61 | else:
62 | raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
63 |
64 | try:
65 | # loading JIT archive
66 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
67 | state_dict = None
68 | except RuntimeError:
69 | # loading saved state dict
70 | if jit:
71 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
72 | jit = False
73 | state_dict = torch.load(model_path, map_location="cpu")
74 |
75 | if not jit:
76 | # Build a non-jit model from the OpenAI jitted model state dict
77 | cast_dtype = get_cast_dtype(precision)
78 | try:
79 | model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
80 | except KeyError:
81 | sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
82 | model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
83 |
84 | # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
85 | model = model.to(device)
86 | if precision.startswith('amp') or precision == 'fp32':
87 | model.float()
88 | elif precision == 'bf16':
89 | convert_weights_to_lp(model, dtype=torch.bfloat16)
90 |
91 | return model
92 |
93 | # patch the device names
94 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
95 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
96 |
97 | def patch_device(module):
98 | try:
99 | graphs = [module.graph] if hasattr(module, "graph") else []
100 | except RuntimeError:
101 | graphs = []
102 |
103 | if hasattr(module, "forward1"):
104 | graphs.append(module.forward1.graph)
105 |
106 | for graph in graphs:
107 | for node in graph.findAllNodes("prim::Constant"):
108 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
109 | node.copyAttributes(device_node)
110 |
111 | model.apply(patch_device)
112 | patch_device(model.encode_image)
113 | patch_device(model.encode_text)
114 |
115 | # patch dtype to float32 (typically for CPU)
116 | if precision == 'fp32':
117 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
118 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
119 | float_node = float_input.node()
120 |
121 | def patch_float(module):
122 | try:
123 | graphs = [module.graph] if hasattr(module, "graph") else []
124 | except RuntimeError:
125 | graphs = []
126 |
127 | if hasattr(module, "forward1"):
128 | graphs.append(module.forward1.graph)
129 |
130 | for graph in graphs:
131 | for node in graph.findAllNodes("aten::to"):
132 | inputs = list(node.inputs())
133 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
134 | if inputs[i].node()["value"] == 5:
135 | inputs[i].node().copyAttributes(float_node)
136 |
137 | model.apply(patch_float)
138 | patch_float(model.encode_image)
139 | patch_float(model.encode_text)
140 | model.float()
141 |
142 | # ensure image_size attr available at consistent location for both jit and non-jit
143 | model.visual.image_size = model.input_resolution.item()
144 | return model
145 |
--------------------------------------------------------------------------------
/eva_clip/pretrained.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from functools import partial
6 | from typing import Dict, Union
7 |
8 | from tqdm import tqdm
9 |
10 | try:
11 | from huggingface_hub import hf_hub_download
12 | _has_hf_hub = True
13 | except ImportError:
14 | hf_hub_download = None
15 | _has_hf_hub = False
16 |
17 |
18 | def _pcfg(url='', hf_hub='', filename='', mean=None, std=None):
19 | return dict(
20 | url=url,
21 | hf_hub=hf_hub,
22 | mean=mean,
23 | std=std,
24 | )
25 |
26 | _VITB32 = dict(
27 | openai=_pcfg(
28 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
29 | laion400m_e31=_pcfg(
30 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
31 | laion400m_e32=_pcfg(
32 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
33 | laion2b_e16=_pcfg(
34 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-laion2b_e16-af8dbd0c.pth"),
35 | laion2b_s34b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-B-32-laion2B-s34B-b79K/')
36 | )
37 |
38 | _VITB32_quickgelu = dict(
39 | openai=_pcfg(
40 | "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt"),
41 | laion400m_e31=_pcfg(
42 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt"),
43 | laion400m_e32=_pcfg(
44 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt"),
45 | )
46 |
47 | _VITB16 = dict(
48 | openai=_pcfg(
49 | "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt"),
50 | laion400m_e31=_pcfg(
51 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e31-00efa78f.pt"),
52 | laion400m_e32=_pcfg(
53 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16-laion400m_e32-55e67d44.pt"),
54 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-B-16-laion2B-s34B-b88K/'),
55 | )
56 |
57 | _EVAB16 = dict(
58 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
59 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_B_psz14to16.pt'),
60 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
61 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_B_psz16_s8B.pt'),
62 | )
63 |
64 | _VITB16_PLUS_240 = dict(
65 | laion400m_e31=_pcfg(
66 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e31-8fb26589.pt"),
67 | laion400m_e32=_pcfg(
68 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_16_plus_240-laion400m_e32-699c4b84.pt"),
69 | )
70 |
71 | _VITL14 = dict(
72 | openai=_pcfg(
73 | "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt"),
74 | laion400m_e31=_pcfg(
75 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e31-69988bb6.pt"),
76 | laion400m_e32=_pcfg(
77 | "https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_l_14-laion400m_e32-3d133497.pt"),
78 | laion2b_s32b_b82k=_pcfg(
79 | hf_hub='laion/CLIP-ViT-L-14-laion2B-s32B-b82K/',
80 | mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
81 | )
82 |
83 | _EVAL14 = dict(
84 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
85 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_L_psz14.pt'),
86 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
87 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_s4B.pt'),
88 | )
89 |
90 | _VITL14_336 = dict(
91 | openai=_pcfg(
92 | "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt"),
93 | )
94 |
95 | _EVAL14_336 = dict(
96 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
97 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_336_psz14_s6B.pt'),
98 | eva_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
99 | eva02_clip_224to336=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_L_psz14_224to336.pt'),
100 | )
101 |
102 | _VITH14 = dict(
103 | laion2b_s32b_b79k=_pcfg(hf_hub='laion/CLIP-ViT-H-14-laion2B-s32B-b79K/'),
104 | )
105 |
106 | _VITg14 = dict(
107 | laion2b_s12b_b42k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s12B-b42K/'),
108 | laion2b_s34b_b88k=_pcfg(hf_hub='laion/CLIP-ViT-g-14-laion2B-s34B-b88K/'),
109 | )
110 |
111 | _EVAg14 = dict(
112 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
113 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
114 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
115 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_psz14_s11B.pt'),
116 | )
117 |
118 | _EVAg14_PLUS = dict(
119 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/'),
120 | eva01=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_g_psz14.pt'),
121 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
122 | eva01_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA01_CLIP_g_14_plus_psz14_s11B.pt'),
123 | )
124 |
125 | _VITbigG14 = dict(
126 | laion2b_s39b_b160k=_pcfg(hf_hub='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k/'),
127 | )
128 |
129 | _EVAbigE14 = dict(
130 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
131 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
132 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
133 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_s4B.pt'),
134 | )
135 |
136 | _EVAbigE14_PLUS = dict(
137 | eva=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
138 | eva02=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_E_psz14.pt'),
139 | eva_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
140 | eva02_clip=_pcfg(hf_hub='QuanSun/EVA-CLIP/EVA02_CLIP_E_psz14_plus_s9B.pt'),
141 | )
142 |
143 |
144 | _PRETRAINED = {
145 | # "ViT-B-32": _VITB32,
146 | "OpenaiCLIP-B-32": _VITB32,
147 | "OpenCLIP-B-32": _VITB32,
148 |
149 | # "ViT-B-32-quickgelu": _VITB32_quickgelu,
150 | "OpenaiCLIP-B-32-quickgelu": _VITB32_quickgelu,
151 | "OpenCLIP-B-32-quickgelu": _VITB32_quickgelu,
152 |
153 | # "ViT-B-16": _VITB16,
154 | "OpenaiCLIP-B-16": _VITB16,
155 | "OpenCLIP-B-16": _VITB16,
156 |
157 | "EVA02-B-16": _EVAB16,
158 | "EVA02-CLIP-B-16": _EVAB16,
159 |
160 | # "ViT-B-16-plus-240": _VITB16_PLUS_240,
161 | "OpenCLIP-B-16-plus-240": _VITB16_PLUS_240,
162 |
163 | # "ViT-L-14": _VITL14,
164 | "OpenaiCLIP-L-14": _VITL14,
165 | "OpenCLIP-L-14": _VITL14,
166 |
167 | "EVA02-L-14": _EVAL14,
168 | "EVA02-CLIP-L-14": _EVAL14,
169 |
170 | # "ViT-L-14-336": _VITL14_336,
171 | "OpenaiCLIP-L-14-336": _VITL14_336,
172 |
173 | "EVA02-CLIP-L-14-336": _EVAL14_336,
174 |
175 | # "ViT-H-14": _VITH14,
176 | # "ViT-g-14": _VITg14,
177 | "OpenCLIP-H-14": _VITH14,
178 | "OpenCLIP-g-14": _VITg14,
179 |
180 | "EVA01-CLIP-g-14": _EVAg14,
181 | "EVA01-CLIP-g-14-plus": _EVAg14_PLUS,
182 |
183 | # "ViT-bigG-14": _VITbigG14,
184 | "OpenCLIP-bigG-14": _VITbigG14,
185 |
186 | "EVA02-CLIP-bigE-14": _EVAbigE14,
187 | "EVA02-CLIP-bigE-14-plus": _EVAbigE14_PLUS,
188 | }
189 |
190 |
191 | def _clean_tag(tag: str):
192 | # normalize pretrained tags
193 | return tag.lower().replace('-', '_')
194 |
195 |
196 | def list_pretrained(as_str: bool = False):
197 | """ returns list of pretrained models
198 | Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
199 | """
200 | return [':'.join([k, t]) if as_str else (k, t) for k in _PRETRAINED.keys() for t in _PRETRAINED[k].keys()]
201 |
202 |
203 | def list_pretrained_models_by_tag(tag: str):
204 | """ return all models having the specified pretrain tag """
205 | models = []
206 | tag = _clean_tag(tag)
207 | for k in _PRETRAINED.keys():
208 | if tag in _PRETRAINED[k]:
209 | models.append(k)
210 | return models
211 |
212 |
213 | def list_pretrained_tags_by_model(model: str):
214 | """ return all pretrain tags for the specified model architecture """
215 | tags = []
216 | if model in _PRETRAINED:
217 | tags.extend(_PRETRAINED[model].keys())
218 | return tags
219 |
220 |
221 | def is_pretrained_cfg(model: str, tag: str):
222 | if model not in _PRETRAINED:
223 | return False
224 | return _clean_tag(tag) in _PRETRAINED[model]
225 |
226 |
227 | def get_pretrained_cfg(model: str, tag: str):
228 | if model not in _PRETRAINED:
229 | return {}
230 | model_pretrained = _PRETRAINED[model]
231 | return model_pretrained.get(_clean_tag(tag), {})
232 |
233 |
234 | def get_pretrained_url(model: str, tag: str):
235 | cfg = get_pretrained_cfg(model, _clean_tag(tag))
236 | return cfg.get('url', '')
237 |
238 |
239 | def download_pretrained_from_url(
240 | url: str,
241 | cache_dir: Union[str, None] = None,
242 | ):
243 | if not cache_dir:
244 | cache_dir = os.path.expanduser("~/.cache/clip")
245 | os.makedirs(cache_dir, exist_ok=True)
246 | filename = os.path.basename(url)
247 |
248 | if 'openaipublic' in url:
249 | expected_sha256 = url.split("/")[-2]
250 | elif 'mlfoundations' in url:
251 | expected_sha256 = os.path.splitext(filename)[0].split("-")[-1]
252 | else:
253 | expected_sha256 = ''
254 |
255 | download_target = os.path.join(cache_dir, filename)
256 |
257 | if os.path.exists(download_target) and not os.path.isfile(download_target):
258 | raise RuntimeError(f"{download_target} exists and is not a regular file")
259 |
260 | if os.path.isfile(download_target):
261 | if expected_sha256:
262 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
263 | return download_target
264 | else:
265 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
266 | else:
267 | return download_target
268 |
269 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
270 | with tqdm(total=int(source.headers.get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
271 | while True:
272 | buffer = source.read(8192)
273 | if not buffer:
274 | break
275 |
276 | output.write(buffer)
277 | loop.update(len(buffer))
278 |
279 | if expected_sha256 and not hashlib.sha256(open(download_target, "rb").read()).hexdigest().startswith(expected_sha256):
280 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
281 |
282 | return download_target
283 |
284 |
285 | def has_hf_hub(necessary=False):
286 | if not _has_hf_hub and necessary:
287 | # if no HF Hub module installed, and it is necessary to continue, raise error
288 | raise RuntimeError(
289 | 'Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`.')
290 | return _has_hf_hub
291 |
292 |
293 | def download_pretrained_from_hf(
294 | model_id: str,
295 | filename: str = 'open_clip_pytorch_model.bin',
296 | revision=None,
297 | cache_dir: Union[str, None] = None,
298 | ):
299 | has_hf_hub(True)
300 | cached_file = hf_hub_download(model_id, filename, revision=revision, cache_dir=cache_dir)
301 | return cached_file
302 |
303 |
304 | def download_pretrained(
305 | cfg: Dict,
306 | force_hf_hub: bool = False,
307 | cache_dir: Union[str, None] = None,
308 | ):
309 | target = ''
310 | if not cfg:
311 | return target
312 |
313 | download_url = cfg.get('url', '')
314 | download_hf_hub = cfg.get('hf_hub', '')
315 | if download_hf_hub and force_hf_hub:
316 | # use HF hub even if url exists
317 | download_url = ''
318 |
319 | if download_url:
320 | target = download_pretrained_from_url(download_url, cache_dir=cache_dir)
321 | elif download_hf_hub:
322 | has_hf_hub(True)
323 | # we assume the hf_hub entries in pretrained config combine model_id + filename in
324 | # 'org/model_name/filename.pt' form. To specify just the model id w/o filename and
325 | # use 'open_clip_pytorch_model.bin' default, there must be a trailing slash 'org/model_name/'.
326 | model_id, filename = os.path.split(download_hf_hub)
327 | if filename:
328 | target = download_pretrained_from_hf(model_id, filename=filename, cache_dir=cache_dir)
329 | else:
330 | target = download_pretrained_from_hf(model_id, cache_dir=cache_dir)
331 |
332 | return target
333 |
--------------------------------------------------------------------------------
/eva_clip/rope.py:
--------------------------------------------------------------------------------
1 | from math import pi
2 | import torch
3 | from torch import nn
4 | from einops import rearrange, repeat
5 | import logging
6 |
7 | def broadcat(tensors, dim = -1):
8 | num_tensors = len(tensors)
9 | shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
10 | assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
11 | shape_len = list(shape_lens)[0]
12 | dim = (dim + shape_len) if dim < 0 else dim
13 | dims = list(zip(*map(lambda t: list(t.shape), tensors)))
14 | expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
15 | assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation'
16 | max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
17 | expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
18 | expanded_dims.insert(dim, (dim, dims[dim]))
19 | expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
20 | tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
21 | return torch.cat(tensors, dim = dim)
22 |
23 | def rotate_half(x):
24 | x = rearrange(x, '... (d r) -> ... d r', r = 2)
25 | x1, x2 = x.unbind(dim = -1)
26 | x = torch.stack((-x2, x1), dim = -1)
27 | return rearrange(x, '... d r -> ... (d r)')
28 |
29 |
30 | class VisionRotaryEmbedding(nn.Module):
31 | def __init__(
32 | self,
33 | dim,
34 | pt_seq_len,
35 | ft_seq_len=None,
36 | custom_freqs = None,
37 | freqs_for = 'lang',
38 | theta = 10000,
39 | max_freq = 10,
40 | num_freqs = 1,
41 | ):
42 | super().__init__()
43 | if custom_freqs:
44 | freqs = custom_freqs
45 | elif freqs_for == 'lang':
46 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
47 | elif freqs_for == 'pixel':
48 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
49 | elif freqs_for == 'constant':
50 | freqs = torch.ones(num_freqs).float()
51 | else:
52 | raise ValueError(f'unknown modality {freqs_for}')
53 |
54 | if ft_seq_len is None: ft_seq_len = pt_seq_len
55 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
56 |
57 | freqs_h = torch.einsum('..., f -> ... f', t, freqs)
58 | freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2)
59 |
60 | freqs_w = torch.einsum('..., f -> ... f', t, freqs)
61 | freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2)
62 |
63 | freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1)
64 |
65 | self.register_buffer("freqs_cos", freqs.cos())
66 | self.register_buffer("freqs_sin", freqs.sin())
67 |
68 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
69 |
70 | def forward(self, t, start_index = 0):
71 | rot_dim = self.freqs_cos.shape[-1]
72 | end_index = start_index + rot_dim
73 | assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
74 | t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
75 | t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
76 |
77 | return torch.cat((t_left, t, t_right), dim = -1)
78 |
79 | class VisionRotaryEmbeddingFast(nn.Module):
80 | def __init__(
81 | self,
82 | dim,
83 | pt_seq_len,
84 | ft_seq_len=None,
85 | custom_freqs = None,
86 | freqs_for = 'lang',
87 | theta = 10000,
88 | max_freq = 10,
89 | num_freqs = 1,
90 | patch_dropout = 0.
91 | ):
92 | super().__init__()
93 | if custom_freqs:
94 | freqs = custom_freqs
95 | elif freqs_for == 'lang':
96 | freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
97 | elif freqs_for == 'pixel':
98 | freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi
99 | elif freqs_for == 'constant':
100 | freqs = torch.ones(num_freqs).float()
101 | else:
102 | raise ValueError(f'unknown modality {freqs_for}')
103 |
104 | if ft_seq_len is None: ft_seq_len = pt_seq_len
105 | t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
106 |
107 | freqs = torch.einsum('..., f -> ... f', t, freqs)
108 | freqs = repeat(freqs, '... n -> ... (n r)', r = 2)
109 | freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1)
110 |
111 | freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
112 | freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
113 |
114 | self.patch_dropout = patch_dropout
115 |
116 | self.register_buffer("freqs_cos", freqs_cos)
117 | self.register_buffer("freqs_sin", freqs_sin)
118 |
119 | logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
120 |
121 | def forward(self, t, patch_indices_keep=None):
122 | if patch_indices_keep is not None:
123 | batch = t.size()[0]
124 | batch_indices = torch.arange(batch)
125 | batch_indices = batch_indices[..., None]
126 |
127 | freqs_cos = repeat(self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
128 | freqs_sin = repeat(self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1])
129 |
130 | freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
131 | freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
132 | freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
133 | freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
134 |
135 | return t * freqs_cos + rotate_half(t) * freqs_sin
136 |
137 | return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
--------------------------------------------------------------------------------
/eva_clip/timm_model.py:
--------------------------------------------------------------------------------
1 | """ timm model adapter
2 |
3 | Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4 | """
5 | import logging
6 | from collections import OrderedDict
7 |
8 | import torch
9 | import torch.nn as nn
10 |
11 | try:
12 | import timm
13 | from timm.models.layers import Mlp, to_2tuple
14 | try:
15 | # old timm imports < 0.8.1
16 | from timm.models.layers.attention_pool2d import RotAttentionPool2d
17 | from timm.models.layers.attention_pool2d import AttentionPool2d as AbsAttentionPool2d
18 | except ImportError:
19 | # new timm imports >= 0.8.1
20 | from timm.layers import RotAttentionPool2d
21 | from timm.layers import AttentionPool2d as AbsAttentionPool2d
22 | except ImportError:
23 | timm = None
24 |
25 | from .utils import freeze_batch_norm_2d
26 |
27 |
28 | class TimmModel(nn.Module):
29 | """ timm model adapter
30 | # FIXME this adapter is a work in progress, may change in ways that break weight compat
31 | """
32 |
33 | def __init__(
34 | self,
35 | model_name,
36 | embed_dim,
37 | image_size=224,
38 | pool='avg',
39 | proj='linear',
40 | proj_bias=False,
41 | drop=0.,
42 | pretrained=False):
43 | super().__init__()
44 | if timm is None:
45 | raise RuntimeError("Please `pip install timm` to use timm models.")
46 |
47 | self.image_size = to_2tuple(image_size)
48 | self.trunk = timm.create_model(model_name, pretrained=pretrained)
49 | feat_size = self.trunk.default_cfg.get('pool_size', None)
50 | feature_ndim = 1 if not feat_size else 2
51 | if pool in ('abs_attn', 'rot_attn'):
52 | assert feature_ndim == 2
53 | # if attn pooling used, remove both classifier and default pool
54 | self.trunk.reset_classifier(0, global_pool='')
55 | else:
56 | # reset global pool if pool config set, otherwise leave as network default
57 | reset_kwargs = dict(global_pool=pool) if pool else {}
58 | self.trunk.reset_classifier(0, **reset_kwargs)
59 | prev_chs = self.trunk.num_features
60 |
61 | head_layers = OrderedDict()
62 | if pool == 'abs_attn':
63 | head_layers['pool'] = AbsAttentionPool2d(prev_chs, feat_size=feat_size, out_features=embed_dim)
64 | prev_chs = embed_dim
65 | elif pool == 'rot_attn':
66 | head_layers['pool'] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
67 | prev_chs = embed_dim
68 | else:
69 | assert proj, 'projection layer needed if non-attention pooling is used.'
70 |
71 | # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
72 | if proj == 'linear':
73 | head_layers['drop'] = nn.Dropout(drop)
74 | head_layers['proj'] = nn.Linear(prev_chs, embed_dim, bias=proj_bias)
75 | elif proj == 'mlp':
76 | head_layers['mlp'] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop, bias=(True, proj_bias))
77 |
78 | self.head = nn.Sequential(head_layers)
79 |
80 | def lock(self, unlocked_groups=0, freeze_bn_stats=False):
81 | """ lock modules
82 | Args:
83 | unlocked_groups (int): leave last n layer groups unlocked (default: 0)
84 | """
85 | if not unlocked_groups:
86 | # lock full model
87 | for param in self.trunk.parameters():
88 | param.requires_grad = False
89 | if freeze_bn_stats:
90 | freeze_batch_norm_2d(self.trunk)
91 | else:
92 | # NOTE: partial freeze requires latest timm (master) branch and is subject to change
93 | try:
94 | # FIXME import here until API stable and in an official release
95 | from timm.models.helpers import group_parameters, group_modules
96 | except ImportError:
97 | raise RuntimeError(
98 | 'Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`')
99 | matcher = self.trunk.group_matcher()
100 | gparams = group_parameters(self.trunk, matcher)
101 | max_layer_id = max(gparams.keys())
102 | max_layer_id = max_layer_id - unlocked_groups
103 | for group_idx in range(max_layer_id + 1):
104 | group = gparams[group_idx]
105 | for param in group:
106 | self.trunk.get_parameter(param).requires_grad = False
107 | if freeze_bn_stats:
108 | gmodules = group_modules(self.trunk, matcher, reverse=True)
109 | gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
110 | freeze_batch_norm_2d(self.trunk, gmodules)
111 |
112 | @torch.jit.ignore
113 | def set_grad_checkpointing(self, enable=True):
114 | try:
115 | self.trunk.set_grad_checkpointing(enable)
116 | except Exception as e:
117 | logging.warning('grad checkpointing not supported for this timm image tower, continuing without...')
118 |
119 | def forward(self, x):
120 | x = self.trunk(x)
121 | x = self.head(x)
122 | return x
123 |
--------------------------------------------------------------------------------
/eva_clip/tokenizer.py:
--------------------------------------------------------------------------------
1 | """ CLIP tokenizer
2 |
3 | Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4 | """
5 | import gzip
6 | import html
7 | import os
8 | from functools import lru_cache
9 | from typing import Union, List
10 |
11 | import ftfy
12 | import regex as re
13 | import torch
14 |
15 | # https://stackoverflow.com/q/62691279
16 | import os
17 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
18 |
19 |
20 | @lru_cache()
21 | def default_bpe():
22 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
23 |
24 |
25 | @lru_cache()
26 | def bytes_to_unicode():
27 | """
28 | Returns list of utf-8 byte and a corresponding list of unicode strings.
29 | The reversible bpe codes work on unicode strings.
30 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
31 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
32 | This is a signficant percentage of your normal, say, 32K bpe vocab.
33 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
34 | And avoids mapping to whitespace/control characters the bpe code barfs on.
35 | """
36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
37 | cs = bs[:]
38 | n = 0
39 | for b in range(2**8):
40 | if b not in bs:
41 | bs.append(b)
42 | cs.append(2**8+n)
43 | n += 1
44 | cs = [chr(n) for n in cs]
45 | return dict(zip(bs, cs))
46 |
47 |
48 | def get_pairs(word):
49 | """Return set of symbol pairs in a word.
50 | Word is represented as tuple of symbols (symbols being variable-length strings).
51 | """
52 | pairs = set()
53 | prev_char = word[0]
54 | for char in word[1:]:
55 | pairs.add((prev_char, char))
56 | prev_char = char
57 | return pairs
58 |
59 |
60 | def basic_clean(text):
61 | text = ftfy.fix_text(text)
62 | text = html.unescape(html.unescape(text))
63 | return text.strip()
64 |
65 |
66 | def whitespace_clean(text):
67 | text = re.sub(r'\s+', ' ', text)
68 | text = text.strip()
69 | return text
70 |
71 |
72 | class SimpleTokenizer(object):
73 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
74 | self.byte_encoder = bytes_to_unicode()
75 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
76 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
77 | merges = merges[1:49152-256-2+1]
78 | merges = [tuple(merge.split()) for merge in merges]
79 | vocab = list(bytes_to_unicode().values())
80 | vocab = vocab + [v+'' for v in vocab]
81 | for merge in merges:
82 | vocab.append(''.join(merge))
83 | if not special_tokens:
84 | special_tokens = ['', '']
85 | else:
86 | special_tokens = ['', ''] + special_tokens
87 | vocab.extend(special_tokens)
88 | self.encoder = dict(zip(vocab, range(len(vocab))))
89 | self.decoder = {v: k for k, v in self.encoder.items()}
90 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
91 | self.cache = {t:t for t in special_tokens}
92 | special = "|".join(special_tokens)
93 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
94 |
95 | self.vocab_size = len(self.encoder)
96 | self.all_special_ids = [self.encoder[t] for t in special_tokens]
97 |
98 | def bpe(self, token):
99 | if token in self.cache:
100 | return self.cache[token]
101 | word = tuple(token[:-1]) + ( token[-1] + '',)
102 | pairs = get_pairs(word)
103 |
104 | if not pairs:
105 | return token+''
106 |
107 | while True:
108 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
109 | if bigram not in self.bpe_ranks:
110 | break
111 | first, second = bigram
112 | new_word = []
113 | i = 0
114 | while i < len(word):
115 | try:
116 | j = word.index(first, i)
117 | new_word.extend(word[i:j])
118 | i = j
119 | except:
120 | new_word.extend(word[i:])
121 | break
122 |
123 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
124 | new_word.append(first+second)
125 | i += 2
126 | else:
127 | new_word.append(word[i])
128 | i += 1
129 | new_word = tuple(new_word)
130 | word = new_word
131 | if len(word) == 1:
132 | break
133 | else:
134 | pairs = get_pairs(word)
135 | word = ' '.join(word)
136 | self.cache[token] = word
137 | return word
138 |
139 | def encode(self, text):
140 | bpe_tokens = []
141 | text = whitespace_clean(basic_clean(text)).lower()
142 | for token in re.findall(self.pat, text):
143 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
144 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
145 | return bpe_tokens
146 |
147 | def decode(self, tokens):
148 | text = ''.join([self.decoder[token] for token in tokens])
149 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
150 | return text
151 |
152 |
153 | _tokenizer = SimpleTokenizer()
154 |
155 |
156 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
157 | """
158 | Returns the tokenized representation of given input string(s)
159 |
160 | Parameters
161 | ----------
162 | texts : Union[str, List[str]]
163 | An input string or a list of input strings to tokenize
164 | context_length : int
165 | The context length to use; all CLIP models use 77 as the context length
166 |
167 | Returns
168 | -------
169 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
170 | """
171 | if isinstance(texts, str):
172 | texts = [texts]
173 |
174 | sot_token = _tokenizer.encoder[""]
175 | eot_token = _tokenizer.encoder[""]
176 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
177 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
178 |
179 | for i, tokens in enumerate(all_tokens):
180 | if len(tokens) > context_length:
181 | tokens = tokens[:context_length] # Truncate
182 | tokens[-1] = eot_token
183 | result[i, :len(tokens)] = torch.tensor(tokens)
184 |
185 | return result
186 |
187 |
188 | class HFTokenizer:
189 | "HuggingFace tokenizer wrapper"
190 | def __init__(self, tokenizer_name:str):
191 | from transformers import AutoTokenizer
192 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
193 |
194 | def __call__(self, texts:Union[str, List[str]], context_length:int=77) -> torch.Tensor:
195 | # same cleaning as for default tokenizer, except lowercasing
196 | # adding lower (for case-sensitive tokenizers) will make it more robust but less sensitive to nuance
197 | if isinstance(texts, str):
198 | texts = [texts]
199 | texts = [whitespace_clean(basic_clean(text)) for text in texts]
200 | input_ids = self.tokenizer(texts, return_tensors='pt', max_length=context_length, padding='max_length', truncation=True).input_ids
201 | return input_ids
202 |
--------------------------------------------------------------------------------
/eva_clip/transform.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Sequence, Tuple
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torchvision.transforms.functional as F
6 |
7 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, \
8 | CenterCrop
9 |
10 | from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
11 |
12 |
13 | class ResizeMaxSize(nn.Module):
14 |
15 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0):
16 | super().__init__()
17 | if not isinstance(max_size, int):
18 | raise TypeError(f"Size should be int. Got {type(max_size)}")
19 | self.max_size = max_size
20 | self.interpolation = interpolation
21 | self.fn = min if fn == 'min' else min
22 | self.fill = fill
23 |
24 | def forward(self, img):
25 | if isinstance(img, torch.Tensor):
26 | height, width = img.shape[:2]
27 | else:
28 | width, height = img.size
29 | scale = self.max_size / float(max(height, width))
30 | if scale != 1.0:
31 | new_size = tuple(round(dim * scale) for dim in (height, width))
32 | img = F.resize(img, new_size, self.interpolation)
33 | pad_h = self.max_size - new_size[0]
34 | pad_w = self.max_size - new_size[1]
35 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill)
36 | return img
37 |
38 |
39 | def _convert_to_rgb(image):
40 | return image.convert('RGB')
41 |
42 |
43 | # class CatGen(nn.Module):
44 | # def __init__(self, num=4):
45 | # self.num = num
46 | # def mixgen_batch(image, text):
47 | # batch_size = image.shape[0]
48 | # index = np.random.permutation(batch_size)
49 |
50 | # cat_images = []
51 | # for i in range(batch_size):
52 | # # image mixup
53 | # image[i,:] = lam * image[i,:] + (1 - lam) * image[index[i],:]
54 | # # text concat
55 | # text[i] = tokenizer((str(text[i]) + " " + str(text[index[i]])))[0]
56 | # text = torch.stack(text)
57 | # return image, text
58 |
59 |
60 | def image_transform(
61 | image_size: int,
62 | is_train: bool,
63 | mean: Optional[Tuple[float, ...]] = None,
64 | std: Optional[Tuple[float, ...]] = None,
65 | resize_longest_max: bool = False,
66 | fill_color: int = 0,
67 | ):
68 | mean = mean or OPENAI_DATASET_MEAN
69 | if not isinstance(mean, (list, tuple)):
70 | mean = (mean,) * 3
71 |
72 | std = std or OPENAI_DATASET_STD
73 | if not isinstance(std, (list, tuple)):
74 | std = (std,) * 3
75 |
76 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]:
77 | # for square size, pass size as int so that Resize() uses aspect preserving shortest edge
78 | image_size = image_size[0]
79 |
80 | normalize = Normalize(mean=mean, std=std)
81 | if is_train:
82 | return Compose([
83 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC),
84 | _convert_to_rgb,
85 | ToTensor(),
86 | normalize,
87 | ])
88 | else:
89 | if resize_longest_max:
90 | transforms = [
91 | ResizeMaxSize(image_size, fill=fill_color)
92 | ]
93 | else:
94 | transforms = [
95 | Resize(image_size, interpolation=InterpolationMode.BICUBIC),
96 | CenterCrop(image_size),
97 | ]
98 | transforms.extend([
99 | _convert_to_rgb,
100 | ToTensor(),
101 | normalize,
102 | ])
103 | return Compose(transforms)
104 |
--------------------------------------------------------------------------------
/lists/kinetics_400_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,abseiling
3 | 1,air drumming
4 | 2,answering questions
5 | 3,applauding
6 | 4,applying cream
7 | 5,archery
8 | 6,arm wrestling
9 | 7,arranging flowers
10 | 8,assembling computer
11 | 9,auctioning
12 | 10,baby waking up
13 | 11,baking cookies
14 | 12,balloon blowing
15 | 13,bandaging
16 | 14,barbequing
17 | 15,bartending
18 | 16,beatboxing
19 | 17,bee keeping
20 | 18,belly dancing
21 | 19,bench pressing
22 | 20,bending back
23 | 21,bending metal
24 | 22,biking through snow
25 | 23,blasting sand
26 | 24,blowing glass
27 | 25,blowing leaves
28 | 26,blowing nose
29 | 27,blowing out candles
30 | 28,bobsledding
31 | 29,bookbinding
32 | 30,bouncing on trampoline
33 | 31,bowling
34 | 32,braiding hair
35 | 33,breading or breadcrumbing
36 | 34,breakdancing
37 | 35,brush painting
38 | 36,brushing hair
39 | 37,brushing teeth
40 | 38,building cabinet
41 | 39,building shed
42 | 40,bungee jumping
43 | 41,busking
44 | 42,canoeing or kayaking
45 | 43,capoeira
46 | 44,carrying baby
47 | 45,cartwheeling
48 | 46,carving pumpkin
49 | 47,catching fish
50 | 48,catching or throwing baseball
51 | 49,catching or throwing frisbee
52 | 50,catching or throwing softball
53 | 51,celebrating
54 | 52,changing oil
55 | 53,changing wheel
56 | 54,checking tires
57 | 55,cheerleading
58 | 56,chopping wood
59 | 57,clapping
60 | 58,clay pottery making
61 | 59,clean and jerk
62 | 60,cleaning floor
63 | 61,cleaning gutters
64 | 62,cleaning pool
65 | 63,cleaning shoes
66 | 64,cleaning toilet
67 | 65,cleaning windows
68 | 66,climbing a rope
69 | 67,climbing ladder
70 | 68,climbing tree
71 | 69,contact juggling
72 | 70,cooking chicken
73 | 71,cooking egg
74 | 72,cooking on campfire
75 | 73,cooking sausages
76 | 74,counting money
77 | 75,country line dancing
78 | 76,cracking neck
79 | 77,crawling baby
80 | 78,crossing river
81 | 79,crying
82 | 80,curling hair
83 | 81,cutting nails
84 | 82,cutting pineapple
85 | 83,cutting watermelon
86 | 84,dancing ballet
87 | 85,dancing charleston
88 | 86,dancing gangnam style
89 | 87,dancing macarena
90 | 88,deadlifting
91 | 89,decorating the christmas tree
92 | 90,digging
93 | 91,dining
94 | 92,disc golfing
95 | 93,diving cliff
96 | 94,dodgeball
97 | 95,doing aerobics
98 | 96,doing laundry
99 | 97,doing nails
100 | 98,drawing
101 | 99,dribbling basketball
102 | 100,drinking
103 | 101,drinking beer
104 | 102,drinking shots
105 | 103,driving car
106 | 104,driving tractor
107 | 105,drop kicking
108 | 106,drumming fingers
109 | 107,dunking basketball
110 | 108,dying hair
111 | 109,eating burger
112 | 110,eating cake
113 | 111,eating carrots
114 | 112,eating chips
115 | 113,eating doughnuts
116 | 114,eating hotdog
117 | 115,eating ice cream
118 | 116,eating spaghetti
119 | 117,eating watermelon
120 | 118,egg hunting
121 | 119,exercising arm
122 | 120,exercising with an exercise ball
123 | 121,extinguishing fire
124 | 122,faceplanting
125 | 123,feeding birds
126 | 124,feeding fish
127 | 125,feeding goats
128 | 126,filling eyebrows
129 | 127,finger snapping
130 | 128,fixing hair
131 | 129,flipping pancake
132 | 130,flying kite
133 | 131,folding clothes
134 | 132,folding napkins
135 | 133,folding paper
136 | 134,front raises
137 | 135,frying vegetables
138 | 136,garbage collecting
139 | 137,gargling
140 | 138,getting a haircut
141 | 139,getting a tattoo
142 | 140,giving or receiving award
143 | 141,golf chipping
144 | 142,golf driving
145 | 143,golf putting
146 | 144,grinding meat
147 | 145,grooming dog
148 | 146,grooming horse
149 | 147,gymnastics tumbling
150 | 148,hammer throw
151 | 149,headbanging
152 | 150,headbutting
153 | 151,high jump
154 | 152,high kick
155 | 153,hitting baseball
156 | 154,hockey stop
157 | 155,holding snake
158 | 156,hopscotch
159 | 157,hoverboarding
160 | 158,hugging
161 | 159,hula hooping
162 | 160,hurdling
163 | 161,hurling (sport)
164 | 162,ice climbing
165 | 163,ice fishing
166 | 164,ice skating
167 | 165,ironing
168 | 166,javelin throw
169 | 167,jetskiing
170 | 168,jogging
171 | 169,juggling balls
172 | 170,juggling fire
173 | 171,juggling soccer ball
174 | 172,jumping into pool
175 | 173,jumpstyle dancing
176 | 174,kicking field goal
177 | 175,kicking soccer ball
178 | 176,kissing
179 | 177,kitesurfing
180 | 178,knitting
181 | 179,krumping
182 | 180,laughing
183 | 181,laying bricks
184 | 182,long jump
185 | 183,lunge
186 | 184,making a cake
187 | 185,making a sandwich
188 | 186,making bed
189 | 187,making jewelry
190 | 188,making pizza
191 | 189,making snowman
192 | 190,making sushi
193 | 191,making tea
194 | 192,marching
195 | 193,massaging back
196 | 194,massaging feet
197 | 195,massaging legs
198 | 196,massaging person's head
199 | 197,milking cow
200 | 198,mopping floor
201 | 199,motorcycling
202 | 200,moving furniture
203 | 201,mowing lawn
204 | 202,news anchoring
205 | 203,opening bottle
206 | 204,opening present
207 | 205,paragliding
208 | 206,parasailing
209 | 207,parkour
210 | 208,passing American football (in game)
211 | 209,passing American football (not in game)
212 | 210,peeling apples
213 | 211,peeling potatoes
214 | 212,petting animal (not cat)
215 | 213,petting cat
216 | 214,picking fruit
217 | 215,planting trees
218 | 216,plastering
219 | 217,playing accordion
220 | 218,playing badminton
221 | 219,playing bagpipes
222 | 220,playing basketball
223 | 221,playing bass guitar
224 | 222,playing cards
225 | 223,playing cello
226 | 224,playing chess
227 | 225,playing clarinet
228 | 226,playing controller
229 | 227,playing cricket
230 | 228,playing cymbals
231 | 229,playing didgeridoo
232 | 230,playing drums
233 | 231,playing flute
234 | 232,playing guitar
235 | 233,playing harmonica
236 | 234,playing harp
237 | 235,playing ice hockey
238 | 236,playing keyboard
239 | 237,playing kickball
240 | 238,playing monopoly
241 | 239,playing organ
242 | 240,playing paintball
243 | 241,playing piano
244 | 242,playing poker
245 | 243,playing recorder
246 | 244,playing saxophone
247 | 245,playing squash or racquetball
248 | 246,playing tennis
249 | 247,playing trombone
250 | 248,playing trumpet
251 | 249,playing ukulele
252 | 250,playing violin
253 | 251,playing volleyball
254 | 252,playing xylophone
255 | 253,pole vault
256 | 254,presenting weather forecast
257 | 255,pull ups
258 | 256,pumping fist
259 | 257,pumping gas
260 | 258,punching bag
261 | 259,punching person (boxing)
262 | 260,push up
263 | 261,pushing car
264 | 262,pushing cart
265 | 263,pushing wheelchair
266 | 264,reading book
267 | 265,reading newspaper
268 | 266,recording music
269 | 267,riding a bike
270 | 268,riding camel
271 | 269,riding elephant
272 | 270,riding mechanical bull
273 | 271,riding mountain bike
274 | 272,riding mule
275 | 273,riding or walking with horse
276 | 274,riding scooter
277 | 275,riding unicycle
278 | 276,ripping paper
279 | 277,robot dancing
280 | 278,rock climbing
281 | 279,rock scissors paper
282 | 280,roller skating
283 | 281,running on treadmill
284 | 282,sailing
285 | 283,salsa dancing
286 | 284,sanding floor
287 | 285,scrambling eggs
288 | 286,scuba diving
289 | 287,setting table
290 | 288,shaking hands
291 | 289,shaking head
292 | 290,sharpening knives
293 | 291,sharpening pencil
294 | 292,shaving head
295 | 293,shaving legs
296 | 294,shearing sheep
297 | 295,shining shoes
298 | 296,shooting basketball
299 | 297,shooting goal (soccer)
300 | 298,shot put
301 | 299,shoveling snow
302 | 300,shredding paper
303 | 301,shuffling cards
304 | 302,side kick
305 | 303,sign language interpreting
306 | 304,singing
307 | 305,situp
308 | 306,skateboarding
309 | 307,ski jumping
310 | 308,skiing (not slalom or crosscountry)
311 | 309,skiing crosscountry
312 | 310,skiing slalom
313 | 311,skipping rope
314 | 312,skydiving
315 | 313,slacklining
316 | 314,slapping
317 | 315,sled dog racing
318 | 316,smoking
319 | 317,smoking hookah
320 | 318,snatch weight lifting
321 | 319,sneezing
322 | 320,sniffing
323 | 321,snorkeling
324 | 322,snowboarding
325 | 323,snowkiting
326 | 324,snowmobiling
327 | 325,somersaulting
328 | 326,spinning poi
329 | 327,spray painting
330 | 328,spraying
331 | 329,springboard diving
332 | 330,squat
333 | 331,sticking tongue out
334 | 332,stomping grapes
335 | 333,stretching arm
336 | 334,stretching leg
337 | 335,strumming guitar
338 | 336,surfing crowd
339 | 337,surfing water
340 | 338,sweeping floor
341 | 339,swimming backstroke
342 | 340,swimming breast stroke
343 | 341,swimming butterfly stroke
344 | 342,swing dancing
345 | 343,swinging legs
346 | 344,swinging on something
347 | 345,sword fighting
348 | 346,tai chi
349 | 347,taking a shower
350 | 348,tango dancing
351 | 349,tap dancing
352 | 350,tapping guitar
353 | 351,tapping pen
354 | 352,tasting beer
355 | 353,tasting food
356 | 354,testifying
357 | 355,texting
358 | 356,throwing axe
359 | 357,throwing ball
360 | 358,throwing discus
361 | 359,tickling
362 | 360,tobogganing
363 | 361,tossing coin
364 | 362,tossing salad
365 | 363,training dog
366 | 364,trapezing
367 | 365,trimming or shaving beard
368 | 366,trimming trees
369 | 367,triple jump
370 | 368,tying bow tie
371 | 369,tying knot (not on a tie)
372 | 370,tying tie
373 | 371,unboxing
374 | 372,unloading truck
375 | 373,using computer
376 | 374,using remote controller (not gaming)
377 | 375,using segway
378 | 376,vault
379 | 377,waiting in line
380 | 378,walking the dog
381 | 379,washing dishes
382 | 380,washing feet
383 | 381,washing hair
384 | 382,washing hands
385 | 383,water skiing
386 | 384,water sliding
387 | 385,watering plants
388 | 386,waxing back
389 | 387,waxing chest
390 | 388,waxing eyebrows
391 | 389,waxing legs
392 | 390,weaving basket
393 | 391,welding
394 | 392,whistling
395 | 393,windsurfing
396 | 394,wrapping present
397 | 395,wrestling
398 | 396,writing
399 | 397,yawning
400 | 398,yoga
401 | 399,zumba
402 |
--------------------------------------------------------------------------------
/lists/sth_labels.csv:
--------------------------------------------------------------------------------
1 | id,name
2 | 0,Approaching something with your camera
3 | 1,Attaching something to something
4 | 2,Bending something so that it deforms
5 | 3,Bending something until it breaks
6 | 4,Burying something in something
7 | 5,Closing something
8 | 6,Covering something with something
9 | 7,Digging something out of something
10 | 8,Dropping something behind something
11 | 9,Dropping something in front of something
12 | 10,Dropping something into something
13 | 11,Dropping something next to something
14 | 12,Dropping something onto something
15 | 13,Failing to put something into something because something does not fit
16 | 14,Folding something
17 | 15,Hitting something with something
18 | 16,Holding something
19 | 17,Holding something behind something
20 | 18,Holding something in front of something
21 | 19,Holding something next to something
22 | 20,Holding something over something
23 | 21,Laying something on the table on its side not upright
24 | 22,Letting something roll along a flat surface
25 | 23,Letting something roll down a slanted surface
26 | 24,Letting something roll up a slanted surface so it rolls back down
27 | 25,Lifting a surface with something on it but not enough for it to slide down
28 | 26,Lifting a surface with something on it until it starts sliding down
29 | 27,Lifting something up completely without letting it drop down
30 | 28,Lifting something up completely then letting it drop down
31 | 29,Lifting something with something on it
32 | 30,Lifting up one end of something without letting it drop down
33 | 31,Lifting up one end of something then letting it drop down
34 | 32,Moving away from something with your camera
35 | 33,Moving part of something
36 | 34,Moving something across a surface until it falls down
37 | 35,Moving something across a surface without it falling down
38 | 36,Moving something and something away from each other
39 | 37,Moving something and something closer to each other
40 | 38,Moving something and something so they collide with each other
41 | 39,Moving something and something so they pass each other
42 | 40,Moving something away from something
43 | 41,Moving something away from the camera
44 | 42,Moving something closer to something
45 | 43,Moving something down
46 | 44,Moving something towards the camera
47 | 45,Moving something up
48 | 46,Opening something
49 | 47,Picking something up
50 | 48,Piling something up
51 | 49,Plugging something into something
52 | 50,Plugging something into something but pulling it right out as you remove your hand
53 | 51,Poking a hole into some substance
54 | 52,Poking a hole into something soft
55 | 53,Poking a stack of something so the stack collapses
56 | 54,Poking a stack of something without the stack collapsing
57 | 55,Poking something so it slightly moves
58 | 56,Poking something so lightly that it doesn't or almost doesn't move
59 | 57,Poking something so that it falls over
60 | 58,Poking something so that it spins around
61 | 59,Pouring something into something
62 | 60,Pouring something into something until it overflows
63 | 61,Pouring something onto something
64 | 62,Pouring something out of something
65 | 63,Pretending or failing to wipe something off of something
66 | 64,Pretending or trying and failing to twist something
67 | 65,Pretending to be tearing something that is not tearable
68 | 66,Pretending to close something without actually closing it
69 | 67,Pretending to open something without actually opening it
70 | 68,Pretending to pick something up
71 | 69,Pretending to poke something
72 | 70,Pretending to pour something out of something but something is empty
73 | 71,Pretending to put something behind something
74 | 72,Pretending to put something into something
75 | 73,Pretending to put something next to something
76 | 74,Pretending to put something on a surface
77 | 75,Pretending to put something onto something
78 | 76,Pretending to put something underneath something
79 | 77,Pretending to scoop something up with something
80 | 78,Pretending to spread air onto something
81 | 79,Pretending to sprinkle air onto something
82 | 80,Pretending to squeeze something
83 | 81,Pretending to take something from somewhere
84 | 82,Pretending to take something out of something
85 | 83,Pretending to throw something
86 | 84,Pretending to turn something upside down
87 | 85,Pulling something from behind of something
88 | 86,Pulling something from left to right
89 | 87,Pulling something from right to left
90 | 88,Pulling something onto something
91 | 89,Pulling something out of something
92 | 90,Pulling two ends of something but nothing happens
93 | 91,Pulling two ends of something so that it gets stretched
94 | 92,Pulling two ends of something so that it separates into two pieces
95 | 93,Pushing something from left to right
96 | 94,Pushing something from right to left
97 | 95,Pushing something off of something
98 | 96,Pushing something onto something
99 | 97,Pushing something so it spins
100 | 98,Pushing something so that it almost falls off but doesn't
101 | 99,Pushing something so that it falls off the table
102 | 100,Pushing something so that it slightly moves
103 | 101,Pushing something with something
104 | 102,Putting number of something onto something
105 | 103,Putting something and something on the table
106 | 104,Putting something behind something
107 | 105,Putting something in front of something
108 | 106,Putting something into something
109 | 107,Putting something next to something
110 | 108,Putting something on a flat surface without letting it roll
111 | 109,Putting something on a surface
112 | 110,Putting something on the edge of something so it is not supported and falls down
113 | 111,Putting something onto a slanted surface but it doesn't glide down
114 | 112,Putting something onto something
115 | 113,Putting something onto something else that cannot support it so it falls down
116 | 114,Putting something similar to other things that are already on the table
117 | 115,Putting something that can't roll onto a slanted surface so it slides down
118 | 116,Putting something that can't roll onto a slanted surface so it stays where it is
119 | 117,Putting something that cannot actually stand upright upright on the table so it falls on its side
120 | 118,Putting something underneath something
121 | 119,Putting something upright on the table
122 | 120,Putting something something and something on the table
123 | 121,Removing something revealing something behind
124 | 122,Rolling something on a flat surface
125 | 123,Scooping something up with something
126 | 124,Showing a photo of something to the camera
127 | 125,Showing something behind something
128 | 126,Showing something next to something
129 | 127,Showing something on top of something
130 | 128,Showing something to the camera
131 | 129,Showing that something is empty
132 | 130,Showing that something is inside something
133 | 131,Something being deflected from something
134 | 132,Something colliding with something and both are being deflected
135 | 133,Something colliding with something and both come to a halt
136 | 134,Something falling like a feather or paper
137 | 135,Something falling like a rock
138 | 136,Spilling something behind something
139 | 137,Spilling something next to something
140 | 138,Spilling something onto something
141 | 139,Spinning something so it continues spinning
142 | 140,Spinning something that quickly stops spinning
143 | 141,Spreading something onto something
144 | 142,Sprinkling something onto something
145 | 143,Squeezing something
146 | 144,Stacking number of something
147 | 145,Stuffing something into something
148 | 146,Taking one of many similar things on the table
149 | 147,Taking something from somewhere
150 | 148,Taking something out of something
151 | 149,Tearing something into two pieces
152 | 150,Tearing something just a little bit
153 | 151,Throwing something
154 | 152,Throwing something against something
155 | 153,Throwing something in the air and catching it
156 | 154,Throwing something in the air and letting it fall
157 | 155,Throwing something onto a surface
158 | 156,Tilting something with something on it slightly so it doesn't fall down
159 | 157,Tilting something with something on it until it falls off
160 | 158,Tipping something over
161 | 159,Tipping something with something in it over so something in it falls out
162 | 160,Touching (without moving) part of something
163 | 161,Trying but failing to attach something to something because it doesn't stick
164 | 162,Trying to bend something unbendable so nothing happens
165 | 163,Trying to pour something into something but missing so it spills next to it
166 | 164,Turning something upside down
167 | 165,Turning the camera downwards while filming something
168 | 166,Turning the camera left while filming something
169 | 167,Turning the camera right while filming something
170 | 168,Turning the camera upwards while filming something
171 | 169,Twisting (wringing) something wet until water comes out
172 | 170,Twisting something
173 | 171,Uncovering something
174 | 172,Unfolding something
175 | 173,Wiping something off of something
--------------------------------------------------------------------------------
/optim_factory.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import optim as optim
3 |
4 | from timm.optim.adafactor import Adafactor
5 | from timm.optim.adahessian import Adahessian
6 | from timm.optim.adamp import AdamP
7 | from timm.optim.lookahead import Lookahead
8 | from timm.optim.nadam import Nadam
9 | from timm.optim.novograd import NovoGrad
10 | from timm.optim.nvnovograd import NvNovoGrad
11 | from timm.optim.radam import RAdam
12 | from timm.optim.rmsprop_tf import RMSpropTF
13 | from timm.optim.sgdp import SGDP
14 |
15 | import json
16 |
17 | try:
18 | from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
19 | has_apex = True
20 | except ImportError:
21 | has_apex = False
22 |
23 |
24 | def get_num_layer_for_vit(var_name, num_max_layer):
25 | if var_name in ("pos_embed", "temporal_embed"): ###!!!
26 | return 0
27 | elif var_name.startswith("visual.conv1"):
28 | return 0
29 | elif var_name.startswith("visual.ln_pre"):
30 | return 0
31 | elif var_name.startswith("visual.transformer.resblocks"):
32 | layer_id = int(var_name.split('.')[3]) #// 2
33 | return layer_id + 1
34 | else:
35 | return num_max_layer - 1
36 |
37 | def get_num_layer_for_eva_vit(var_name, num_max_layer):
38 | if var_name in ("pos_embed", "temporal_embed"): ###!!!
39 | return 0
40 | elif var_name.startswith("visual.patch_embed"):
41 | return 0
42 | #elif var_name.startswith("visual.ln_pre"):
43 | # return 0
44 | elif var_name.startswith("visual.blocks"):
45 | layer_id = int(var_name.split('.')[2]) #// 2
46 | return layer_id + 1
47 | else:
48 | return num_max_layer - 1
49 |
50 | class LayerDecayValueAssigner(object):
51 | def __init__(self, values, vit_arch='clip'):
52 | self.values = values
53 | self.vit_arch = vit_arch
54 |
55 | def get_scale(self, layer_id):
56 | return self.values[layer_id]
57 |
58 | def get_layer_id(self, var_name):
59 | if self.vit_arch == 'clip':
60 | return get_num_layer_for_vit(var_name, len(self.values))
61 | elif self.vit_arch == 'eva_clip':
62 | return get_num_layer_for_eva_vit(var_name, len(self.values))
63 |
64 | def get_parameter_groups(model, weight_decay=1e-5, skip_list=(), get_num_layer=None, get_layer_scale=None):
65 | parameter_group_names = {}
66 | parameter_group_vars = {}
67 |
68 | for name, param in model.named_parameters():
69 | if not param.requires_grad:
70 | continue # frozen weights
71 |
72 |
73 | if len(param.shape) == 1 or name.endswith(".bias") or name in skip_list: ###!!!
74 | group_name = "no_decay"
75 | this_weight_decay = 0.
76 | elif 'control_point' in name:
77 | group_name = "control_point"
78 | this_weight_decay = weight_decay
79 | group_name = 'control_point'
80 | else:
81 | group_name = "decay"
82 | this_weight_decay = weight_decay
83 |
84 | if get_num_layer is not None:
85 | layer_id = get_num_layer(name)
86 | group_name = "layer_%d_%s" % (layer_id, group_name)
87 | else:
88 | layer_id = None
89 |
90 | if group_name not in parameter_group_names:
91 | if get_layer_scale is not None:
92 | if 'control_point' not in group_name and 'visual' in name: #for clip part
93 | scale = get_layer_scale(layer_id) #0.001
94 | else:
95 | scale = 1. #get_layer_scale(layer_id)
96 | else:
97 | scale = 1.
98 |
99 | parameter_group_names[group_name] = {
100 | "weight_decay": this_weight_decay,
101 | "params": [],
102 | "lr_scale": scale
103 | }
104 | parameter_group_vars[group_name] = {
105 | "weight_decay": this_weight_decay,
106 | "params": [],
107 | "lr_scale": scale
108 | }
109 |
110 | parameter_group_vars[group_name]["params"].append(param)
111 | parameter_group_names[group_name]["params"].append(name)
112 | print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
113 | return list(parameter_group_vars.values())
114 |
115 |
116 | def create_optimizer(args, model, get_num_layer=None, get_layer_scale=None, filter_bias_and_bn=True, skip_list=None):
117 | opt_lower = args.opt.lower()
118 | weight_decay = args.weight_decay
119 | if weight_decay and filter_bias_and_bn:
120 | skip = {}
121 | if skip_list is not None:
122 | skip = skip_list
123 | elif hasattr(model, 'no_weight_decay'):
124 | skip = model.no_weight_decay()
125 | parameters = get_parameter_groups(model, weight_decay, skip, get_num_layer, get_layer_scale)
126 | weight_decay = 0.
127 | else:
128 | parameters = model.parameters()
129 |
130 | if 'fused' in opt_lower:
131 | assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
132 |
133 | opt_args = dict(lr=args.lr, weight_decay=weight_decay)
134 | if hasattr(args, 'opt_eps') and args.opt_eps is not None:
135 | opt_args['eps'] = args.opt_eps
136 | if hasattr(args, 'opt_betas') and args.opt_betas is not None:
137 | opt_args['betas'] = args.opt_betas
138 |
139 | print("optimizer settings:", opt_args)
140 |
141 | opt_split = opt_lower.split('_')
142 | opt_lower = opt_split[-1]
143 | if opt_lower == 'sgd' or opt_lower == 'nesterov':
144 | opt_args.pop('eps', None)
145 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
146 | elif opt_lower == 'momentum':
147 | opt_args.pop('eps', None)
148 | optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
149 | elif opt_lower == 'adam':
150 | optimizer = optim.Adam(parameters, **opt_args)
151 | elif opt_lower == 'adamw':
152 | optimizer = optim.AdamW(parameters, **opt_args)
153 | elif opt_lower == 'nadam':
154 | optimizer = Nadam(parameters, **opt_args)
155 | elif opt_lower == 'radam':
156 | optimizer = RAdam(parameters, **opt_args)
157 | elif opt_lower == 'adamp':
158 | optimizer = AdamP(parameters, wd_ratio=0.01, nesterov=True, **opt_args)
159 | elif opt_lower == 'sgdp':
160 | optimizer = SGDP(parameters, momentum=args.momentum, nesterov=True, **opt_args)
161 | elif opt_lower == 'adadelta':
162 | optimizer = optim.Adadelta(parameters, **opt_args)
163 | elif opt_lower == 'adafactor':
164 | if not args.lr:
165 | opt_args['lr'] = None
166 | optimizer = Adafactor(parameters, **opt_args)
167 | elif opt_lower == 'adahessian':
168 | optimizer = Adahessian(parameters, **opt_args)
169 | elif opt_lower == 'rmsprop':
170 | optimizer = optim.RMSprop(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
171 | elif opt_lower == 'rmsproptf':
172 | optimizer = RMSpropTF(parameters, alpha=0.9, momentum=args.momentum, **opt_args)
173 | elif opt_lower == 'novograd':
174 | optimizer = NovoGrad(parameters, **opt_args)
175 | elif opt_lower == 'nvnovograd':
176 | optimizer = NvNovoGrad(parameters, **opt_args)
177 | elif opt_lower == 'fusedsgd':
178 | opt_args.pop('eps', None)
179 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
180 | elif opt_lower == 'fusedmomentum':
181 | opt_args.pop('eps', None)
182 | optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
183 | elif opt_lower == 'fusedadam':
184 | optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
185 | elif opt_lower == 'fusedadamw':
186 | optimizer = FusedAdam(parameters, adam_w_mode=True, **opt_args)
187 | elif opt_lower == 'fusedlamb':
188 | optimizer = FusedLAMB(parameters, **opt_args)
189 | elif opt_lower == 'fusednovograd':
190 | opt_args.setdefault('betas', (0.95, 0.98))
191 | optimizer = FusedNovoGrad(parameters, **opt_args)
192 | else:
193 | assert False and "Invalid optimizer"
194 | raise ValueError
195 |
196 | if len(opt_split) > 1:
197 | if opt_split[0] == 'lookahead':
198 | optimizer = Lookahead(optimizer)
199 |
200 | return optimizer
201 |
202 |
--------------------------------------------------------------------------------
/pics/ATM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/whwu95/ATM/98ba3aa2ac258cc1b91beefe9317136657ae3d8d/pics/ATM.png
--------------------------------------------------------------------------------
/scripts/k400/train_base.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/k400/base_f8'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model ViT-B/16 \
15 | --data_set Kinetics-400 \
16 | --nb_classes 400 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 8 \
25 | --embed_dim 512 \
26 | --opt adamw \
27 | --lr 3e-4 \
28 | --layer_decay 0.65 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 20 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed
--------------------------------------------------------------------------------
/scripts/k400/train_eva_large.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/k400/eva_large_f8'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model EVA02-CLIP-L-14 \
15 | --data_set Kinetics-400 \
16 | --nb_classes 400 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 8 \
25 | --embed_dim 768 \
26 | --opt adamw \
27 | --lr 2e-4 \
28 | --layer_decay 0.7 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 20 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed
--------------------------------------------------------------------------------
/scripts/k400/train_eva_large_336.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/k400/eva_large_336_f8'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/test/k400/train_320_frames'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model EVA02-CLIP-L-14-336 \
15 | --data_set Kinetics-400 \
16 | --nb_classes 400 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 336 \
22 | --short_side_size 336 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 8 \
25 | --embed_dim 768 \
26 | --opt adamw \
27 | --lr 2e-4 \
28 | --layer_decay 0.7 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 20 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed
--------------------------------------------------------------------------------
/scripts/ssv1/test_base_f16.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F16: Overall Prec@1 60.609% Prec@5 86.450%
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
6 |
7 | MODEL_PATH='output/ssv1/base_f16_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV1 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 16 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv1/test_base_f32.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F32: Overall Prec@1 61.450% Prec@5 86.251%
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
6 |
7 | MODEL_PATH='output/ssv1/base_f32_bs10_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV1 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 32 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv1/test_base_f8.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F8: Overall Prec@1 58.744% Prec@5 85.427%
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
6 |
7 | MODEL_PATH='output/ssv1/base_f8_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18899 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV1 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 8 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv1/test_large_f16.sh:
--------------------------------------------------------------------------------
1 | # ViT-L/14 F16:
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
6 |
7 | MODEL_PATH='output/ssv1/large_f16_bs6_atm14/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-L/14 \
17 | --data_set SSV1 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 16 \
27 | --embed_dim 768 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv1/train_base.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/ssv1/base_f8'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model ViT-B/16 \
15 | --data_set SSV1 \
16 | --nb_classes 174 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 8 \
25 | --embed_dim 512 \
26 | --opt adamw \
27 | --lr 7e-4 \
28 | --layer_decay 0.70 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 20 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed \
36 | --aa True
--------------------------------------------------------------------------------
/scripts/ssv1/train_eva_large.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/ssv1/eva_large_f16'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v1'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model EVA02-CLIP-L-14 \
15 | --data_set SSV1 \
16 | --nb_classes 174 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 16 \
25 | --embed_dim 768 \
26 | --opt adamw \
27 | --lr 1e-3 \
28 | --layer_decay 0.75 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 15 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed \
36 | --aa True
--------------------------------------------------------------------------------
/scripts/ssv2/test_base_f16.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F16:
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
6 |
7 | MODEL_PATH='output/ssv2/base_f16_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV2 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 16 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv2/test_base_f32.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F32:
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
6 |
7 | MODEL_PATH='../ATM_ICCV23_CODE_MODELS/output/ssv2/base_f32_bs10_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV2 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 32 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv2/test_base_f8.sh:
--------------------------------------------------------------------------------
1 | # ViT-B/16 F8:
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
6 |
7 | MODEL_PATH='output/ssv2/base_f8_bs16_atm7/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18899 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-B/16 \
17 | --data_set SSV2 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 8 \
27 | --embed_dim 512 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv2/test_large_f16.sh:
--------------------------------------------------------------------------------
1 | # ViT-L/14 F16:
2 |
3 | # Set the path to save checkpoints
4 | OUTPUT_DIR='./output/test'
5 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
6 |
7 | MODEL_PATH='../ATM_ICCV23_CODE_MODELS/output/ssv2/large_f16_bs6_atm14/checkpoint-best/mp_rank_00_model_states.pt'
8 |
9 | MASTER_ADDR='127.0.0.1'
10 |
11 | NNODES=1
12 |
13 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
14 | --master_port 18890 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
15 | test_for_frame.py \
16 | --model ViT-L/14 \
17 | --data_set SSV2 \
18 | --nb_classes 174 \
19 | --data_path ${DATA_PATH} \
20 | --log_dir ${OUTPUT_DIR} \
21 | --output_dir ${OUTPUT_DIR} \
22 | --batch_size 10 \
23 | --input_size 224 \
24 | --short_side_size 224 \
25 | --save_ckpt_freq 5 \
26 | --num_frames 16 \
27 | --embed_dim 768 \
28 | --dist_eval \
29 | --test_num_segment 2 \
30 | --test_num_crop 3 \
31 | --resume ${MODEL_PATH}
32 |
33 |
34 |
--------------------------------------------------------------------------------
/scripts/ssv2/train_base.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/ssv2/base_f8'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model ViT-B/16 \
15 | --data_set SSV2 \
16 | --nb_classes 174 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 8 \
25 | --embed_dim 512 \
26 | --opt adamw \
27 | --lr 7e-4 \
28 | --layer_decay 0.70 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 20 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed \
36 | --aa True
--------------------------------------------------------------------------------
/scripts/ssv2/train_eva_large.sh:
--------------------------------------------------------------------------------
1 | # Set the path to save checkpoints
2 | OUTPUT_DIR='output/ssv2/eva_large_f16'
3 |
4 | DATA_PATH='/bpfs/v2_mnt/VIS/wuwenhao/20bn-something-something-v2-frames'
5 | # path to pretrain model
6 | MODEL_PATH=' '
7 |
8 | MASTER_ADDR='127.0.0.1'
9 | NNODES=1
10 |
11 | OMP_NUM_THREADS=1 python -m torch.distributed.launch --nproc_per_node=8 \
12 | --master_port 18892 --nnodes=${NNODES} --node_rank=0 --master_addr=${MASTER_ADDR} \
13 | run_class_finetuning.py \
14 | --model EVA02-CLIP-L-14 \
15 | --data_set SSV2 \
16 | --nb_classes 174 \
17 | --data_path ${DATA_PATH} \
18 | --log_dir ${OUTPUT_DIR} \
19 | --output_dir ${OUTPUT_DIR} \
20 | --batch_size 16 \
21 | --input_size 224 \
22 | --short_side_size 224 \
23 | --save_ckpt_freq 5 \
24 | --num_frames 16 \
25 | --embed_dim 768 \
26 | --opt adamw \
27 | --lr 1e-3 \
28 | --layer_decay 0.75 \
29 | --opt_betas 0.9 0.999 \
30 | --weight_decay 0.05 \
31 | --epochs 15 \
32 | --drop_path 0. \
33 | --drop 0. \
34 | --dist_eval \
35 | --enable_deepspeed \
36 | --aa True
--------------------------------------------------------------------------------