├── .idea ├── ALIGN.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── maple.iml ├── misc.xml ├── modules.xml ├── vcs.xml └── workspace.xml ├── README.md ├── clip ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── clip.cpython-37.pyc │ ├── clip.cpython-38.pyc │ ├── model.cpython-37.pyc │ ├── model.cpython-38.pyc │ ├── simple_tokenizer.cpython-37.pyc │ └── simple_tokenizer.cpython-38.pyc ├── bpe_simple_vocab_16e6.txt.gz ├── clip.py ├── model.py └── simple_tokenizer.py ├── clip_words.csv ├── configs ├── datasets │ ├── caltech101.yaml │ ├── dtd.yaml │ ├── eurosat.yaml │ ├── fgvc_aircraft.yaml │ ├── food101.yaml │ ├── imagenet.yaml │ ├── imagenet_a.yaml │ ├── imagenet_r.yaml │ ├── imagenet_sketch.yaml │ ├── imagenetv2.yaml │ ├── oxford_flowers.yaml │ ├── oxford_pets.yaml │ ├── stanford_cars.yaml │ ├── sun397.yaml │ └── ucf101.yaml └── trainers │ ├── CoCoOp │ ├── vit_b16_c16_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1.yaml │ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml │ └── vit_b16_c8_ep10_batch1.yaml │ ├── CoOp │ ├── rn101.yaml │ ├── rn101_ep50.yaml │ ├── rn50.yaml │ ├── rn50_ctxv1.yaml │ ├── rn50_ep100.yaml │ ├── rn50_ep50.yaml │ ├── rn50_ep50_ctxv1.yaml │ ├── rn50_val.yaml │ ├── vit_b16.yaml │ ├── vit_b16_ep100.yaml │ ├── vit_b16_ep50.yaml │ ├── vit_b32.yaml │ └── vit_b32_ep50.yaml │ ├── IVLP │ ├── vit_b16_c2_ep5_batch4_2+2ctx.yaml │ └── vit_b16_c2_ep5_batch4_4ctx_language_only.yaml │ ├── MMP │ ├── sun397.yaml │ ├── vit_b16_c2_ep5_batch4_2ctx.yaml │ └── vit_h.yaml │ ├── MaPLe │ ├── vit_b16_c2_ep5_batch4_2ctx.yaml │ └── vit_b16_c2_ep5_batch4_2ctx_cross_datasets.yaml │ └── VPT │ └── vit_b16_c2_ep5_batch4_4.yaml ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── caltech101.cpython-37.pyc │ ├── caltech101.cpython-38.pyc │ ├── dtd.cpython-37.pyc │ ├── dtd.cpython-38.pyc │ ├── eurosat.cpython-37.pyc │ ├── eurosat.cpython-38.pyc │ ├── fgvc_aircraft.cpython-37.pyc │ ├── fgvc_aircraft.cpython-38.pyc │ ├── food101.cpython-37.pyc │ ├── food101.cpython-38.pyc │ ├── imagenet.cpython-37.pyc │ ├── imagenet.cpython-38.pyc │ ├── imagenet_a.cpython-37.pyc │ ├── imagenet_a.cpython-38.pyc │ ├── imagenet_r.cpython-37.pyc │ ├── imagenet_r.cpython-38.pyc │ ├── imagenet_sketch.cpython-37.pyc │ ├── imagenet_sketch.cpython-38.pyc │ ├── imagenetv2.cpython-37.pyc │ ├── imagenetv2.cpython-38.pyc │ ├── oxford_flowers.cpython-37.pyc │ ├── oxford_flowers.cpython-38.pyc │ ├── oxford_pets.cpython-37.pyc │ ├── oxford_pets.cpython-38.pyc │ ├── stanford_cars.cpython-37.pyc │ ├── stanford_cars.cpython-38.pyc │ ├── sun397.cpython-37.pyc │ ├── sun397.cpython-38.pyc │ ├── ucf101.cpython-37.pyc │ └── ucf101.cpython-38.pyc ├── caltech101.py ├── dtd.py ├── eurosat.py ├── fgvc_aircraft.py ├── food101.py ├── imagenet.py ├── imagenet_a.py ├── imagenet_r.py ├── imagenet_sketch.py ├── imagenetv2.py ├── oxford_flowers.py ├── oxford_pets.py ├── stanford_cars.py ├── sun397.py └── ucf101.py ├── images └── ALIGN.png ├── parse_test_res.py ├── scripts ├── cocoop │ ├── base2new_test.sh │ ├── base2new_train.sh │ ├── xd_test.sh │ └── xd_train.sh ├── coop │ ├── basenewtrain.sh │ ├── eval.sh │ └── main.sh ├── independent-vlp │ ├── base2new_test_ivlp.sh │ ├── base2new_train_ivlp.sh │ ├── reproduce_ivlp.sh │ ├── xd_test_ivlp.sh │ └── xd_train_ivlp.sh ├── language-prompting │ ├── base2new_test_lp.sh │ ├── base2new_train_lp.sh │ ├── reproduce_lp.sh │ ├── xd_test_lp.sh │ └── xd_train_lp.sh ├── maple │ ├── base2new_test_maple.sh │ ├── base2new_train_maple.sh │ ├── fst.sh │ ├── reproduce_maple.sh │ ├── reproduce_maple_xd.sh │ ├── xd_test_maple.sh │ └── xd_train_maple.sh ├── mmp │ ├── base_to_new_test.sh │ └── base_to_new_train.sh ├── vpt │ ├── base2new_test_vpt.sh │ ├── base2new_train_vpt.sh │ ├── reproduce_vpt.sh │ ├── xd_test_vpt.sh │ └── xd_train_vpt.sh └── zsclip │ └── zeroshot.sh ├── train.py └── trainers ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── cocoop.cpython-37.pyc ├── cocoop.cpython-38.pyc ├── coop.cpython-37.pyc ├── coop.cpython-38.pyc ├── imagenet_templates.cpython-37.pyc ├── imagenet_templates.cpython-38.pyc ├── independentVL.cpython-37.pyc ├── independentVL.cpython-38.pyc ├── maple.cpython-37.pyc ├── maple.cpython-38.pyc ├── mmp.cpython-37.pyc ├── mmp.cpython-38.pyc ├── vpt.cpython-37.pyc ├── vpt.cpython-38.pyc ├── zsclip.cpython-37.pyc └── zsclip.cpython-38.pyc ├── cocoop.py ├── coop.py ├── imagenet_templates.py ├── independentVL.py ├── maple.py ├── mmp.py ├── vpt.py └── zsclip.py /.idea/ALIGN.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/maple.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 7 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | 8 | 13 | 14 | 19 | 20 | 21 | 23 | 24 | 25 | 26 | 27 | 29 | { 30 | "associatedIndex": 6 31 | } 32 | 33 | 34 | 38 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 113 | 114 | 115 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 1680770662970 147 | 177 | 178 | 179 | 180 | 182 | 183 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Tuning Multi-mode Token-level Prompt Alignment across Modalities [NeurIPS 2023] 2 | 3 | This is the official implementation of our paper [Tuning Multi-mode Token-level Prompt Alignment across Modalities](https://arxiv.org/abs/2309.13847) in NeurIPS 2023. 4 | 5 | ![avatar](images/ALIGN.png) 6 | 7 | The proposed ALIGN algorithm aims to learn multiple prompts in both textual and visual domains. Given the M visual prompts and N textual prompts, ALIGN first views the label/image as discrete distributions over the 8 | the M and N supporting, and each distribution itself can further be modeled as a discrete distribution over its model-specific token-level space. ALIGN applies the Prompt-level OT and Token-level OT to align those two 9 | domains. 10 | 11 | ## TODO 12 | Due to some ddls, we will add more details about the training scripts and results soon. 13 | 14 | ## Getting Started 15 | ### Install 16 | - Clone this repo: 17 | ```bash 18 | git clone https://github.com/wds2014/ALIGN.git 19 | cd ALIGN 20 | ``` 21 | - Please follow the [INSTALL.md](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main/docs/INSTALL.md) to build the python environment. 22 | 23 | ### Dataset 24 | - Datasets in our paper 25 | 26 | The datasets we used is as the same as previous works (CoOp and MAPLE). Please follow the [DATASETS.md](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main/docs/DATASETS.md) to prepare all datasets. 27 | 28 | ### Training 29 | - Easy to train: 30 | ```bash 31 | cd scripts/mmp 32 | bash base_to_new_train.sh 33 | ``` 34 | Change the DATASET and SEED in the .sh file to train our model in different datasets and seeds. 35 | 36 | ## Citation 37 | If you find this repo useful to your project, please consider to cite it with following bib: 38 | 39 | ```bash 40 | @article{wang2023tuning, 41 | title={Tuning Multi-mode Token-level Prompt Alignment across Modalities}, 42 | author={Wang, Dongsheng and Li, Miaoge and Liu, Xinyang and Xu, MingSheng and Chen, Bo and Zhang, Hanwang}, 43 | journal={arXiv preprint arXiv:2309.13847}, 44 | year={2023} 45 | } 46 | ``` 47 | 48 | ## Acknowledgements 49 | Our code is modified based on [CoOp](https://github.com/KaiyangZhou/CoOp) and [MAPLE](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main) repository. 50 | We thank the authors for releasing their code. -------------------------------------------------------------------------------- /clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/clip.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/clip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/clip.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/simple_tokenizer.cpython-37.pyc -------------------------------------------------------------------------------- /clip/__pycache__/simple_tokenizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/simple_tokenizer.cpython-38.pyc -------------------------------------------------------------------------------- /clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Union, List 6 | 7 | import torch 8 | from PIL import Image 9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 10 | from tqdm import tqdm 11 | 12 | from .model import build_model 13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 14 | 15 | try: 16 | from torchvision.transforms import InterpolationMode 17 | BICUBIC = InterpolationMode.BICUBIC 18 | except ImportError: 19 | BICUBIC = Image.BICUBIC 20 | 21 | 22 | if torch.__version__.split(".") < ["1", "7", "1"]: 23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 24 | 25 | 26 | __all__ = ["available_models", "load", "tokenize"] 27 | _tokenizer = _Tokenizer() 28 | 29 | _MODELS = { 30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 36 | } 37 | 38 | 39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 40 | os.makedirs(root, exist_ok=True) 41 | filename = os.path.basename(url) 42 | 43 | expected_sha256 = url.split("/")[-2] 44 | download_target = os.path.join(root, filename) 45 | 46 | if os.path.exists(download_target) and not os.path.isfile(download_target): 47 | raise RuntimeError(f"{download_target} exists and is not a regular file") 48 | 49 | if os.path.isfile(download_target): 50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 51 | return download_target 52 | else: 53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 54 | 55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 57 | while True: 58 | buffer = source.read(8192) 59 | if not buffer: 60 | break 61 | 62 | output.write(buffer) 63 | loop.update(len(buffer)) 64 | 65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 67 | 68 | return download_target 69 | 70 | 71 | def _transform(n_px): 72 | return Compose([ 73 | Resize(n_px, interpolation=BICUBIC), 74 | CenterCrop(n_px), 75 | lambda image: image.convert("RGB"), 76 | ToTensor(), 77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 78 | ]) 79 | 80 | 81 | def available_models() -> List[str]: 82 | """Returns the names of available CLIP models""" 83 | return list(_MODELS.keys()) 84 | 85 | 86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False): 87 | """Load a CLIP model 88 | 89 | Parameters 90 | ---------- 91 | name : str 92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 93 | 94 | device : Union[str, torch.device] 95 | The device to put the loaded model 96 | 97 | jit : bool 98 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 99 | 100 | Returns 101 | ------- 102 | model : torch.nn.Module 103 | The CLIP model 104 | 105 | preprocess : Callable[[PIL.Image], torch.Tensor] 106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 107 | """ 108 | if name in _MODELS: 109 | model_path = _download(_MODELS[name]) 110 | elif os.path.isfile(name): 111 | model_path = name 112 | else: 113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 114 | 115 | try: 116 | # loading JIT archive 117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 118 | state_dict = None 119 | except RuntimeError: 120 | # loading saved state dict 121 | if jit: 122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 123 | jit = False 124 | state_dict = torch.load(model_path, map_location="cpu") 125 | 126 | if not jit: 127 | model = build_model(state_dict or model.state_dict()).to(device) 128 | if str(device) == "cpu": 129 | model.float() 130 | return model, _transform(model.visual.input_resolution) 131 | 132 | # patch the device names 133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 135 | 136 | def patch_device(module): 137 | try: 138 | graphs = [module.graph] if hasattr(module, "graph") else [] 139 | except RuntimeError: 140 | graphs = [] 141 | 142 | if hasattr(module, "forward1"): 143 | graphs.append(module.forward1.graph) 144 | 145 | for graph in graphs: 146 | for node in graph.findAllNodes("prim::Constant"): 147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 148 | node.copyAttributes(device_node) 149 | 150 | model.apply(patch_device) 151 | patch_device(model.encode_image) 152 | patch_device(model.encode_text) 153 | 154 | # patch dtype to float32 on CPU 155 | if str(device) == "cpu": 156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 158 | float_node = float_input.node() 159 | 160 | def patch_float(module): 161 | try: 162 | graphs = [module.graph] if hasattr(module, "graph") else [] 163 | except RuntimeError: 164 | graphs = [] 165 | 166 | if hasattr(module, "forward1"): 167 | graphs.append(module.forward1.graph) 168 | 169 | for graph in graphs: 170 | for node in graph.findAllNodes("aten::to"): 171 | inputs = list(node.inputs()) 172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 173 | if inputs[i].node()["value"] == 5: 174 | inputs[i].node().copyAttributes(float_node) 175 | 176 | model.apply(patch_float) 177 | patch_float(model.encode_image) 178 | patch_float(model.encode_text) 179 | 180 | model.float() 181 | 182 | return model, _transform(model.input_resolution.item()) 183 | 184 | 185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 186 | """ 187 | Returns the tokenized representation of given input string(s) 188 | 189 | Parameters 190 | ---------- 191 | texts : Union[str, List[str]] 192 | An input string or a list of input strings to tokenize 193 | 194 | context_length : int 195 | The context length to use; all CLIP models use 77 as the context length 196 | 197 | truncate: bool 198 | Whether to truncate the text in case its encoding is longer than the context length 199 | 200 | Returns 201 | ------- 202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 203 | """ 204 | if isinstance(texts, str): 205 | texts = [texts] 206 | 207 | sot_token = _tokenizer.encoder["<|startoftext|>"] 208 | eot_token = _tokenizer.encoder["<|endoftext|>"] 209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 211 | 212 | for i, tokens in enumerate(all_tokens): 213 | if len(tokens) > context_length: 214 | if truncate: 215 | tokens = tokens[:context_length] 216 | tokens[-1] = eot_token 217 | else: 218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 219 | result[i, :len(tokens)] = torch.tensor(tokens) 220 | 221 | return result 222 | -------------------------------------------------------------------------------- /clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /configs/datasets/caltech101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Caltech101" 3 | -------------------------------------------------------------------------------- /configs/datasets/dtd.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "DescribableTextures" 3 | -------------------------------------------------------------------------------- /configs/datasets/eurosat.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "EuroSAT" 3 | -------------------------------------------------------------------------------- /configs/datasets/fgvc_aircraft.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "FGVCAircraft" 3 | -------------------------------------------------------------------------------- /configs/datasets/food101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "Food101" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNet" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_a.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetA" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_r.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetR" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenet_sketch.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetSketch" 3 | -------------------------------------------------------------------------------- /configs/datasets/imagenetv2.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "ImageNetV2" 3 | -------------------------------------------------------------------------------- /configs/datasets/oxford_flowers.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordFlowers" -------------------------------------------------------------------------------- /configs/datasets/oxford_pets.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "OxfordPets" -------------------------------------------------------------------------------- /configs/datasets/stanford_cars.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "StanfordCars" 3 | -------------------------------------------------------------------------------- /configs/datasets/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "SUN397" 3 | -------------------------------------------------------------------------------- /configs/datasets/ucf101.yaml: -------------------------------------------------------------------------------- 1 | DATASET: 2 | NAME: "UCF101" 3 | -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 16 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 4 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 1 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | COCOOP: 33 | N_CTX: 8 34 | CTX_INIT: "" 35 | PREC: "fp16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn101_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN101" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" 34 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_ep50_ctxv1.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "RN50" 30 | 31 | TRAINER: 32 | COOP: 33 | CTX_INIT: "a photo of a" -------------------------------------------------------------------------------- /configs/trainers/CoOp/rn50_val.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 200 4 | TEST: 5 | BATCH_SIZE: 200 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | MODEL: 16 | BACKBONE: 17 | NAME: "RN50" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep100.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 100 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b16_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 200 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/CoOp/vit_b32_ep50.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 32 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 50 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 5 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/32" -------------------------------------------------------------------------------- /configs/trainers/IVLP/vit_b16_c2_ep5_batch4_2+2ctx.yaml: -------------------------------------------------------------------------------- 1 | # Deep independent V-L Prompting 2 | DATALOADER: 3 | TRAIN_X: 4 | BATCH_SIZE: 4 5 | TEST: 6 | BATCH_SIZE: 100 7 | NUM_WORKERS: 8 8 | 9 | INPUT: 10 | SIZE: (224, 224) 11 | INTERPOLATION: "bicubic" 12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.0035 19 | MAX_EPOCH: 5 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 20 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | IVLP: 34 | N_CTX_VISION: 2 35 | N_CTX_TEXT: 2 36 | CTX_INIT: "a photo of a" 37 | PREC: "fp16" 38 | PROMPT_DEPTH_VISION: 12 39 | PROMPT_DEPTH_TEXT: 12 -------------------------------------------------------------------------------- /configs/trainers/IVLP/vit_b16_c2_ep5_batch4_4ctx_language_only.yaml: -------------------------------------------------------------------------------- 1 | # Deep language prompting 2 | DATALOADER: 3 | TRAIN_X: 4 | BATCH_SIZE: 4 5 | TEST: 6 | BATCH_SIZE: 100 7 | NUM_WORKERS: 8 8 | 9 | INPUT: 10 | SIZE: (224, 224) 11 | INTERPOLATION: "bicubic" 12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.0025 19 | MAX_EPOCH: 5 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 20 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | IVLP: 34 | N_CTX_VISION: 0 35 | N_CTX_TEXT: 4 36 | CTX_INIT: "a photo of a" 37 | PREC: "fp16" 38 | PROMPT_DEPTH_VISION: 0 39 | PROMPT_DEPTH_TEXT: 12 40 | -------------------------------------------------------------------------------- /configs/trainers/MMP/sun397.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 4 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.002 18 | MAX_EPOCH: 20 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | # CHECKPOINT_FREQ: 5 27 | 28 | #TEST: 29 | # FINAL_MODEL: best_val 30 | # NO_TEST: False 31 | 32 | MODEL: 33 | BACKBONE: 34 | NAME: "ViT-B/16" 35 | 36 | TRAINER: 37 | MMP: 38 | N_CTX: 2 39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the" 40 | PREC: "fp16" 41 | TEXT_PROMPT_DEPTH: 9 42 | VISION_PROMPT_DEPTH: 9 43 | TEXT_PROMPT_NUMBER: 2 44 | VISION_PROMPT_NUMBER: 2 45 | HIERARCHICAL: True 46 | USECT: False 47 | # HIERARCHICAL: False 48 | # USECT: True -------------------------------------------------------------------------------- /configs/trainers/MMP/vit_b16_c2_ep5_batch4_2ctx.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0025 18 | MAX_EPOCH: 20 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | # CHECKPOINT_FREQ: 5 27 | 28 | #TEST: 29 | # FINAL_MODEL: best_val 30 | # NO_TEST: False 31 | 32 | MODEL: 33 | BACKBONE: 34 | NAME: "ViT-B/16" 35 | 36 | TRAINER: 37 | MMP: 38 | N_CTX: 2 39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the" 40 | PREC: "fp16" 41 | TEXT_PROMPT_DEPTH: 9 42 | VISION_PROMPT_DEPTH: 9 43 | TEXT_PROMPT_NUMBER: 4 44 | VISION_PROMPT_NUMBER: 4 45 | HIERARCHICAL: True 46 | USECT: False 47 | # HIERARCHICAL: False 48 | # USECT: True 49 | -------------------------------------------------------------------------------- /configs/trainers/MMP/vit_h.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0025 18 | MAX_EPOCH: 20 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | # CHECKPOINT_FREQ: 5 27 | 28 | #TEST: 29 | # FINAL_MODEL: best_val 30 | # NO_TEST: False 31 | 32 | MODEL: 33 | BACKBONE: 34 | NAME: "ViT-H/14" 35 | 36 | TRAINER: 37 | MMP: 38 | N_CTX: 2 39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the" 40 | PREC: "fp16" 41 | TEXT_PROMPT_DEPTH: 9 42 | VISION_PROMPT_DEPTH: 9 43 | TEXT_PROMPT_NUMBER: 4 44 | VISION_PROMPT_NUMBER: 4 45 | HIERARCHICAL: True 46 | USECT: False 47 | # HIERARCHICAL: False 48 | # USECT: True 49 | -------------------------------------------------------------------------------- /configs/trainers/MaPLe/vit_b16_c2_ep5_batch4_2ctx.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0035 18 | MAX_EPOCH: 10 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | MAPLE: 33 | N_CTX: 2 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | PROMPT_DEPTH: 9 -------------------------------------------------------------------------------- /configs/trainers/MaPLe/vit_b16_c2_ep5_batch4_2ctx_cross_datasets.yaml: -------------------------------------------------------------------------------- 1 | DATALOADER: 2 | TRAIN_X: 3 | BATCH_SIZE: 4 4 | TEST: 5 | BATCH_SIZE: 100 6 | NUM_WORKERS: 8 7 | 8 | INPUT: 9 | SIZE: (224, 224) 10 | INTERPOLATION: "bicubic" 11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 14 | 15 | OPTIM: 16 | NAME: "sgd" 17 | LR: 0.0026 18 | MAX_EPOCH: 2 19 | LR_SCHEDULER: "cosine" 20 | WARMUP_EPOCH: 1 21 | WARMUP_TYPE: "constant" 22 | WARMUP_CONS_LR: 1e-5 23 | 24 | TRAIN: 25 | PRINT_FREQ: 20 26 | 27 | MODEL: 28 | BACKBONE: 29 | NAME: "ViT-B/16" 30 | 31 | TRAINER: 32 | MAPLE: 33 | N_CTX: 2 34 | CTX_INIT: "a photo of a" 35 | PREC: "fp16" 36 | PROMPT_DEPTH: 3 -------------------------------------------------------------------------------- /configs/trainers/VPT/vit_b16_c2_ep5_batch4_4.yaml: -------------------------------------------------------------------------------- 1 | # Deep vision prompting 2 | DATALOADER: 3 | TRAIN_X: 4 | BATCH_SIZE: 4 5 | TEST: 6 | BATCH_SIZE: 100 7 | NUM_WORKERS: 8 8 | 9 | INPUT: 10 | SIZE: (224, 224) 11 | INTERPOLATION: "bicubic" 12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073] 13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711] 14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"] 15 | 16 | OPTIM: 17 | NAME: "sgd" 18 | LR: 0.0025 19 | MAX_EPOCH: 5 20 | LR_SCHEDULER: "cosine" 21 | WARMUP_EPOCH: 1 22 | WARMUP_TYPE: "constant" 23 | WARMUP_CONS_LR: 1e-5 24 | 25 | TRAIN: 26 | PRINT_FREQ: 20 27 | 28 | MODEL: 29 | BACKBONE: 30 | NAME: "ViT-B/16" 31 | 32 | TRAINER: 33 | VPT: 34 | N_CTX_VISION: 8 35 | CTX_INIT: "a photo of a" 36 | PREC: "fp16" 37 | PROMPT_DEPTH_VISION: 12 -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__init__.py -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/caltech101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/caltech101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/caltech101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/dtd.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtd.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/dtd.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/eurosat.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/eurosat.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/eurosat.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/fgvc_aircraft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/food101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/food101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/food101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_a.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_a.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_a.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_r.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_r.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_r.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_sketch.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenet_sketch.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_sketch.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenetv2.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/imagenetv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenetv2.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_flowers.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_flowers.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_flowers.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_pets.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/oxford_pets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_pets.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/stanford_cars.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/stanford_cars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/stanford_cars.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/sun397.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/sun397.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/sun397.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/ucf101.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/ucf101.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/ucf101.cpython-38.pyc -------------------------------------------------------------------------------- /datasets/caltech101.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import pickle 3 | import pickle5 as pickle 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | from .dtd import DescribableTextures as DTD 10 | 11 | IGNORED = ["BACKGROUND_Google", "Faces_easy"] 12 | NEW_CNAMES = { 13 | "airplanes": "airplane", 14 | "Faces": "face", 15 | "Leopards": "leopard", 16 | "Motorbikes": "motorbike", 17 | } 18 | 19 | 20 | @DATASET_REGISTRY.register() 21 | class Caltech101(DatasetBase): 22 | 23 | dataset_dir = "caltech-101" 24 | 25 | def __init__(self, cfg): 26 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 27 | self.dataset_dir = os.path.join(root, self.dataset_dir) 28 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories") 29 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json") 30 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 31 | mkdir_if_missing(self.split_fewshot_dir) 32 | 33 | if os.path.exists(self.split_path): 34 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 35 | else: 36 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES) 37 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 38 | 39 | num_shots = cfg.DATASET.NUM_SHOTS 40 | if num_shots >= 1: 41 | seed = cfg.SEED 42 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 43 | 44 | if os.path.exists(preprocessed): 45 | print(f"Loading preprocessed few-shot data from {preprocessed}") 46 | with open(preprocessed, "rb") as file: 47 | data = pickle.load(file) 48 | train, val = data["train"], data["val"] 49 | else: 50 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 51 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 52 | data = {"train": train, "val": val} 53 | print(f"Saving preprocessed few-shot data to {preprocessed}") 54 | with open(preprocessed, "wb") as file: 55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 56 | 57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 58 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 59 | 60 | super().__init__(train_x=train, val=val, test=test) 61 | -------------------------------------------------------------------------------- /datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | # import pickle5 as pickle 4 | import random 5 | 6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 7 | from dassl.utils import listdir_nohidden, mkdir_if_missing 8 | 9 | from .oxford_pets import OxfordPets 10 | 11 | 12 | @DATASET_REGISTRY.register() 13 | class DescribableTextures(DatasetBase): 14 | 15 | dataset_dir = "dtd" 16 | 17 | def __init__(self, cfg): 18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 19 | self.dataset_dir = os.path.join(root, self.dataset_dir) 20 | self.image_dir = os.path.join(self.dataset_dir, "images") 21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 27 | else: 28 | train, val, test = self.read_and_split_data(self.image_dir) 29 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 30 | 31 | num_shots = cfg.DATASET.NUM_SHOTS 32 | if num_shots >= 1: 33 | seed = cfg.SEED 34 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 35 | 36 | if os.path.exists(preprocessed): 37 | print(f"Loading preprocessed few-shot data from {preprocessed}") 38 | with open(preprocessed, "rb") as file: 39 | data = pickle.load(file) 40 | train, val = data["train"], data["val"] 41 | else: 42 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 43 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 44 | data = {"train": train, "val": val} 45 | print(f"Saving preprocessed few-shot data to {preprocessed}") 46 | with open(preprocessed, "wb") as file: 47 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 48 | 49 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 50 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 51 | 52 | super().__init__(train_x=train, val=val, test=test) 53 | 54 | @staticmethod 55 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None): 56 | # The data are supposed to be organized into the following structure 57 | # ============= 58 | # images/ 59 | # dog/ 60 | # cat/ 61 | # horse/ 62 | # ============= 63 | categories = listdir_nohidden(image_dir) 64 | categories = [c for c in categories if c not in ignored] 65 | categories.sort() 66 | 67 | p_tst = 1 - p_trn - p_val 68 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test") 69 | 70 | def _collate(ims, y, c): 71 | items = [] 72 | for im in ims: 73 | item = Datum(impath=im, label=y, classname=c) # is already 0-based 74 | items.append(item) 75 | return items 76 | 77 | train, val, test = [], [], [] 78 | for label, category in enumerate(categories): 79 | category_dir = os.path.join(image_dir, category) 80 | images = listdir_nohidden(category_dir) 81 | images = [os.path.join(category_dir, im) for im in images] 82 | random.shuffle(images) 83 | n_total = len(images) 84 | n_train = round(n_total * p_trn) 85 | n_val = round(n_total * p_val) 86 | n_test = n_total - n_train - n_val 87 | assert n_train > 0 and n_val > 0 and n_test > 0 88 | 89 | if new_cnames is not None and category in new_cnames: 90 | category = new_cnames[category] 91 | 92 | train.extend(_collate(images[:n_train], label, category)) 93 | val.extend(_collate(images[n_train : n_train + n_val], label, category)) 94 | test.extend(_collate(images[n_train + n_val :], label, category)) 95 | 96 | return train, val, test 97 | -------------------------------------------------------------------------------- /datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | # import pickle 3 | import pickle5 as pickle 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | from .dtd import DescribableTextures as DTD 10 | 11 | NEW_CNAMES = { 12 | "AnnualCrop": "Annual Crop Land", 13 | "Forest": "Forest", 14 | "HerbaceousVegetation": "Herbaceous Vegetation Land", 15 | "Highway": "Highway or Road", 16 | "Industrial": "Industrial Buildings", 17 | "Pasture": "Pasture Land", 18 | "PermanentCrop": "Permanent Crop Land", 19 | "Residential": "Residential Buildings", 20 | "River": "River", 21 | "SeaLake": "Sea or Lake", 22 | } 23 | 24 | 25 | @DATASET_REGISTRY.register() 26 | class EuroSAT(DatasetBase): 27 | 28 | dataset_dir = "eurosat" 29 | 30 | def __init__(self, cfg): 31 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 32 | self.dataset_dir = os.path.join(root, self.dataset_dir) 33 | self.image_dir = os.path.join(self.dataset_dir, "2750") 34 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json") 35 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 36 | mkdir_if_missing(self.split_fewshot_dir) 37 | 38 | if os.path.exists(self.split_path): 39 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 40 | else: 41 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES) 42 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 43 | 44 | num_shots = cfg.DATASET.NUM_SHOTS 45 | if num_shots >= 1: 46 | seed = cfg.SEED 47 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 48 | 49 | if os.path.exists(preprocessed): 50 | print(f"Loading preprocessed few-shot data from {preprocessed}") 51 | with open(preprocessed, "rb") as file: 52 | data = pickle.load(file) 53 | train, val = data["train"], data["val"] 54 | else: 55 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 56 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 57 | data = {"train": train, "val": val} 58 | print(f"Saving preprocessed few-shot data to {preprocessed}") 59 | with open(preprocessed, "wb") as file: 60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 61 | 62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 63 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 64 | 65 | super().__init__(train_x=train, val=val, test=test) 66 | 67 | def update_classname(self, dataset_old): 68 | dataset_new = [] 69 | for item_old in dataset_old: 70 | cname_old = item_old.classname 71 | cname_new = NEW_CLASSNAMES[cname_old] 72 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new) 73 | dataset_new.append(item_new) 74 | return dataset_new 75 | -------------------------------------------------------------------------------- /datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class FGVCAircraft(DatasetBase): 12 | 13 | dataset_dir = "fgvc_aircraft" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "images") 19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 20 | mkdir_if_missing(self.split_fewshot_dir) 21 | 22 | classnames = [] 23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f: 24 | lines = f.readlines() 25 | for line in lines: 26 | classnames.append(line.strip()) 27 | cname2lab = {c: i for i, c in enumerate(classnames)} 28 | 29 | train = self.read_data(cname2lab, "images_variant_train.txt") 30 | val = self.read_data(cname2lab, "images_variant_val.txt") 31 | test = self.read_data(cname2lab, "images_variant_test.txt") 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, cname2lab, split_file): 57 | filepath = os.path.join(self.dataset_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip().split(" ") 64 | imname = line[0] + ".jpg" 65 | classname = " ".join(line[1:]) 66 | impath = os.path.join(self.image_dir, imname) 67 | label = cname2lab[classname] 68 | item = Datum(impath=impath, label=label, classname=classname) 69 | items.append(item) 70 | 71 | return items 72 | -------------------------------------------------------------------------------- /datasets/food101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | from .dtd import DescribableTextures as DTD 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class Food101(DatasetBase): 13 | 14 | dataset_dir = "food-101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | train, val, test = DTD.read_and_split_data(self.image_dir) 28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 29 | 30 | num_shots = cfg.DATASET.NUM_SHOTS 31 | if num_shots >= 1: 32 | seed = cfg.SEED 33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 34 | 35 | if os.path.exists(preprocessed): 36 | print(f"Loading preprocessed few-shot data from {preprocessed}") 37 | with open(preprocessed, "rb") as file: 38 | data = pickle.load(file) 39 | train, val = data["train"], data["val"] 40 | else: 41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 43 | data = {"train": train, "val": val} 44 | print(f"Saving preprocessed few-shot data to {preprocessed}") 45 | with open(preprocessed, "wb") as file: 46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 47 | 48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 50 | 51 | super().__init__(train_x=train, val=val, test=test) 52 | -------------------------------------------------------------------------------- /datasets/imagenet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from collections import OrderedDict 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import listdir_nohidden, mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNet(DatasetBase): 13 | 14 | dataset_dir = "imagenet" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.preprocessed): 25 | with open(self.preprocessed, "rb") as f: 26 | preprocessed = pickle.load(f) 27 | train = preprocessed["train"] 28 | test = preprocessed["test"] 29 | else: 30 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 31 | classnames = self.read_classnames(text_file) 32 | train = self.read_data(classnames, "train") 33 | # Follow standard practice to perform evaluation on the val set 34 | # Also used as the val set (so evaluate the last-step model) 35 | test = self.read_data(classnames, "val") 36 | 37 | preprocessed = {"train": train, "test": test} 38 | with open(self.preprocessed, "wb") as f: 39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train = data["train"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | data = {"train": train} 54 | print(f"Saving preprocessed few-shot data to {preprocessed}") 55 | with open(preprocessed, "wb") as file: 56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 57 | 58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample) 60 | 61 | super().__init__(train_x=train, val=test, test=test) 62 | 63 | @staticmethod 64 | def read_classnames(text_file): 65 | """Return a dictionary containing 66 | key-value pairs of : . 67 | """ 68 | classnames = OrderedDict() 69 | with open(text_file, "r") as f: 70 | lines = f.readlines() 71 | for line in lines: 72 | line = line.strip().split(" ") 73 | folder = line[0] 74 | classname = " ".join(line[1:]) 75 | classnames[folder] = classname 76 | return classnames 77 | 78 | def read_data(self, classnames, split_dir): 79 | split_dir = os.path.join(self.image_dir, split_dir) 80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir()) 81 | items = [] 82 | 83 | for label, folder in enumerate(folders): 84 | imnames = listdir_nohidden(os.path.join(split_dir, folder)) 85 | classname = classnames[folder] 86 | for imname in imnames: 87 | impath = os.path.join(split_dir, folder, imname) 88 | item = Datum(impath=impath, label=label, classname=classname) 89 | items.append(item) 90 | 91 | return items 92 | -------------------------------------------------------------------------------- /datasets/imagenet_a.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetA(DatasetBase): 13 | """ImageNet-A(dversarial). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-adversarial" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_r.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | TO_BE_IGNORED = ["README.txt"] 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class ImageNetR(DatasetBase): 13 | """ImageNet-R(endition). 14 | 15 | This dataset is used for testing only. 16 | """ 17 | 18 | dataset_dir = "imagenet-rendition" 19 | 20 | def __init__(self, cfg): 21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 22 | self.dataset_dir = os.path.join(root, self.dataset_dir) 23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r") 24 | 25 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 26 | classnames = ImageNet.read_classnames(text_file) 27 | 28 | data = self.read_data(classnames) 29 | 30 | super().__init__(train_x=data, test=data) 31 | 32 | def read_data(self, classnames): 33 | image_dir = self.image_dir 34 | folders = listdir_nohidden(image_dir, sort=True) 35 | folders = [f for f in folders if f not in TO_BE_IGNORED] 36 | items = [] 37 | 38 | for label, folder in enumerate(folders): 39 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(image_dir, folder, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/imagenet_sketch.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetSketch(DatasetBase): 11 | """ImageNet-Sketch. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenet-sketch" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "images") 22 | 23 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 24 | classnames = ImageNet.read_classnames(text_file) 25 | 26 | data = self.read_data(classnames) 27 | 28 | super().__init__(train_x=data, test=data) 29 | 30 | def read_data(self, classnames): 31 | image_dir = self.image_dir 32 | folders = listdir_nohidden(image_dir, sort=True) 33 | items = [] 34 | 35 | for label, folder in enumerate(folders): 36 | imnames = listdir_nohidden(os.path.join(image_dir, folder)) 37 | classname = classnames[folder] 38 | for imname in imnames: 39 | impath = os.path.join(image_dir, folder, imname) 40 | item = Datum(impath=impath, label=label, classname=classname) 41 | items.append(item) 42 | 43 | return items 44 | -------------------------------------------------------------------------------- /datasets/imagenetv2.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 4 | from dassl.utils import listdir_nohidden 5 | 6 | from .imagenet import ImageNet 7 | 8 | 9 | @DATASET_REGISTRY.register() 10 | class ImageNetV2(DatasetBase): 11 | """ImageNetV2. 12 | 13 | This dataset is used for testing only. 14 | """ 15 | 16 | dataset_dir = "imagenetv2" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | image_dir = "imagenetv2-matched-frequency-format-val" 22 | self.image_dir = os.path.join(self.dataset_dir, image_dir) 23 | 24 | text_file = os.path.join(self.dataset_dir, "classnames.txt") 25 | classnames = ImageNet.read_classnames(text_file) 26 | 27 | data = self.read_data(classnames) 28 | 29 | super().__init__(train_x=data, test=data) 30 | 31 | def read_data(self, classnames): 32 | image_dir = self.image_dir 33 | folders = list(classnames.keys()) 34 | items = [] 35 | 36 | for label in range(1000): 37 | class_dir = os.path.join(image_dir, str(label)) 38 | imnames = listdir_nohidden(class_dir) 39 | folder = folders[label] 40 | classname = classnames[folder] 41 | for imname in imnames: 42 | impath = os.path.join(class_dir, imname) 43 | item = Datum(impath=impath, label=label, classname=classname) 44 | items.append(item) 45 | 46 | return items 47 | -------------------------------------------------------------------------------- /datasets/oxford_flowers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | from scipy.io import loadmat 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, mkdir_if_missing 9 | 10 | from .oxford_pets import OxfordPets 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class OxfordFlowers(DatasetBase): 15 | 16 | dataset_dir = "oxford_flowers" 17 | 18 | def __init__(self, cfg): 19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 20 | self.dataset_dir = os.path.join(root, self.dataset_dir) 21 | self.image_dir = os.path.join(self.dataset_dir, "jpg") 22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat") 23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json") 24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json") 25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 26 | mkdir_if_missing(self.split_fewshot_dir) 27 | 28 | if os.path.exists(self.split_path): 29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 30 | else: 31 | train, val, test = self.read_data() 32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self): 58 | tracker = defaultdict(list) 59 | label_file = loadmat(self.label_file)["labels"][0] 60 | for i, label in enumerate(label_file): 61 | imname = f"image_{str(i + 1).zfill(5)}.jpg" 62 | impath = os.path.join(self.image_dir, imname) 63 | label = int(label) 64 | tracker[label].append(impath) 65 | 66 | print("Splitting data into 50% train, 20% val, and 30% test") 67 | 68 | def _collate(ims, y, c): 69 | items = [] 70 | for im in ims: 71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label 72 | items.append(item) 73 | return items 74 | 75 | lab2cname = read_json(self.lab2cname_file) 76 | train, val, test = [], [], [] 77 | for label, impaths in tracker.items(): 78 | random.shuffle(impaths) 79 | n_total = len(impaths) 80 | n_train = round(n_total * 0.5) 81 | n_val = round(n_total * 0.2) 82 | n_test = n_total - n_train - n_val 83 | assert n_train > 0 and n_val > 0 and n_test > 0 84 | cname = lab2cname[str(label)] 85 | train.extend(_collate(impaths[:n_train], label, cname)) 86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname)) 87 | test.extend(_collate(impaths[n_train + n_val :], label, cname)) 88 | 89 | return train, val, test 90 | -------------------------------------------------------------------------------- /datasets/oxford_pets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import math 4 | import random 5 | from collections import defaultdict 6 | 7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 8 | from dassl.utils import read_json, write_json, mkdir_if_missing 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class OxfordPets(DatasetBase): 13 | 14 | dataset_dir = "oxford_pets" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "images") 20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations") 21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json") 22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 23 | mkdir_if_missing(self.split_fewshot_dir) 24 | 25 | if os.path.exists(self.split_path): 26 | train, val, test = self.read_split(self.split_path, self.image_dir) 27 | else: 28 | trainval = self.read_data(split_file="trainval.txt") 29 | test = self.read_data(split_file="test.txt") 30 | train, val = self.split_trainval(trainval) 31 | self.save_split(train, val, test, self.split_path, self.image_dir) 32 | 33 | num_shots = cfg.DATASET.NUM_SHOTS 34 | if num_shots >= 1: 35 | seed = cfg.SEED 36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 37 | 38 | if os.path.exists(preprocessed): 39 | print(f"Loading preprocessed few-shot data from {preprocessed}") 40 | with open(preprocessed, "rb") as file: 41 | data = pickle.load(file) 42 | train, val = data["train"], data["val"] 43 | else: 44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 46 | data = {"train": train, "val": val} 47 | print(f"Saving preprocessed few-shot data to {preprocessed}") 48 | with open(preprocessed, "wb") as file: 49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 50 | 51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 52 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample) 53 | 54 | super().__init__(train_x=train, val=val, test=test) 55 | 56 | def read_data(self, split_file): 57 | filepath = os.path.join(self.anno_dir, split_file) 58 | items = [] 59 | 60 | with open(filepath, "r") as f: 61 | lines = f.readlines() 62 | for line in lines: 63 | line = line.strip() 64 | imname, label, species, _ = line.split(" ") 65 | breed = imname.split("_")[:-1] 66 | breed = "_".join(breed) 67 | breed = breed.lower() 68 | imname += ".jpg" 69 | impath = os.path.join(self.image_dir, imname) 70 | label = int(label) - 1 # convert to 0-based index 71 | item = Datum(impath=impath, label=label, classname=breed) 72 | items.append(item) 73 | 74 | return items 75 | 76 | @staticmethod 77 | def split_trainval(trainval, p_val=0.2): 78 | p_trn = 1 - p_val 79 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val") 80 | tracker = defaultdict(list) 81 | for idx, item in enumerate(trainval): 82 | label = item.label 83 | tracker[label].append(idx) 84 | 85 | train, val = [], [] 86 | for label, idxs in tracker.items(): 87 | n_val = round(len(idxs) * p_val) 88 | assert n_val > 0 89 | random.shuffle(idxs) 90 | for n, idx in enumerate(idxs): 91 | item = trainval[idx] 92 | if n < n_val: 93 | val.append(item) 94 | else: 95 | train.append(item) 96 | 97 | return train, val 98 | 99 | @staticmethod 100 | def save_split(train, val, test, filepath, path_prefix): 101 | def _extract(items): 102 | out = [] 103 | for item in items: 104 | impath = item.impath 105 | label = item.label 106 | classname = item.classname 107 | impath = impath.replace(path_prefix, "") 108 | if impath.startswith("/"): 109 | impath = impath[1:] 110 | out.append((impath, label, classname)) 111 | return out 112 | 113 | train = _extract(train) 114 | val = _extract(val) 115 | test = _extract(test) 116 | 117 | split = {"train": train, "val": val, "test": test} 118 | 119 | write_json(split, filepath) 120 | print(f"Saved split to {filepath}") 121 | 122 | @staticmethod 123 | def read_split(filepath, path_prefix): 124 | def _convert(items): 125 | out = [] 126 | for impath, label, classname in items: 127 | impath = os.path.join(path_prefix, impath) 128 | item = Datum(impath=impath, label=int(label), classname=classname) 129 | out.append(item) 130 | return out 131 | 132 | print(f"Reading split from {filepath}") 133 | split = read_json(filepath) 134 | train = _convert(split["train"]) 135 | val = _convert(split["val"]) 136 | test = _convert(split["test"]) 137 | 138 | return train, val, test 139 | 140 | @staticmethod 141 | def subsample_classes(*args, subsample="all"): 142 | """Divide classes into two groups. The first group 143 | represents base classes while the second group represents 144 | new classes. 145 | 146 | Args: 147 | args: a list of datasets, e.g. train, val and test. 148 | subsample (str): what classes to subsample. 149 | """ 150 | assert subsample in ["all", "base", "new"] 151 | 152 | if subsample == "all": 153 | return args 154 | 155 | dataset = args[0] 156 | labels = set() 157 | for item in dataset: 158 | labels.add(item.label) 159 | labels = list(labels) 160 | labels.sort() 161 | n = len(labels) 162 | # Divide classes into two halves 163 | m = math.ceil(n / 2) 164 | 165 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!") 166 | if subsample == "base": 167 | selected = labels[:m] # take the first half 168 | else: 169 | selected = labels[m:] # take the second half 170 | relabeler = {y: y_new for y_new, y in enumerate(selected)} 171 | 172 | output = [] 173 | for dataset in args: 174 | dataset_new = [] 175 | for item in dataset: 176 | if item.label not in selected: 177 | continue 178 | item_new = Datum( 179 | impath=item.impath, 180 | label=relabeler[item.label], 181 | classname=item.classname 182 | ) 183 | dataset_new.append(item_new) 184 | output.append(dataset_new) 185 | 186 | return output 187 | -------------------------------------------------------------------------------- /datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from scipy.io import loadmat 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class StanfordCars(DatasetBase): 13 | 14 | dataset_dir = "stanford_cars" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir) 25 | else: 26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat") 27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat") 28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat") 29 | trainval = self.read_data("cars_train", trainval_file, meta_file) 30 | test = self.read_data("cars_test", test_file, meta_file) 31 | train, val = OxfordPets.split_trainval(trainval) 32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir) 33 | 34 | num_shots = cfg.DATASET.NUM_SHOTS 35 | if num_shots >= 1: 36 | seed = cfg.SEED 37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 38 | 39 | if os.path.exists(preprocessed): 40 | print(f"Loading preprocessed few-shot data from {preprocessed}") 41 | with open(preprocessed, "rb") as file: 42 | data = pickle.load(file) 43 | train, val = data["train"], data["val"] 44 | else: 45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 47 | data = {"train": train, "val": val} 48 | print(f"Saving preprocessed few-shot data to {preprocessed}") 49 | with open(preprocessed, "wb") as file: 50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 51 | 52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 54 | 55 | super().__init__(train_x=train, val=val, test=test) 56 | 57 | def read_data(self, image_dir, anno_file, meta_file): 58 | anno_file = loadmat(anno_file)["annotations"][0] 59 | meta_file = loadmat(meta_file)["class_names"][0] 60 | items = [] 61 | 62 | for i in range(len(anno_file)): 63 | imname = anno_file[i]["fname"][0] 64 | impath = os.path.join(self.dataset_dir, image_dir, imname) 65 | label = anno_file[i]["class"][0, 0] 66 | label = int(label) - 1 # convert to 0-based index 67 | classname = meta_file[label][0] 68 | names = classname.split(" ") 69 | year = names.pop(-1) 70 | names.insert(0, year) 71 | classname = " ".join(names) 72 | item = Datum(impath=impath, label=label, classname=classname) 73 | items.append(item) 74 | 75 | return items 76 | -------------------------------------------------------------------------------- /datasets/sun397.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | 4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 5 | from dassl.utils import mkdir_if_missing 6 | 7 | from .oxford_pets import OxfordPets 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class SUN397(DatasetBase): 12 | 13 | dataset_dir = "sun397" 14 | 15 | def __init__(self, cfg): 16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 17 | self.dataset_dir = os.path.join(root, self.dataset_dir) 18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397") 19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json") 20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 21 | mkdir_if_missing(self.split_fewshot_dir) 22 | 23 | if os.path.exists(self.split_path): 24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 25 | else: 26 | classnames = [] 27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f: 28 | lines = f.readlines() 29 | for line in lines: 30 | line = line.strip()[1:] # remove / 31 | classnames.append(line) 32 | cname2lab = {c: i for i, c in enumerate(classnames)} 33 | trainval = self.read_data(cname2lab, "Training_01.txt") 34 | test = self.read_data(cname2lab, "Testing_01.txt") 35 | train, val = OxfordPets.split_trainval(trainval) 36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 37 | 38 | num_shots = cfg.DATASET.NUM_SHOTS 39 | if num_shots >= 1: 40 | seed = cfg.SEED 41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 42 | 43 | if os.path.exists(preprocessed): 44 | print(f"Loading preprocessed few-shot data from {preprocessed}") 45 | with open(preprocessed, "rb") as file: 46 | data = pickle.load(file) 47 | train, val = data["train"], data["val"] 48 | else: 49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 51 | data = {"train": train, "val": val} 52 | print(f"Saving preprocessed few-shot data to {preprocessed}") 53 | with open(preprocessed, "wb") as file: 54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 55 | 56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 58 | 59 | super().__init__(train_x=train, val=val, test=test) 60 | 61 | def read_data(self, cname2lab, text_file): 62 | text_file = os.path.join(self.dataset_dir, text_file) 63 | items = [] 64 | 65 | with open(text_file, "r") as f: 66 | lines = f.readlines() 67 | for line in lines: 68 | imname = line.strip()[1:] # remove / 69 | classname = os.path.dirname(imname) 70 | label = cname2lab[classname] 71 | impath = os.path.join(self.image_dir, imname) 72 | 73 | names = classname.split("/")[1:] # remove 1st letter 74 | names = names[::-1] # put words like indoor/outdoor at first 75 | classname = " ".join(names) 76 | 77 | item = Datum(impath=impath, label=label, classname=classname) 78 | items.append(item) 79 | 80 | return items 81 | -------------------------------------------------------------------------------- /datasets/ucf101.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import re 4 | 5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase 6 | from dassl.utils import mkdir_if_missing 7 | 8 | from .oxford_pets import OxfordPets 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class UCF101(DatasetBase): 13 | 14 | dataset_dir = "ucf101" 15 | 16 | def __init__(self, cfg): 17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT)) 18 | self.dataset_dir = os.path.join(root, self.dataset_dir) 19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes") 20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json") 21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot") 22 | mkdir_if_missing(self.split_fewshot_dir) 23 | 24 | if os.path.exists(self.split_path): 25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir) 26 | else: 27 | cname2lab = {} 28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt") 29 | with open(filepath, "r") as f: 30 | lines = f.readlines() 31 | for line in lines: 32 | label, classname = line.strip().split(" ") 33 | label = int(label) - 1 # conver to 0-based index 34 | cname2lab[classname] = label 35 | 36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt") 37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt") 38 | train, val = OxfordPets.split_trainval(trainval) 39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir) 40 | 41 | num_shots = cfg.DATASET.NUM_SHOTS 42 | if num_shots >= 1: 43 | seed = cfg.SEED 44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl") 45 | 46 | if os.path.exists(preprocessed): 47 | print(f"Loading preprocessed few-shot data from {preprocessed}") 48 | with open(preprocessed, "rb") as file: 49 | data = pickle.load(file) 50 | train, val = data["train"], data["val"] 51 | else: 52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots) 53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4)) 54 | data = {"train": train, "val": val} 55 | print(f"Saving preprocessed few-shot data to {preprocessed}") 56 | with open(preprocessed, "wb") as file: 57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL) 58 | 59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES 60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample) 61 | 62 | super().__init__(train_x=train, val=val, test=test) 63 | 64 | def read_data(self, cname2lab, text_file): 65 | text_file = os.path.join(self.dataset_dir, text_file) 66 | items = [] 67 | 68 | with open(text_file, "r") as f: 69 | lines = f.readlines() 70 | for line in lines: 71 | line = line.strip().split(" ")[0] # trainlist: filename, label 72 | action, filename = line.split("/") 73 | label = cname2lab[action] 74 | 75 | elements = re.findall("[A-Z][^A-Z]*", action) 76 | renamed_action = "_".join(elements) 77 | 78 | filename = filename.replace(".avi", ".jpg") 79 | impath = os.path.join(self.image_dir, renamed_action, filename) 80 | 81 | item = Datum(impath=impath, label=label, classname=renamed_action) 82 | items.append(item) 83 | 84 | return items 85 | -------------------------------------------------------------------------------- /images/ALIGN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/images/ALIGN.png -------------------------------------------------------------------------------- /parse_test_res.py: -------------------------------------------------------------------------------- 1 | """ 2 | Goal 3 | --- 4 | 1. Read test results from log.txt files 5 | 2. Compute mean and std across different folders (seeds) 6 | 7 | Usage 8 | --- 9 | Assume the output files are saved under output/my_experiment, 10 | which contains results of different seeds, e.g., 11 | 12 | my_experiment/ 13 | seed1/ 14 | log.txt 15 | seed2/ 16 | log.txt 17 | seed3/ 18 | log.txt 19 | 20 | Run the following command from the root directory: 21 | 22 | $ python tools/parse_test_res.py output/my_experiment 23 | 24 | Add --ci95 to the argument if you wanna get 95% confidence 25 | interval instead of standard deviation: 26 | 27 | $ python tools/parse_test_res.py output/my_experiment --ci95 28 | 29 | If my_experiment/ has the following structure, 30 | 31 | my_experiment/ 32 | exp-1/ 33 | seed1/ 34 | log.txt 35 | ... 36 | seed2/ 37 | log.txt 38 | ... 39 | seed3/ 40 | log.txt 41 | ... 42 | exp-2/ 43 | ... 44 | exp-3/ 45 | ... 46 | 47 | Run 48 | 49 | $ python tools/parse_test_res.py output/my_experiment --multi-exp 50 | """ 51 | import re 52 | import numpy as np 53 | import os.path as osp 54 | import argparse 55 | from collections import OrderedDict, defaultdict 56 | 57 | from dassl.utils import check_isfile, listdir_nohidden 58 | 59 | 60 | def compute_ci95(res): 61 | return 1.96 * np.std(res) / np.sqrt(len(res)) 62 | 63 | 64 | def parse_function(*metrics, directory="", args=None, end_signal=None): 65 | print(f"Parsing files in {directory}") 66 | subdirs = listdir_nohidden(directory, sort=True) 67 | 68 | outputs = [] 69 | 70 | for subdir in subdirs: 71 | fpath = osp.join(directory, subdir, "log.txt") 72 | assert check_isfile(fpath) 73 | good_to_go = False 74 | output = OrderedDict() 75 | 76 | with open(fpath, "r") as f: 77 | lines = f.readlines() 78 | 79 | for line in lines: 80 | line = line.strip() 81 | 82 | if line == end_signal: 83 | good_to_go = True 84 | 85 | for metric in metrics: 86 | match = metric["regex"].search(line) 87 | if match and good_to_go: 88 | if "file" not in output: 89 | output["file"] = fpath 90 | num = float(match.group(1)) 91 | name = metric["name"] 92 | output[name] = num 93 | 94 | if output: 95 | outputs.append(output) 96 | 97 | assert len(outputs) > 0, f"Nothing found in {directory}" 98 | 99 | metrics_results = defaultdict(list) 100 | 101 | for output in outputs: 102 | msg = "" 103 | for key, value in output.items(): 104 | if isinstance(value, float): 105 | msg += f"{key}: {value:.2f}%. " 106 | else: 107 | msg += f"{key}: {value}. " 108 | if key != "file": 109 | metrics_results[key].append(value) 110 | print(msg) 111 | 112 | output_results = OrderedDict() 113 | 114 | print("===") 115 | print(f"Summary of directory: {directory}") 116 | for key, values in metrics_results.items(): 117 | avg = np.mean(values) 118 | std = compute_ci95(values) if args.ci95 else np.std(values) 119 | print(f"* {key}: {avg:.2f}% +- {std:.2f}%") 120 | output_results[key] = avg 121 | print("===") 122 | 123 | return output_results 124 | 125 | 126 | def main(args, end_signal): 127 | metric = { 128 | "name": args.keyword, 129 | "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"), 130 | } 131 | 132 | if args.multi_exp: 133 | final_results = defaultdict(list) 134 | 135 | for directory in listdir_nohidden(args.directory, sort=True): 136 | directory = osp.join(args.directory, directory) 137 | results = parse_function( 138 | metric, directory=directory, args=args, end_signal=end_signal 139 | ) 140 | 141 | for key, value in results.items(): 142 | final_results[key].append(value) 143 | 144 | print("Average performance") 145 | for key, values in final_results.items(): 146 | avg = np.mean(values) 147 | print(f"* {key}: {avg:.2f}%") 148 | 149 | else: 150 | parse_function( 151 | metric, directory=args.directory, args=args, end_signal=end_signal 152 | ) 153 | 154 | 155 | if __name__ == "__main__": 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument("directory", type=str, help="path to directory") 158 | parser.add_argument( 159 | "--ci95", action="store_true", help=r"compute 95\% confidence interval" 160 | ) 161 | parser.add_argument("--test-log", action="store_true", help="parse test-only logs") 162 | parser.add_argument( 163 | "--multi-exp", action="store_true", help="parse multiple experiments" 164 | ) 165 | parser.add_argument( 166 | "--keyword", default="accuracy", type=str, help="which keyword to extract" 167 | ) 168 | args = parser.parse_args() 169 | 170 | end_signal = "Finished training" 171 | if args.test_log: 172 | end_signal = "=> result" 173 | 174 | main(args, end_signal) 175 | -------------------------------------------------------------------------------- /scripts/cocoop/base2new_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=CoCoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | LOADEP=10 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Evaluating model" 23 | echo "Results are available in ${DIR}. Resuming..." 24 | 25 | python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | --model-dir ${MODEL_DIR} \ 33 | --load-epoch ${LOADEP} \ 34 | --eval-only \ 35 | DATASET.NUM_SHOTS ${SHOTS} \ 36 | DATASET.SUBSAMPLE_CLASSES ${SUB} 37 | 38 | else 39 | echo "Evaluating model" 40 | echo "Runing the first phase job and save the output to ${DIR}" 41 | 42 | python train.py \ 43 | --root ${DATA} \ 44 | --seed ${SEED} \ 45 | --trainer ${TRAINER} \ 46 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 48 | --output-dir ${DIR} \ 49 | --model-dir ${MODEL_DIR} \ 50 | --load-epoch ${LOADEP} \ 51 | --eval-only \ 52 | DATASET.NUM_SHOTS ${SHOTS} \ 53 | DATASET.SUBSAMPLE_CLASSES ${SUB} 54 | fi -------------------------------------------------------------------------------- /scripts/cocoop/base2new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=CoCoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | 15 | 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Resuming..." 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | else 29 | echo "Run this job and save the output to ${DIR}" 30 | python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 10 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/cocoop/xd_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoCoOp 8 | 9 | DATASET=imagenet 10 | SEED=$1 11 | 12 | CFG=vit_b16_c4_ep10_batch1_ctxv1 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/coop/basenewtrain.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="/data4/wds/dataset/CoOpData/" 7 | TRAINER=CoOp 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_ep50 13 | SHOTS=16 14 | NCTX=16 15 | 16 | 17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | if [ -d "$DIR" ]; then 19 | echo "Results are available in ${DIR}. Resuming..." 20 | python train.py \ 21 | --root ${DATA} \ 22 | --seed ${SEED} \ 23 | --trainer ${TRAINER} \ 24 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 26 | --output-dir ${DIR} \ 27 | TRAINER.COOP.N_CTX ${NCTX} \ 28 | DATASET.NUM_SHOTS ${SHOTS} \ 29 | DATASET.SUBSAMPLE_CLASSES base 30 | else 31 | echo "Run this job and save the output to ${DIR}" 32 | python train.py \ 33 | --root ${DATA} \ 34 | --seed ${SEED} \ 35 | --trainer ${TRAINER} \ 36 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 37 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 38 | --output-dir ${DIR} \ 39 | TRAINER.COOP.N_CTX ${NCTX} \ 40 | DATASET.NUM_SHOTS ${SHOTS} \ 41 | DATASET.SUBSAMPLE_CLASSES base 42 | fi -------------------------------------------------------------------------------- /scripts/coop/eval.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=CoOp 8 | SHOTS=16 9 | NCTX=16 10 | CSC=False 11 | CTP=end 12 | 13 | DATASET=$1 14 | CFG=$2 15 | 16 | for SEED in 1 2 3 17 | do 18 | python train.py \ 19 | --root ${DATA} \ 20 | --seed ${SEED} \ 21 | --trainer ${TRAINER} \ 22 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 24 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \ 25 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \ 26 | --load-epoch 50 \ 27 | --eval-only \ 28 | TRAINER.COOP.N_CTX ${NCTX} \ 29 | TRAINER.COOP.CSC ${CSC} \ 30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} 31 | done -------------------------------------------------------------------------------- /scripts/coop/main.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="/data4/wds/dataset/CoOpData/" 7 | TRAINER=CoOp 8 | 9 | DATASET=eurosat 10 | CFG=rn50 # config file 11 | CTP=end # class token position (end or middle) 12 | NCTX=16 # number of context tokens 13 | SHOTS=8 # number of shots (1, 2, 4, 8, 16) 14 | CSC=False # class-specific context (False or True) 15 | 16 | for SEED in 1 17 | do 18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} 19 | if [ -d "$DIR" ]; then 20 | echo "Results are available in ${DIR}. Skip this job" 21 | else 22 | echo "Run this job and save the output to ${DIR}" 23 | python train.py \ 24 | --root ${DATA} \ 25 | --seed ${SEED} \ 26 | --trainer ${TRAINER} \ 27 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 29 | --output-dir ${DIR} \ 30 | TRAINER.COOP.N_CTX ${NCTX} \ 31 | TRAINER.COOP.CSC ${CSC} \ 32 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \ 33 | DATASET.NUM_SHOTS ${SHOTS} 34 | fi 35 | done -------------------------------------------------------------------------------- /scripts/independent-vlp/base2new_test_ivlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx 13 | SHOTS=16 14 | LOADEP=5 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Evaluating model" 23 | echo "Results are available in ${DIR}. Resuming..." 24 | 25 | python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | --model-dir ${MODEL_DIR} \ 33 | --load-epoch ${LOADEP} \ 34 | --eval-only \ 35 | DATASET.NUM_SHOTS ${SHOTS} \ 36 | DATASET.SUBSAMPLE_CLASSES ${SUB} 37 | 38 | else 39 | echo "Evaluating model" 40 | echo "Runing the first phase job and save the output to ${DIR}" 41 | 42 | python train.py \ 43 | --root ${DATA} \ 44 | --seed ${SEED} \ 45 | --trainer ${TRAINER} \ 46 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 48 | --output-dir ${DIR} \ 49 | --model-dir ${MODEL_DIR} \ 50 | --load-epoch ${LOADEP} \ 51 | --eval-only \ 52 | DATASET.NUM_SHOTS ${SHOTS} \ 53 | DATASET.SUBSAMPLE_CLASSES ${SUB} 54 | fi -------------------------------------------------------------------------------- /scripts/independent-vlp/base2new_train_ivlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx 13 | SHOTS=16 14 | 15 | 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Resuming..." 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | else 29 | echo "Run this job and save the output to ${DIR}" 30 | python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi -------------------------------------------------------------------------------- /scripts/independent-vlp/reproduce_ivlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | WEIGHTSPATH=$3 12 | 13 | CFG=vit_b16_c2_ep5_batch4_2+2ctx 14 | SHOTS=16 15 | LOADEP=5 16 | SUB_base=base 17 | SUB_novel=new 18 | 19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED} 21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR} 22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Results are already available in ${DIR}. Skipping..." 25 | else 26 | echo "Evaluating model" 27 | echo "Runing the first phase job and save the output to ${DIR}" 28 | # Evaluate on base classes 29 | python train.py \ 30 | --root ${DATA} \ 31 | --seed ${SEED} \ 32 | --trainer ${TRAINER} \ 33 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 35 | --output-dir ${DIR_base} \ 36 | --model-dir ${MODEL_DIR} \ 37 | --load-epoch ${LOADEP} \ 38 | --eval-only \ 39 | DATASET.NUM_SHOTS ${SHOTS} \ 40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base} 41 | 42 | # Evaluate on novel classes 43 | python train.py \ 44 | --root ${DATA} \ 45 | --seed ${SEED} \ 46 | --trainer ${TRAINER} \ 47 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 49 | --output-dir ${DIR_novel} \ 50 | --model-dir ${MODEL_DIR} \ 51 | --load-epoch ${LOADEP} \ 52 | --eval-only \ 53 | DATASET.NUM_SHOTS ${SHOTS} \ 54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel} 55 | 56 | fi -------------------------------------------------------------------------------- /scripts/independent-vlp/xd_test_ivlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 2 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/independent-vlp/xd_train_ivlp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}." 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/language-prompting/base2new_test_lp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only 13 | SHOTS=16 14 | LOADEP=5 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Evaluating model" 23 | echo "Results are available in ${DIR}. Resuming..." 24 | 25 | python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | --model-dir ${MODEL_DIR} \ 33 | --load-epoch ${LOADEP} \ 34 | --eval-only \ 35 | DATASET.NUM_SHOTS ${SHOTS} \ 36 | DATASET.SUBSAMPLE_CLASSES ${SUB} 37 | 38 | else 39 | echo "Evaluating model" 40 | echo "Runing the first phase job and save the output to ${DIR}" 41 | 42 | python train.py \ 43 | --root ${DATA} \ 44 | --seed ${SEED} \ 45 | --trainer ${TRAINER} \ 46 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 48 | --output-dir ${DIR} \ 49 | --model-dir ${MODEL_DIR} \ 50 | --load-epoch ${LOADEP} \ 51 | --eval-only \ 52 | DATASET.NUM_SHOTS ${SHOTS} \ 53 | DATASET.SUBSAMPLE_CLASSES ${SUB} 54 | fi -------------------------------------------------------------------------------- /scripts/language-prompting/base2new_train_lp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only 13 | SHOTS=16 14 | 15 | 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Resuming..." 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | else 29 | echo "Run this job and save the output to ${DIR}" 30 | python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi -------------------------------------------------------------------------------- /scripts/language-prompting/reproduce_lp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | WEIGHTSPATH=$3 12 | 13 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only 14 | SHOTS=16 15 | LOADEP=5 16 | SUB_base=base 17 | SUB_novel=new 18 | 19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED} 21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR} 22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Results are already available in ${DIR}. Skipping..." 25 | else 26 | echo "Evaluating model" 27 | echo "Runing the first phase job and save the output to ${DIR}" 28 | # Evaluate on base classes 29 | python train.py \ 30 | --root ${DATA} \ 31 | --seed ${SEED} \ 32 | --trainer ${TRAINER} \ 33 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 35 | --output-dir ${DIR_base} \ 36 | --model-dir ${MODEL_DIR} \ 37 | --load-epoch ${LOADEP} \ 38 | --eval-only \ 39 | DATASET.NUM_SHOTS ${SHOTS} \ 40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base} 41 | 42 | # Evaluate on novel classes 43 | python train.py \ 44 | --root ${DATA} \ 45 | --seed ${SEED} \ 46 | --trainer ${TRAINER} \ 47 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 49 | --output-dir ${DIR_novel} \ 50 | --model-dir ${MODEL_DIR} \ 51 | --load-epoch ${LOADEP} \ 52 | --eval-only \ 53 | DATASET.NUM_SHOTS ${SHOTS} \ 54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel} 55 | 56 | fi -------------------------------------------------------------------------------- /scripts/language-prompting/xd_test_lp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 2 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/language-prompting/xd_train_lp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=IVLP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}." 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/maple/base2new_test_maple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="/data4/wds/dataset/CoOpData/" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx 13 | SHOTS=16 14 | LOADEP=10 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Evaluating model" 23 | echo "Results are available in ${DIR}. Resuming..." 24 | 25 | python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | --model-dir ${MODEL_DIR} \ 33 | --load-epoch ${LOADEP} \ 34 | --eval-only \ 35 | DATASET.NUM_SHOTS ${SHOTS} \ 36 | DATASET.SUBSAMPLE_CLASSES ${SUB} 37 | 38 | else 39 | echo "Evaluating model" 40 | echo "Runing the first phase job and save the output to ${DIR}" 41 | 42 | python train.py \ 43 | --root ${DATA} \ 44 | --seed ${SEED} \ 45 | --trainer ${TRAINER} \ 46 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 48 | --output-dir ${DIR} \ 49 | --model-dir ${MODEL_DIR} \ 50 | --load-epoch ${LOADEP} \ 51 | --eval-only \ 52 | DATASET.NUM_SHOTS ${SHOTS} \ 53 | DATASET.SUBSAMPLE_CLASSES ${SUB} 54 | fi -------------------------------------------------------------------------------- /scripts/maple/base2new_train_maple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="/data4/wds/dataset/CoOpData/" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx 13 | SHOTS=16 14 | 15 | 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Resuming..." 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | else 29 | echo "Run this job and save the output to ${DIR}" 30 | python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi -------------------------------------------------------------------------------- /scripts/maple/fst.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="/data4/wds/dataset/CoOpData/" 7 | TRAINER=MaPLe 8 | 9 | CFG=vit_b16_c2_ep5_batch4_2ctx 10 | #DATASET=$1 11 | 12 | #for DATASET in caltech101 dtd eurosat fgvc_aircraft 13 | #for DATASET in food101 oxford_flowers oxford_pets stanford_cars 14 | for DATASET in ucf101 15 | do 16 | for SHOTS in 1 2 4 8 16 17 | do 18 | for SEED in 1 2 3 19 | do 20 | DIR=output/fewshot/${DATASET}/${TRAINER}_2/shots_${SHOTS}/seed${SEED} 21 | if [ -d "$DIR" ]; then 22 | echo "Results are available in ${DIR}. Resuming..." 23 | python train.py \ 24 | --root ${DATA} \ 25 | --seed ${SEED} \ 26 | --trainer ${TRAINER} \ 27 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 29 | --output-dir ${DIR} \ 30 | DATASET.NUM_SHOTS ${SHOTS} 31 | else 32 | echo "Run this job and save the output to ${DIR}" 33 | python train.py \ 34 | --root ${DATA} \ 35 | --seed ${SEED} \ 36 | --trainer ${TRAINER} \ 37 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 39 | --output-dir ${DIR} \ 40 | DATASET.NUM_SHOTS ${SHOTS} 41 | fi 42 | done 43 | done 44 | done -------------------------------------------------------------------------------- /scripts/maple/reproduce_maple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | WEIGHTSPATH=$3 12 | 13 | CFG=vit_b16_c2_ep5_batch4_2ctx 14 | SHOTS=16 15 | LOADEP=5 16 | SUB_base=base 17 | SUB_novel=new 18 | 19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED} 21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR} 22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Results are already available in ${DIR}. Skipping..." 25 | else 26 | echo "Evaluating model" 27 | echo "Runing the first phase job and save the output to ${DIR}" 28 | # Evaluate on base classes 29 | python train.py \ 30 | --root ${DATA} \ 31 | --seed ${SEED} \ 32 | --trainer ${TRAINER} \ 33 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 35 | --output-dir ${DIR_base} \ 36 | --model-dir ${MODEL_DIR} \ 37 | --load-epoch ${LOADEP} \ 38 | --eval-only \ 39 | DATASET.NUM_SHOTS ${SHOTS} \ 40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base} 41 | 42 | # Evaluate on novel classes 43 | python train.py \ 44 | --root ${DATA} \ 45 | --seed ${SEED} \ 46 | --trainer ${TRAINER} \ 47 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 49 | --output-dir ${DIR_novel} \ 50 | --model-dir ${MODEL_DIR} \ 51 | --load-epoch ${LOADEP} \ 52 | --eval-only \ 53 | DATASET.NUM_SHOTS ${SHOTS} \ 54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel} 55 | 56 | 57 | 58 | fi -------------------------------------------------------------------------------- /scripts/maple/reproduce_maple_xd.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | WEIGHTSPATH=$3 12 | 13 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets 14 | SHOTS=16 15 | LOADEP=2 16 | 17 | MODEL_DIR=${WEIGHTSPATH}/seed${SEED} 18 | 19 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 20 | if [ -d "$DIR" ]; then 21 | echo "Results are already available in ${DIR}. Skipping..." 22 | else 23 | echo "Evaluating model" 24 | echo "Runing the first phase job and save the output to ${DIR}" 25 | # Evaluate on evaluation datasets 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | --model-dir ${MODEL_DIR} \ 34 | --load-epoch ${LOADEP} \ 35 | --eval-only \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | 38 | fi -------------------------------------------------------------------------------- /scripts/maple/xd_test_maple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 2 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/maple/xd_train_maple.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=MaPLe 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}." 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/mmp/base_to_new_test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="dirs to datasets" 7 | TRAINER=MMP 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx 13 | SHOTS=16 14 | LOADEP=15 15 | SUB=base 16 | 17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 18 | MODEL_DIR=/home/dongsheng//wds/maple/output/base2new/train_vis_2/${COMMON_DIR} 19 | DIR=/home/dongsheng//wds/maple/output/base2new/train_vis_2/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | 21 | 22 | if [ -d "$DIR" ]; then 23 | echo "Evaluating model" 24 | echo "Results are available in ${DIR}. Resuming..." 25 | 26 | python train.py \ 27 | --root ${DATA} \ 28 | --seed ${SEED} \ 29 | --trainer ${TRAINER} \ 30 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 32 | --output-dir ${DIR} \ 33 | --model-dir ${MODEL_DIR} \ 34 | --load-epoch ${LOADEP} \ 35 | --eval-only \ 36 | DATASET.NUM_SHOTS ${SHOTS} \ 37 | DATASET.SUBSAMPLE_CLASSES ${SUB} 38 | 39 | else 40 | echo "Evaluating model" 41 | echo "Runing the first phase job and save the output to ${DIR}" 42 | 43 | python train.py \ 44 | --root ${DATA} \ 45 | --seed ${SEED} \ 46 | --trainer ${TRAINER} \ 47 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 49 | --output-dir ${DIR} \ 50 | --model-dir ${MODEL_DIR} \ 51 | --load-epoch ${LOADEP} \ 52 | --eval-only \ 53 | DATASET.NUM_SHOTS ${SHOTS} \ 54 | DATASET.SUBSAMPLE_CLASSES ${SUB} 55 | fi -------------------------------------------------------------------------------- /scripts/mmp/base_to_new_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | cd ../.. 4 | 5 | # custom config 6 | DATA="dirs to datasets" 7 | TRAINER=MMP 8 | 9 | #DATASET=$1 10 | #SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_2ctx 13 | #CFG=sun397 14 | SHOTS=16 15 | N_CTX=2 16 | for DATASET in caltech101 17 | do 18 | for SEED in 1 19 | do 20 | 21 | DIR=/home/dongsheng/wds/maple/output/base2new/train_vis_${N_CTX}/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 22 | if [ -d "$DIR" ]; then 23 | echo "Results are available in ${DIR}. Resuming..." 24 | python train.py \ 25 | --root ${DATA} \ 26 | --seed ${SEED} \ 27 | --trainer ${TRAINER} \ 28 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 30 | --output-dir ${DIR} \ 31 | DATASET.NUM_SHOTS ${SHOTS} \ 32 | DATASET.SUBSAMPLE_CLASSES base 33 | else 34 | echo "Run this job and save the output to ${DIR}" 35 | python train.py \ 36 | --root ${DATA} \ 37 | --seed ${SEED} \ 38 | --trainer ${TRAINER} \ 39 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 40 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 41 | --output-dir ${DIR} \ 42 | DATASET.NUM_SHOTS ${SHOTS} \ 43 | DATASET.SUBSAMPLE_CLASSES base 44 | fi 45 | done 46 | done -------------------------------------------------------------------------------- /scripts/vpt/base2new_test_vpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4 13 | SHOTS=16 14 | LOADEP=5 15 | SUB=new 16 | 17 | 18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR} 20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR} 21 | if [ -d "$DIR" ]; then 22 | echo "Evaluating model" 23 | echo "Results are available in ${DIR}. Resuming..." 24 | 25 | python train.py \ 26 | --root ${DATA} \ 27 | --seed ${SEED} \ 28 | --trainer ${TRAINER} \ 29 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 31 | --output-dir ${DIR} \ 32 | --model-dir ${MODEL_DIR} \ 33 | --load-epoch ${LOADEP} \ 34 | --eval-only \ 35 | DATASET.NUM_SHOTS ${SHOTS} \ 36 | DATASET.SUBSAMPLE_CLASSES ${SUB} 37 | 38 | else 39 | echo "Evaluating model" 40 | echo "Runing the first phase job and save the output to ${DIR}" 41 | 42 | python train.py \ 43 | --root ${DATA} \ 44 | --seed ${SEED} \ 45 | --trainer ${TRAINER} \ 46 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 48 | --output-dir ${DIR} \ 49 | --model-dir ${MODEL_DIR} \ 50 | --load-epoch ${LOADEP} \ 51 | --eval-only \ 52 | DATASET.NUM_SHOTS ${SHOTS} \ 53 | DATASET.SUBSAMPLE_CLASSES ${SUB} 54 | fi -------------------------------------------------------------------------------- /scripts/vpt/base2new_train_vpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4 13 | SHOTS=16 14 | 15 | 16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Resuming..." 19 | python train.py \ 20 | --root ${DATA} \ 21 | --seed ${SEED} \ 22 | --trainer ${TRAINER} \ 23 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 25 | --output-dir ${DIR} \ 26 | DATASET.NUM_SHOTS ${SHOTS} \ 27 | DATASET.SUBSAMPLE_CLASSES base 28 | else 29 | echo "Run this job and save the output to ${DIR}" 30 | python train.py \ 31 | --root ${DATA} \ 32 | --seed ${SEED} \ 33 | --trainer ${TRAINER} \ 34 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 36 | --output-dir ${DIR} \ 37 | DATASET.NUM_SHOTS ${SHOTS} \ 38 | DATASET.SUBSAMPLE_CLASSES base 39 | fi -------------------------------------------------------------------------------- /scripts/vpt/reproduce_vpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | WEIGHTSPATH=$3 12 | 13 | CFG=vit_b16_c2_ep5_batch4_4 14 | SHOTS=16 15 | LOADEP=5 16 | SUB_base=base 17 | SUB_novel=new 18 | 19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED} 20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED} 21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR} 22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR} 23 | if [ -d "$DIR" ]; then 24 | echo "Results are already available in ${DIR}. Skipping..." 25 | else 26 | echo "Evaluating model" 27 | echo "Runing the first phase job and save the output to ${DIR}" 28 | # Evaluate on base classes 29 | python train.py \ 30 | --root ${DATA} \ 31 | --seed ${SEED} \ 32 | --trainer ${TRAINER} \ 33 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 35 | --output-dir ${DIR_base} \ 36 | --model-dir ${MODEL_DIR} \ 37 | --load-epoch ${LOADEP} \ 38 | --eval-only \ 39 | DATASET.NUM_SHOTS ${SHOTS} \ 40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base} 41 | 42 | # Evaluate on novel classes 43 | python train.py \ 44 | --root ${DATA} \ 45 | --seed ${SEED} \ 46 | --trainer ${TRAINER} \ 47 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 49 | --output-dir ${DIR_novel} \ 50 | --model-dir ${MODEL_DIR} \ 51 | --load-epoch ${LOADEP} \ 52 | --eval-only \ 53 | DATASET.NUM_SHOTS ${SHOTS} \ 54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel} 55 | 56 | fi -------------------------------------------------------------------------------- /scripts/vpt/xd_test_vpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4 13 | SHOTS=16 14 | 15 | 16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}. Skip this job" 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \ 30 | --load-epoch 2 \ 31 | --eval-only 32 | fi -------------------------------------------------------------------------------- /scripts/vpt/xd_train_vpt.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA="/path/to/dataset/folder" 7 | TRAINER=VPT 8 | 9 | DATASET=$1 10 | SEED=$2 11 | 12 | CFG=vit_b16_c2_ep5_batch4_4 13 | SHOTS=16 14 | 15 | 16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} 17 | if [ -d "$DIR" ]; then 18 | echo "Results are available in ${DIR}." 19 | else 20 | echo "Run this job and save the output to ${DIR}" 21 | 22 | python train.py \ 23 | --root ${DATA} \ 24 | --seed ${SEED} \ 25 | --trainer ${TRAINER} \ 26 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \ 28 | --output-dir ${DIR} \ 29 | DATASET.NUM_SHOTS ${SHOTS} 30 | fi -------------------------------------------------------------------------------- /scripts/zsclip/zeroshot.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #cd ../.. 4 | 5 | # custom config 6 | DATA=/path/to/datasets 7 | TRAINER=ZeroshotCLIP 8 | DATASET=$1 9 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16 10 | 11 | python train.py \ 12 | --root ${DATA} \ 13 | --trainer ${TRAINER} \ 14 | --dataset-config-file configs/datasets/${DATASET}.yaml \ 15 | --config-file configs/trainers/CoOp/${CFG}.yaml \ 16 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \ 17 | --eval-only -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | 4 | from dassl.utils import setup_logger, set_random_seed, collect_env_info 5 | from dassl.config import get_cfg_default 6 | from dassl.engine import build_trainer 7 | 8 | # custom 9 | import datasets.oxford_pets 10 | import datasets.oxford_flowers 11 | import datasets.fgvc_aircraft 12 | import datasets.dtd 13 | import datasets.eurosat 14 | import datasets.stanford_cars 15 | import datasets.food101 16 | import datasets.sun397 17 | import datasets.caltech101 18 | import datasets.ucf101 19 | import datasets.imagenet 20 | 21 | import datasets.imagenet_sketch 22 | import datasets.imagenetv2 23 | import datasets.imagenet_a 24 | import datasets.imagenet_r 25 | 26 | import trainers.coop 27 | import trainers.cocoop 28 | import trainers.zsclip 29 | import trainers.maple 30 | import trainers.mmp 31 | import trainers.independentVL 32 | import trainers.vpt 33 | 34 | def print_args(args, cfg): 35 | print("***************") 36 | print("** Arguments **") 37 | print("***************") 38 | optkeys = list(args.__dict__.keys()) 39 | optkeys.sort() 40 | for key in optkeys: 41 | print("{}: {}".format(key, args.__dict__[key])) 42 | print("************") 43 | print("** Config **") 44 | print("************") 45 | print(cfg) 46 | 47 | 48 | def reset_cfg(cfg, args): 49 | if args.root: 50 | cfg.DATASET.ROOT = args.root 51 | 52 | if args.output_dir: 53 | cfg.OUTPUT_DIR = args.output_dir 54 | 55 | if args.resume: 56 | cfg.RESUME = args.resume 57 | 58 | if args.seed: 59 | cfg.SEED = args.seed 60 | 61 | if args.source_domains: 62 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains 63 | 64 | if args.target_domains: 65 | cfg.DATASET.TARGET_DOMAINS = args.target_domains 66 | 67 | if args.transforms: 68 | cfg.INPUT.TRANSFORMS = args.transforms 69 | 70 | if args.trainer: 71 | cfg.TRAINER.NAME = args.trainer 72 | 73 | if args.backbone: 74 | cfg.MODEL.BACKBONE.NAME = args.backbone 75 | 76 | if args.head: 77 | cfg.MODEL.HEAD.NAME = args.head 78 | 79 | 80 | def extend_cfg(cfg): 81 | """ 82 | Add new config variables. 83 | 84 | E.g. 85 | from yacs.config import CfgNode as CN 86 | cfg.TRAINER.MY_MODEL = CN() 87 | cfg.TRAINER.MY_MODEL.PARAM_A = 1. 88 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5 89 | cfg.TRAINER.MY_MODEL.PARAM_C = False 90 | """ 91 | from yacs.config import CfgNode as CN 92 | 93 | cfg.TRAINER.COOP = CN() 94 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors 95 | cfg.TRAINER.COOP.CSC = False # class-specific context 96 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words 97 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp 98 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front' 99 | 100 | cfg.TRAINER.COCOOP = CN() 101 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors 102 | cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words 103 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp 104 | 105 | # Config for MaPLe 106 | cfg.TRAINER.MAPLE = CN() 107 | cfg.TRAINER.MAPLE.N_CTX = 16 # number of context vectors 108 | cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words 109 | cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp 110 | cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1) 111 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 112 | 113 | # Config for MMP 114 | cfg.TRAINER.MMP = CN() 115 | cfg.TRAINER.MMP.N_CTX = 2 # number of context vectors 116 | cfg.TRAINER.MMP.CTX_INIT = "a photo of a" # initialization words 117 | cfg.TRAINER.MMP.PREC = "fp16" # fp16, fp32, amp 118 | cfg.TRAINER.MMP.TEXT_PROMPT_DEPTH = 1 # Max 12, minimum 0, for 1 it will act as shallow MMP (J=1) 119 | cfg.TRAINER.MMP.VISION_PROMPT_DEPTH = 1 # Max 12, minimum 0, for 1 it will act as shallow MMP (J=1) 120 | cfg.TRAINER.MMP.TEXT_PROMPT_NUMBER = 4 # number of to be learned language prompts 121 | cfg.TRAINER.MMP.VISION_PROMPT_NUMBER = 4 # number of to be learned vision prompts 122 | cfg.TRAINER.MMP.HIERARCHICAL = True 123 | cfg.TRAINER.MMP.USECT = True 124 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 125 | 126 | 127 | # Config for independent Vision Language prompting (independent-vlp) 128 | cfg.TRAINER.IVLP = CN() 129 | cfg.TRAINER.IVLP.N_CTX_VISION = 2 # number of context vectors at the vision branch 130 | cfg.TRAINER.IVLP.N_CTX_TEXT = 2 # number of context vectors at the language branch 131 | cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" # initialization words (only for language prompts) 132 | cfg.TRAINER.IVLP.PREC = "fp16" # fp16, fp32, amp 133 | # If both variables below are set to 0, 0, will the config will degenerate to COOP model 134 | cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1) 135 | cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1) 136 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 137 | 138 | # Config for only vision side prompting 139 | cfg.TRAINER.VPT = CN() 140 | cfg.TRAINER.VPT.N_CTX_VISION = 2 # number of context vectors at the vision branch 141 | cfg.TRAINER.VPT.CTX_INIT = "a photo of a" # initialization words 142 | cfg.TRAINER.VPT.PREC = "fp16" # fp16, fp32, amp 143 | cfg.TRAINER.VPT.PROMPT_DEPTH_VISION = 1 # if set to 1, will represent shallow vision prompting only 144 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new 145 | 146 | 147 | def setup_cfg(args): 148 | cfg = get_cfg_default() 149 | extend_cfg(cfg) 150 | 151 | # 1. From the dataset config file 152 | if args.dataset_config_file: 153 | cfg.merge_from_file(args.dataset_config_file) 154 | 155 | # 2. From the method config file 156 | if args.config_file: 157 | cfg.merge_from_file(args.config_file) 158 | 159 | # 3. From input arguments 160 | reset_cfg(cfg, args) 161 | 162 | # 4. From optional input arguments 163 | cfg.merge_from_list(args.opts) 164 | 165 | if cfg.DATASET.SUBSAMPLE_CLASSES == 'all': ## few shot setting 166 | if cfg.DATASET.NUM_SHOTS == 1: 167 | cfg.OPTIM.MAX_EPOCH = 30 168 | elif cfg.DATASET.NUM_SHOTS == 2 or cfg.DATASET.NUM_SHOTS == 4: 169 | cfg.OPTIM.MAX_EPOCH = 50 170 | else: 171 | cfg.OPTIM.MAX_EPOCH = 80 172 | 173 | if cfg.DATASET.NAME == "ImageNet": 174 | cfg.OPTIM.MAX_EPOCH = 20 175 | 176 | # if cfg.DATASET.NAME in ['OxfordFlowers', 'FGVCAircraft', 'StanfordCars']: 177 | # cfg.DATALOADER.TRAIN_X.BATCH_SIZE = 32 # 32 for small dataset such as Car,Air,Flowers 178 | 179 | cfg.freeze() 180 | 181 | return cfg 182 | 183 | 184 | def main(args): 185 | cfg = setup_cfg(args) 186 | if cfg.SEED >= 0: 187 | print("Setting fixed seed: {}".format(cfg.SEED)) 188 | set_random_seed(cfg.SEED) 189 | setup_logger(cfg.OUTPUT_DIR) 190 | 191 | if torch.cuda.is_available() and cfg.USE_CUDA: 192 | torch.backends.cudnn.benchmark = True 193 | 194 | print_args(args, cfg) 195 | print("Collecting env info ...") 196 | print("** System info **\n{}\n".format(collect_env_info())) 197 | 198 | trainer = build_trainer(cfg) 199 | 200 | if args.eval_only: 201 | trainer.load_model(args.model_dir, epoch=args.load_epoch) 202 | trainer.test() 203 | return 204 | 205 | if not args.no_train: 206 | trainer.train() 207 | 208 | 209 | if __name__ == "__main__": 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument("--root", type=str, default="", help="path to dataset") 212 | parser.add_argument("--output-dir", type=str, default="", help="output directory") 213 | parser.add_argument( 214 | "--resume", 215 | type=str, 216 | default="", 217 | help="checkpoint directory (from which the training resumes)", 218 | ) 219 | parser.add_argument( 220 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed" 221 | ) 222 | parser.add_argument( 223 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG" 224 | ) 225 | parser.add_argument( 226 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG" 227 | ) 228 | parser.add_argument( 229 | "--transforms", type=str, nargs="+", help="data augmentation methods" 230 | ) 231 | parser.add_argument( 232 | "--config-file", type=str, default="", help="path to config file" 233 | ) 234 | parser.add_argument( 235 | "--dataset-config-file", 236 | type=str, 237 | default="", 238 | help="path to config file for dataset setup", 239 | ) 240 | parser.add_argument("--trainer", type=str, default="", help="name of trainer") 241 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone") 242 | parser.add_argument("--head", type=str, default="", help="name of head") 243 | parser.add_argument("--eval-only", action="store_true", help="evaluation only") 244 | parser.add_argument( 245 | "--model-dir", 246 | type=str, 247 | default="", 248 | help="load model from this directory for eval-only mode", 249 | ) 250 | parser.add_argument( 251 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation" 252 | ) 253 | parser.add_argument( 254 | "--no-train", action="store_true", help="do not call trainer.train()" 255 | ) 256 | parser.add_argument( 257 | "opts", 258 | default=None, 259 | nargs=argparse.REMAINDER, 260 | help="modify config options using the command-line", 261 | ) 262 | args = parser.parse_args() 263 | main(args) 264 | -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__init__.py -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/cocoop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/cocoop.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/cocoop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/cocoop.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/coop.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/coop.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/coop.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/imagenet_templates.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/imagenet_templates.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/imagenet_templates.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/imagenet_templates.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/independentVL.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/independentVL.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/independentVL.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/independentVL.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/maple.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/maple.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/maple.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/maple.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/mmp.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/mmp.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/mmp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/mmp.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/vpt.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/vpt.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/vpt.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/vpt.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/zsclip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/zsclip.cpython-37.pyc -------------------------------------------------------------------------------- /trainers/__pycache__/zsclip.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/zsclip.cpython-38.pyc -------------------------------------------------------------------------------- /trainers/cocoop.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.cuda.amp import GradScaler, autocast 9 | 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.utils import load_pretrained_weights, load_checkpoint 13 | from dassl.optim import build_optimizer, build_lr_scheduler 14 | 15 | from clip import clip 16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | _tokenizer = _Tokenizer() 19 | 20 | 21 | def load_clip_to_cpu(cfg): 22 | backbone_name = cfg.MODEL.BACKBONE.NAME 23 | url = clip._MODELS[backbone_name] 24 | model_path = clip._download(url) 25 | 26 | try: 27 | # loading JIT archive 28 | model = torch.jit.load(model_path, map_location="cpu").eval() 29 | state_dict = None 30 | 31 | except RuntimeError: 32 | state_dict = torch.load(model_path, map_location="cpu") 33 | design_details = {"trainer": 'CoCoOp', 34 | "vision_depth": 0, 35 | "language_depth": 0, "vision_ctx": 0, 36 | "language_ctx": 0} 37 | model = clip.build_model(state_dict or model.state_dict(), design_details) 38 | 39 | return model 40 | 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, clip_model): 44 | super().__init__() 45 | self.transformer = clip_model.transformer 46 | self.positional_embedding = clip_model.positional_embedding 47 | self.ln_final = clip_model.ln_final 48 | self.text_projection = clip_model.text_projection 49 | self.dtype = clip_model.dtype 50 | 51 | def forward(self, prompts, tokenized_prompts): 52 | x = prompts + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | # take features from the eot embedding (eot_token is the highest number in each sequence) 60 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 61 | 62 | return x 63 | 64 | 65 | class PromptLearner(nn.Module): 66 | def __init__(self, cfg, classnames, clip_model): 67 | super().__init__() 68 | n_cls = len(classnames) 69 | n_ctx = cfg.TRAINER.COCOOP.N_CTX 70 | ctx_init = cfg.TRAINER.COCOOP.CTX_INIT 71 | dtype = clip_model.dtype 72 | ctx_dim = clip_model.ln_final.weight.shape[0] 73 | vis_dim = clip_model.visual.output_dim 74 | clip_imsize = clip_model.visual.input_resolution 75 | cfg_imsize = cfg.INPUT.SIZE[0] 76 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 77 | 78 | if ctx_init: 79 | # use given words to initialize context vectors 80 | ctx_init = ctx_init.replace("_", " ") 81 | n_ctx = len(ctx_init.split(" ")) 82 | prompt = clip.tokenize(ctx_init) 83 | with torch.no_grad(): 84 | embedding = clip_model.token_embedding(prompt).type(dtype) 85 | ctx_vectors = embedding[0, 1: 1 + n_ctx, :] 86 | prompt_prefix = ctx_init 87 | else: 88 | # random initialization 89 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 90 | nn.init.normal_(ctx_vectors, std=0.02) 91 | prompt_prefix = " ".join(["X"] * n_ctx) 92 | 93 | print(f'Initial context: "{prompt_prefix}"') 94 | print(f"Number of context words (tokens): {n_ctx}") 95 | 96 | self.ctx = nn.Parameter(ctx_vectors) 97 | 98 | self.meta_net = nn.Sequential(OrderedDict([ 99 | ("linear1", nn.Linear(vis_dim, vis_dim // 16)), 100 | ("relu", nn.ReLU(inplace=True)), 101 | ("linear2", nn.Linear(vis_dim // 16, ctx_dim)) 102 | ])) 103 | 104 | if cfg.TRAINER.COCOOP.PREC == "fp16": 105 | self.meta_net.half() 106 | 107 | classnames = [name.replace("_", " ") for name in classnames] 108 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 109 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 110 | 111 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) 112 | with torch.no_grad(): 113 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 114 | 115 | # These token vectors will be saved when in save_model(), 116 | # but they should be ignored in load_model() as we want to use 117 | # those computed using the current class names 118 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 119 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS 120 | 121 | self.n_cls = n_cls 122 | self.n_ctx = n_ctx 123 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 124 | self.name_lens = name_lens 125 | 126 | def construct_prompts(self, ctx, prefix, suffix, label=None): 127 | # dim0 is either batch_size (during training) or n_cls (during testing) 128 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 129 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 130 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 131 | 132 | if label is not None: 133 | prefix = prefix[label] 134 | suffix = suffix[label] 135 | 136 | prompts = torch.cat( 137 | [ 138 | prefix, # (dim0, 1, dim) 139 | ctx, # (dim0, n_ctx, dim) 140 | suffix, # (dim0, *, dim) 141 | ], 142 | dim=1, 143 | ) 144 | 145 | return prompts 146 | 147 | def forward(self, im_features): 148 | prefix = self.token_prefix 149 | suffix = self.token_suffix 150 | ctx = self.ctx # (n_ctx, ctx_dim) 151 | bias = self.meta_net(im_features) # (batch, ctx_dim) 152 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim) 153 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim) 154 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim) 155 | 156 | # Use instance-conditioned context tokens for all classes 157 | prompts = [] 158 | for ctx_shifted_i in ctx_shifted: 159 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1) 160 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim) 161 | prompts.append(pts_i) 162 | prompts = torch.stack(prompts) 163 | 164 | return prompts 165 | 166 | 167 | class CustomCLIP(nn.Module): 168 | def __init__(self, cfg, classnames, clip_model): 169 | super().__init__() 170 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model) 171 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 172 | self.image_encoder = clip_model.visual 173 | self.text_encoder = TextEncoder(clip_model) 174 | self.logit_scale = clip_model.logit_scale 175 | self.dtype = clip_model.dtype 176 | 177 | def forward(self, image, label=None): 178 | tokenized_prompts = self.tokenized_prompts 179 | logit_scale = self.logit_scale.exp() 180 | 181 | image_features = self.image_encoder(image.type(self.dtype)) 182 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 183 | 184 | prompts = self.prompt_learner(image_features) 185 | 186 | logits = [] 187 | for pts_i, imf_i in zip(prompts, image_features): 188 | text_features = self.text_encoder(pts_i, tokenized_prompts) 189 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 190 | l_i = logit_scale * imf_i @ text_features.t() 191 | logits.append(l_i) 192 | logits = torch.stack(logits) 193 | 194 | if self.prompt_learner.training: 195 | return F.cross_entropy(logits, label) 196 | 197 | return logits 198 | 199 | 200 | @TRAINER_REGISTRY.register() 201 | class CoCoOp(TrainerX): 202 | def check_cfg(self, cfg): 203 | assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"] 204 | 205 | def build_model(self): 206 | cfg = self.cfg 207 | classnames = self.dm.dataset.classnames 208 | 209 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 210 | clip_model = load_clip_to_cpu(cfg) 211 | 212 | if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp": 213 | # CLIP's default precision is fp16 214 | clip_model.float() 215 | 216 | print("Building custom CLIP") 217 | self.model = CustomCLIP(cfg, classnames, clip_model) 218 | 219 | print("Turning off gradients in both the image and the text encoder") 220 | name_to_update = "prompt_learner" 221 | 222 | for name, param in self.model.named_parameters(): 223 | if name_to_update not in name: 224 | param.requires_grad_(False) 225 | 226 | # Double check 227 | enabled = set() 228 | for name, param in self.model.named_parameters(): 229 | if param.requires_grad: 230 | enabled.add(name) 231 | print(f"Parameters to be updated: {enabled}") 232 | 233 | if cfg.MODEL.INIT_WEIGHTS: 234 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS) 235 | 236 | self.model.to(self.device) 237 | # NOTE: only give prompt_learner to the optimizer 238 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM) 239 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 240 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched) 241 | 242 | self.scaler = GradScaler() if cfg.TRAINER.COCOOP.PREC == "amp" else None 243 | 244 | # Note that multi-gpu training could be slow because CLIP's size is 245 | # big, which slows down the copy operation in DataParallel 246 | device_count = torch.cuda.device_count() 247 | if device_count > 1: 248 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 249 | self.model = nn.DataParallel(self.model) 250 | 251 | def forward_backward(self, batch): 252 | image, label = self.parse_batch_train(batch) 253 | 254 | model = self.model 255 | optim = self.optim 256 | scaler = self.scaler 257 | 258 | prec = self.cfg.TRAINER.COCOOP.PREC 259 | if prec == "amp": 260 | with autocast(): 261 | loss = model(image, label) 262 | optim.zero_grad() 263 | scaler.scale(loss).backward() 264 | scaler.step(optim) 265 | scaler.update() 266 | else: 267 | loss = model(image, label) 268 | optim.zero_grad() 269 | loss.backward() 270 | optim.step() 271 | 272 | loss_summary = {"loss": loss.item()} 273 | 274 | if (self.batch_idx + 1) == self.num_batches: 275 | self.update_lr() 276 | 277 | return loss_summary 278 | 279 | def parse_batch_train(self, batch): 280 | input = batch["img"] 281 | label = batch["label"] 282 | input = input.to(self.device) 283 | label = label.to(self.device) 284 | return input, label 285 | 286 | def load_model(self, directory, epoch=None): 287 | if not directory: 288 | print("Note that load_model() is skipped as no pretrained model is given") 289 | return 290 | 291 | names = self.get_model_names() 292 | 293 | # By default, the best model is loaded 294 | model_file = "model-best.pth.tar" 295 | 296 | if epoch is not None: 297 | model_file = "model.pth.tar-" + str(epoch) 298 | 299 | for name in names: 300 | model_path = osp.join(directory, name, model_file) 301 | 302 | if not osp.exists(model_path): 303 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 304 | 305 | checkpoint = load_checkpoint(model_path) 306 | state_dict = checkpoint["state_dict"] 307 | epoch = checkpoint["epoch"] 308 | 309 | # Ignore fixed token vectors 310 | if "token_prefix" in state_dict: 311 | del state_dict["token_prefix"] 312 | 313 | if "token_suffix" in state_dict: 314 | del state_dict["token_suffix"] 315 | 316 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 317 | # set strict=False 318 | self._models[name].load_state_dict(state_dict, strict=False) 319 | -------------------------------------------------------------------------------- /trainers/imagenet_templates.py: -------------------------------------------------------------------------------- 1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb 2 | 3 | IMAGENET_TEMPLATES = [ 4 | "a bad photo of a {}.", 5 | "a photo of many {}.", 6 | "a sculpture of a {}.", 7 | "a photo of the hard to see {}.", 8 | "a low resolution photo of the {}.", 9 | "a rendering of a {}.", 10 | "graffiti of a {}.", 11 | "a bad photo of the {}.", 12 | "a cropped photo of the {}.", 13 | "a tattoo of a {}.", 14 | "the embroidered {}.", 15 | "a photo of a hard to see {}.", 16 | "a bright photo of a {}.", 17 | "a photo of a clean {}.", 18 | "a photo of a dirty {}.", 19 | "a dark photo of the {}.", 20 | "a drawing of a {}.", 21 | "a photo of my {}.", 22 | "the plastic {}.", 23 | "a photo of the cool {}.", 24 | "a close-up photo of a {}.", 25 | "a black and white photo of the {}.", 26 | "a painting of the {}.", 27 | "a painting of a {}.", 28 | "a pixelated photo of the {}.", 29 | "a sculpture of the {}.", 30 | "a bright photo of the {}.", 31 | "a cropped photo of a {}.", 32 | "a plastic {}.", 33 | "a photo of the dirty {}.", 34 | "a jpeg corrupted photo of a {}.", 35 | "a blurry photo of the {}.", 36 | "a photo of the {}.", 37 | "a good photo of the {}.", 38 | "a rendering of the {}.", 39 | "a {} in a video game.", 40 | "a photo of one {}.", 41 | "a doodle of a {}.", 42 | "a close-up photo of the {}.", 43 | "a photo of a {}.", 44 | "the origami {}.", 45 | "the {} in a video game.", 46 | "a sketch of a {}.", 47 | "a doodle of the {}.", 48 | "a origami {}.", 49 | "a low resolution photo of a {}.", 50 | "the toy {}.", 51 | "a rendition of the {}.", 52 | "a photo of the clean {}.", 53 | "a photo of a large {}.", 54 | "a rendition of a {}.", 55 | "a photo of a nice {}.", 56 | "a photo of a weird {}.", 57 | "a blurry photo of a {}.", 58 | "a cartoon {}.", 59 | "art of a {}.", 60 | "a sketch of the {}.", 61 | "a embroidered {}.", 62 | "a pixelated photo of a {}.", 63 | "itap of the {}.", 64 | "a jpeg corrupted photo of the {}.", 65 | "a good photo of a {}.", 66 | "a plushie {}.", 67 | "a photo of the nice {}.", 68 | "a photo of the small {}.", 69 | "a photo of the weird {}.", 70 | "the cartoon {}.", 71 | "art of the {}.", 72 | "a drawing of the {}.", 73 | "a photo of the large {}.", 74 | "a black and white photo of a {}.", 75 | "the plushie {}.", 76 | "a dark photo of a {}.", 77 | "itap of a {}.", 78 | "graffiti of the {}.", 79 | "a toy {}.", 80 | "itap of my {}.", 81 | "a photo of a cool {}.", 82 | "a photo of a small {}.", 83 | "a tattoo of the {}.", 84 | ] 85 | 86 | IMAGENET_TEMPLATES_SELECT = [ 87 | "itap of a {}.", 88 | "a bad photo of the {}.", 89 | "a origami {}.", 90 | "a photo of the large {}.", 91 | "a {} in a video game.", 92 | "art of the {}.", 93 | "a photo of the small {}.", 94 | ] 95 | -------------------------------------------------------------------------------- /trainers/independentVL.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.cuda.amp import GradScaler, autocast 9 | 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.utils import load_pretrained_weights, load_checkpoint 13 | from dassl.optim import build_optimizer, build_lr_scheduler 14 | 15 | from clip import clip 16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | _tokenizer = _Tokenizer() 19 | 20 | 21 | def load_clip_to_cpu(cfg): 22 | backbone_name = cfg.MODEL.BACKBONE.NAME 23 | url = clip._MODELS[backbone_name] 24 | model_path = clip._download(url) 25 | 26 | try: 27 | # loading JIT archive 28 | model = torch.jit.load(model_path, map_location="cpu").eval() 29 | state_dict = None 30 | 31 | except RuntimeError: 32 | state_dict = torch.load(model_path, map_location="cpu") 33 | design_details = {"trainer": 'IVLP', 34 | "vision_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION, 35 | "language_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT, "vision_ctx": cfg.TRAINER.IVLP.N_CTX_VISION, 36 | "language_ctx": cfg.TRAINER.IVLP.N_CTX_TEXT} 37 | model = clip.build_model(state_dict or model.state_dict(), design_details) 38 | 39 | return model 40 | 41 | 42 | class TextEncoder(nn.Module): 43 | def __init__(self, clip_model): 44 | super().__init__() 45 | self.transformer = clip_model.transformer 46 | self.positional_embedding = clip_model.positional_embedding 47 | self.ln_final = clip_model.ln_final 48 | self.text_projection = clip_model.text_projection 49 | self.dtype = clip_model.dtype 50 | 51 | def forward(self, prompts, tokenized_prompts): 52 | x = prompts + self.positional_embedding.type(self.dtype) 53 | x = x.permute(1, 0, 2) # NLD -> LND 54 | x = self.transformer(x) 55 | x = x.permute(1, 0, 2) # LND -> NLD 56 | x = self.ln_final(x).type(self.dtype) 57 | 58 | # x.shape = [batch_size, n_ctx, transformer.width] 59 | # take features from the eot embedding (eot_token is the highest number in each sequence) 60 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 61 | 62 | return x 63 | 64 | 65 | class VLPromptLearner(nn.Module): 66 | def __init__(self, cfg, classnames, clip_model): 67 | super().__init__() 68 | n_cls = len(classnames) 69 | # Make sure Language depth >= 1 70 | assert cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \ 71 | "\nPlease use VPT trainer if you want to learn only vision " \ 72 | "branch " 73 | n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT 74 | ctx_init = cfg.TRAINER.IVLP.CTX_INIT 75 | dtype = clip_model.dtype 76 | ctx_dim = clip_model.ln_final.weight.shape[0] 77 | vis_dim = clip_model.visual.output_dim 78 | clip_imsize = clip_model.visual.input_resolution 79 | cfg_imsize = cfg.INPUT.SIZE[0] 80 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 81 | 82 | if ctx_init and (n_ctx) <= 4: 83 | # use given words to initialize context vectors 84 | ctx_init = ctx_init.replace("_", " ") 85 | n_ctx = n_ctx 86 | prompt = clip.tokenize(ctx_init) 87 | with torch.no_grad(): 88 | embedding = clip_model.token_embedding(prompt).type(dtype) 89 | ctx_vectors = embedding[0, 1: 1 + n_ctx, :] 90 | prompt_prefix = ctx_init 91 | else: 92 | # random initialization 93 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype) 94 | nn.init.normal_(ctx_vectors, std=0.02) 95 | prompt_prefix = " ".join(["X"] * n_ctx) 96 | print(f"Independent V-L design") 97 | print(f'Initial text context: "{prompt_prefix}"') 98 | print(f"Number of context words (tokens) for Language prompting: {n_ctx}") 99 | print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.IVLP.N_CTX_VISION}") 100 | self.ctx = nn.Parameter(ctx_vectors) 101 | 102 | classnames = [name.replace("_", " ") for name in classnames] 103 | name_lens = [len(_tokenizer.encode(name)) for name in classnames] 104 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 105 | 106 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn) 107 | with torch.no_grad(): 108 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype) 109 | 110 | # These token vectors will be saved when in save_model(), 111 | # but they should be ignored in load_model() as we want to use 112 | # those computed using the current class names 113 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS 114 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS 115 | 116 | self.n_cls = n_cls 117 | self.n_ctx = n_ctx 118 | self.tokenized_prompts = tokenized_prompts # torch.Tensor 119 | self.name_lens = name_lens 120 | 121 | def construct_prompts(self, ctx, prefix, suffix, label=None): 122 | # dim0 is either batch_size (during training) or n_cls (during testing) 123 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim) 124 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim) 125 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim) 126 | 127 | if label is not None: 128 | prefix = prefix[label] 129 | suffix = suffix[label] 130 | 131 | prompts = torch.cat( 132 | [ 133 | prefix, # (dim0, 1, dim) 134 | ctx, # (dim0, n_ctx, dim) 135 | suffix, # (dim0, *, dim) 136 | ], 137 | dim=1, 138 | ) 139 | 140 | return prompts 141 | 142 | def forward(self): 143 | ctx = self.ctx 144 | if ctx.dim() == 2: 145 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1) 146 | 147 | prefix = self.token_prefix 148 | suffix = self.token_suffix 149 | prompts = self.construct_prompts(ctx, prefix, suffix) 150 | 151 | return prompts 152 | 153 | 154 | class CustomCLIP(nn.Module): 155 | def __init__(self, cfg, classnames, clip_model): 156 | super().__init__() 157 | self.prompt_learner = VLPromptLearner(cfg, classnames, clip_model) 158 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts 159 | self.image_encoder = clip_model.visual 160 | self.text_encoder = TextEncoder(clip_model) 161 | self.logit_scale = clip_model.logit_scale 162 | self.dtype = clip_model.dtype 163 | 164 | def forward(self, image, label=None): 165 | tokenized_prompts = self.tokenized_prompts 166 | logit_scale = self.logit_scale.exp() 167 | 168 | prompts = self.prompt_learner() 169 | text_features = self.text_encoder(prompts, tokenized_prompts) 170 | image_features = self.image_encoder(image.type(self.dtype)) 171 | 172 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 173 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 174 | logits = logit_scale * image_features @ text_features.t() 175 | 176 | if self.prompt_learner.training: 177 | return F.cross_entropy(logits, label) 178 | 179 | return logits 180 | 181 | 182 | @TRAINER_REGISTRY.register() 183 | class IVLP(TrainerX): 184 | def check_cfg(self, cfg): 185 | assert cfg.TRAINER.IVLP.PREC in ["fp16", "fp32", "amp"] 186 | 187 | def build_model(self): 188 | cfg = self.cfg 189 | classnames = self.dm.dataset.classnames 190 | 191 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 192 | clip_model = load_clip_to_cpu(cfg) 193 | 194 | if cfg.TRAINER.IVLP.PREC == "fp32" or cfg.TRAINER.IVLP.PREC == "amp": 195 | # CLIP's default precision is fp16 196 | clip_model.float() 197 | 198 | print("Building custom CLIP") 199 | self.model = CustomCLIP(cfg, classnames, clip_model) 200 | 201 | print("Turning off gradients in both the image and the text encoder") 202 | name_to_update = "prompt_learner" 203 | 204 | for name, param in self.model.named_parameters(): 205 | if name_to_update not in name: 206 | # Make sure that VPT prompts are updated 207 | if "VPT" in name: 208 | param.requires_grad_(True) 209 | else: 210 | param.requires_grad_(False) 211 | 212 | # Double check 213 | enabled = set() 214 | for name, param in self.model.named_parameters(): 215 | if param.requires_grad: 216 | enabled.add(name) 217 | print(f"Parameters to be updated: {enabled}") 218 | 219 | if cfg.MODEL.INIT_WEIGHTS: 220 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 221 | 222 | self.model.to(self.device) 223 | # NOTE: only give prompt_learner to the optimizer 224 | self.optim = build_optimizer(self.model, cfg.OPTIM) 225 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 226 | self.register_model("VLPromptLearner", self.model, self.optim, self.sched) 227 | 228 | self.scaler = GradScaler() if cfg.TRAINER.IVLP.PREC == "amp" else None 229 | 230 | # Note that multi-gpu training could be slow because CLIP's size is 231 | # big, which slows down the copy operation in DataParallel 232 | device_count = torch.cuda.device_count() 233 | if device_count > 1: 234 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 235 | self.model = nn.DataParallel(self.model) 236 | 237 | def forward_backward(self, batch): 238 | image, label = self.parse_batch_train(batch) 239 | 240 | model = self.model 241 | optim = self.optim 242 | scaler = self.scaler 243 | 244 | prec = self.cfg.TRAINER.IVLP.PREC 245 | if prec == "amp": 246 | with autocast(): 247 | loss = model(image, label) 248 | optim.zero_grad() 249 | scaler.scale(loss).backward() 250 | scaler.step(optim) 251 | scaler.update() 252 | else: 253 | loss = model(image, label) 254 | optim.zero_grad() 255 | loss.backward() 256 | optim.step() 257 | 258 | loss_summary = {"loss": loss.item()} 259 | 260 | if (self.batch_idx + 1) == self.num_batches: 261 | self.update_lr() 262 | 263 | return loss_summary 264 | 265 | def parse_batch_train(self, batch): 266 | input = batch["img"] 267 | label = batch["label"] 268 | input = input.to(self.device) 269 | label = label.to(self.device) 270 | return input, label 271 | 272 | def load_model(self, directory, epoch=None): 273 | if not directory: 274 | print("Note that load_model() is skipped as no pretrained model is given") 275 | return 276 | 277 | names = self.get_model_names() 278 | 279 | # By default, the best model is loaded 280 | model_file = "model-best.pth.tar" 281 | 282 | if epoch is not None: 283 | model_file = "model.pth.tar-" + str(epoch) 284 | 285 | for name in names: 286 | model_path = osp.join(directory, name, model_file) 287 | 288 | if not osp.exists(model_path): 289 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 290 | 291 | checkpoint = load_checkpoint(model_path) 292 | state_dict = checkpoint["state_dict"] 293 | epoch = checkpoint["epoch"] 294 | 295 | # Ignore fixed token vectors 296 | if "prompt_learner.token_prefix" in state_dict: 297 | del state_dict["prompt_learner.token_prefix"] 298 | 299 | if "prompt_learner.token_suffix" in state_dict: 300 | del state_dict["prompt_learner.token_suffix"] 301 | 302 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 303 | # set strict=False 304 | self._models[name].load_state_dict(state_dict, strict=False) 305 | -------------------------------------------------------------------------------- /trainers/vpt.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from collections import OrderedDict 3 | import math 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import functional as F 8 | from torch.cuda.amp import GradScaler, autocast 9 | 10 | from dassl.engine import TRAINER_REGISTRY, TrainerX 11 | from dassl.metrics import compute_accuracy 12 | from dassl.utils import load_pretrained_weights, load_checkpoint 13 | from dassl.optim import build_optimizer, build_lr_scheduler 14 | 15 | from clip import clip 16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 17 | 18 | _tokenizer = _Tokenizer() 19 | 20 | 21 | def load_clip_to_cpu(cfg): 22 | backbone_name = cfg.MODEL.BACKBONE.NAME 23 | url = clip._MODELS[backbone_name] 24 | model_path = clip._download(url) 25 | 26 | try: 27 | # loading JIT archive 28 | model = torch.jit.load(model_path, map_location="cpu").eval() 29 | state_dict = None 30 | 31 | except RuntimeError: 32 | state_dict = torch.load(model_path, map_location="cpu") 33 | design_details = { "trainer": "VPT", 34 | "vision_depth": cfg.TRAINER.VPT.PROMPT_DEPTH_VISION, 35 | "vision_ctx": cfg.TRAINER.VPT.N_CTX_VISION, 36 | "language_depth": 0, 37 | "language_ctx": 0} 38 | assert cfg.TRAINER.VPT.PROMPT_DEPTH_VISION >= 1, "For Vision Prompting, PROMPT_DEPTH_VISION should be >= 1" 39 | model = clip.build_model(state_dict or model.state_dict(), design_details) 40 | 41 | return model.float() 42 | 43 | 44 | class TextEncoder(nn.Module): 45 | def __init__(self, clip_model): 46 | super().__init__() 47 | self.transformer = clip_model.transformer 48 | self.positional_embedding = clip_model.positional_embedding 49 | self.ln_final = clip_model.ln_final 50 | self.text_projection = clip_model.text_projection 51 | self.dtype = clip_model.dtype 52 | 53 | def forward(self, prompts, tokenized_prompts): 54 | x = prompts + self.positional_embedding.type(self.dtype) 55 | x = x.permute(1, 0, 2) # NLD -> LND 56 | x = self.transformer(x) 57 | x = x.permute(1, 0, 2) # LND -> NLD 58 | x = self.ln_final(x).type(self.dtype) 59 | 60 | # x.shape = [batch_size, n_ctx, transformer.width] 61 | # take features from the eot embedding (eot_token is the highest number in each sequence) 62 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection 63 | 64 | return x 65 | 66 | 67 | class FixedEmbeddings(): 68 | def __init__(self, cfg, classnames, clip_model): 69 | clip_imsize = clip_model.visual.input_resolution 70 | cfg_imsize = cfg.INPUT.SIZE[0] 71 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})" 72 | 73 | prompt_prefix = "a photo of a" 74 | print('Vision Prompting Design') 75 | print(f'Initial context: "{prompt_prefix}"') 76 | print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.VPT.N_CTX_VISION}") 77 | print(f"Using fixed hand crated prompts") 78 | 79 | classnames = [name.replace("_", " ") for name in classnames] 80 | prompts = [prompt_prefix + " " + name + "." for name in classnames] 81 | 82 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) 83 | with torch.no_grad(): 84 | text_features = clip_model.encode_text(tokenized_prompts) 85 | 86 | self.fixed_embeddings = text_features 87 | 88 | def return_fixed_embeddings(self): 89 | return self.fixed_embeddings 90 | 91 | 92 | class CustomCLIP(nn.Module): 93 | def __init__(self, cfg, classnames, clip_model): 94 | super().__init__() 95 | self.embeddings = FixedEmbeddings(cfg, classnames, clip_model) 96 | self.image_encoder = clip_model.visual 97 | self.text_encoder = TextEncoder(clip_model) 98 | self.logit_scale = clip_model.logit_scale 99 | self.dtype = clip_model.dtype 100 | 101 | def forward(self, image, label=None, training=False): 102 | logit_scale = self.logit_scale.exp() 103 | 104 | text_features = self.embeddings.return_fixed_embeddings().cuda() 105 | image_features = self.image_encoder(image.type(self.dtype)) 106 | 107 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 108 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 109 | logits = logit_scale * image_features @ text_features.t() 110 | 111 | if training: 112 | return F.cross_entropy(logits, label) 113 | 114 | return logits 115 | 116 | 117 | @TRAINER_REGISTRY.register() 118 | class VPT(TrainerX): 119 | def check_cfg(self, cfg): 120 | assert cfg.TRAINER.VPT.PREC in ["fp16", "fp32", "amp"] 121 | 122 | def build_model(self): 123 | cfg = self.cfg 124 | classnames = self.dm.dataset.classnames 125 | 126 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 127 | clip_model = load_clip_to_cpu(cfg) 128 | 129 | if cfg.TRAINER.VPT.PREC == "fp32" or cfg.TRAINER.VPT.PREC == "amp": 130 | # CLIP's default precision is fp16 131 | clip_model.float() 132 | 133 | print("Building custom CLIP") 134 | self.model = CustomCLIP(cfg, classnames, clip_model) 135 | 136 | print("Turning off gradients in both the image and the text encoder") 137 | name_to_update = "prompt_learner" 138 | 139 | for name, param in self.model.named_parameters(): 140 | if name_to_update not in name: 141 | # Make sure that VPT prompts are updated 142 | if "VPT" in name: 143 | param.requires_grad_(True) 144 | else: 145 | param.requires_grad_(False) 146 | 147 | # Double check 148 | enabled = set() 149 | for name, param in self.model.named_parameters(): 150 | if param.requires_grad: 151 | enabled.add(name) 152 | print(f"Parameters to be updated: {enabled}") 153 | 154 | if cfg.MODEL.INIT_WEIGHTS: 155 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS) 156 | 157 | self.model.to(self.device) 158 | # NOTE: only give prompt_learner to the optimizer 159 | self.optim = build_optimizer(self.model, cfg.OPTIM) 160 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM) 161 | self.register_model("prompt_learner", self.model, self.optim, self.sched) 162 | 163 | self.scaler = GradScaler() if cfg.TRAINER.VPT.PREC == "amp" else None 164 | 165 | # Note that multi-gpu training could be slow because CLIP's size is 166 | # big, which slows down the copy operation in DataParallel 167 | device_count = torch.cuda.device_count() 168 | if device_count > 1: 169 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!") 170 | self.model = nn.DataParallel(self.model) 171 | 172 | def forward_backward(self, batch): 173 | image, label = self.parse_batch_train(batch) 174 | 175 | model = self.model 176 | optim = self.optim 177 | scaler = self.scaler 178 | 179 | prec = self.cfg.TRAINER.VPT.PREC 180 | if prec == "amp": 181 | with autocast(): 182 | loss = model(image, label) 183 | optim.zero_grad() 184 | scaler.scale(loss).backward() 185 | scaler.step(optim) 186 | scaler.update() 187 | else: 188 | loss = model(image, label, training=True) 189 | optim.zero_grad() 190 | loss.backward() 191 | optim.step() 192 | 193 | loss_summary = {"loss": loss.item()} 194 | 195 | if (self.batch_idx + 1) == self.num_batches: 196 | self.update_lr() 197 | 198 | return loss_summary 199 | 200 | def parse_batch_train(self, batch): 201 | input = batch["img"] 202 | label = batch["label"] 203 | input = input.to(self.device) 204 | label = label.to(self.device) 205 | return input, label 206 | 207 | def load_model(self, directory, epoch=None): 208 | if not directory: 209 | print("Note that load_model() is skipped as no pretrained model is given") 210 | return 211 | 212 | names = self.get_model_names() 213 | 214 | # By default, the best model is loaded 215 | model_file = "model-best.pth.tar" 216 | 217 | if epoch is not None: 218 | model_file = "model.pth.tar-" + str(epoch) 219 | 220 | for name in names: 221 | model_path = osp.join(directory, name, model_file) 222 | 223 | if not osp.exists(model_path): 224 | raise FileNotFoundError('Model not found at "{}"'.format(model_path)) 225 | 226 | checkpoint = load_checkpoint(model_path) 227 | state_dict = checkpoint["state_dict"] 228 | epoch = checkpoint["epoch"] 229 | 230 | # Ignore fixed token vectors 231 | if "prompt_learner.token_prefix" in state_dict: 232 | del state_dict["prompt_learner.token_prefix"] 233 | 234 | if "prompt_learner.token_suffix" in state_dict: 235 | del state_dict["prompt_learner.token_suffix"] 236 | 237 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch)) 238 | # set strict=False 239 | self._models[name].load_state_dict(state_dict, strict=False) 240 | -------------------------------------------------------------------------------- /trainers/zsclip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from dassl.engine import TRAINER_REGISTRY, TrainerX 5 | from dassl.optim import build_optimizer, build_lr_scheduler 6 | 7 | from clip import clip 8 | from clip.model import convert_weights 9 | 10 | from .coop import load_clip_to_cpu 11 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT 12 | 13 | CUSTOM_TEMPLATES = { 14 | "OxfordPets": "a photo of a {}, a type of pet.", 15 | "OxfordFlowers": "a photo of a {}, a type of flower.", 16 | "FGVCAircraft": "a photo of a {}, a type of aircraft.", 17 | "DescribableTextures": "{} texture.", 18 | "EuroSAT": "a centered satellite photo of {}.", 19 | "StanfordCars": "a photo of a {}.", 20 | "Food101": "a photo of {}, a type of food.", 21 | "SUN397": "a photo of a {}.", 22 | "Caltech101": "a photo of a {}.", 23 | "UCF101": "a photo of a person doing {}.", 24 | "ImageNet": "a photo of a {}.", 25 | "ImageNetSketch": "a photo of a {}.", 26 | "ImageNetV2": "a photo of a {}.", 27 | "ImageNetA": "a photo of a {}.", 28 | "ImageNetR": "a photo of a {}.", 29 | } 30 | 31 | 32 | @TRAINER_REGISTRY.register() 33 | class ZeroshotCLIP(TrainerX): 34 | def build_model(self): 35 | cfg = self.cfg 36 | classnames = self.dm.dataset.classnames 37 | 38 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 39 | clip_model = load_clip_to_cpu(cfg) 40 | clip_model.to(self.device) 41 | 42 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME] 43 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 44 | print(f"Prompts: {prompts}") 45 | prompts = torch.cat([clip.tokenize(p) for p in prompts]) 46 | prompts = prompts.to(self.device) 47 | 48 | with torch.no_grad(): 49 | text_features = clip_model.encode_text(prompts) 50 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 51 | 52 | self.text_features = text_features 53 | self.clip_model = clip_model 54 | 55 | def model_inference(self, image): 56 | image_features = self.clip_model.encode_image(image) 57 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 58 | logit_scale = self.clip_model.logit_scale.exp() 59 | logits = logit_scale * image_features @ self.text_features.t() 60 | return logits 61 | 62 | 63 | @TRAINER_REGISTRY.register() 64 | class ZeroshotCLIP2(ZeroshotCLIP): 65 | """Prompt ensembling.""" 66 | 67 | # templates = IMAGENET_TEMPLATES 68 | templates = IMAGENET_TEMPLATES_SELECT 69 | 70 | def build_model(self): 71 | cfg = self.cfg 72 | classnames = self.dm.dataset.classnames 73 | 74 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})") 75 | clip_model = load_clip_to_cpu(cfg) 76 | clip_model.to(self.device) 77 | 78 | for params in clip_model.parameters(): 79 | params.requires_grad_(False) 80 | 81 | # add custom-made prompt 82 | if cfg.DATASET.NAME != "ImageNet": 83 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]] 84 | 85 | num_temp = len(self.templates) 86 | print(f"Prompt ensembling (n={num_temp})") 87 | 88 | mean_text_features = 0 89 | for i, temp in enumerate(self.templates): 90 | prompts = [temp.format(c.replace("_", " ")) for c in classnames] 91 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device) 92 | text_features = clip_model.encode_text(prompts) 93 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 94 | mean_text_features = mean_text_features + text_features 95 | mean_text_features = mean_text_features / num_temp 96 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True) 97 | 98 | self.text_features = mean_text_features 99 | self.clip_model = clip_model 100 | --------------------------------------------------------------------------------