├── .idea
├── ALIGN.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── maple.iml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── README.md
├── clip
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── clip.cpython-37.pyc
│ ├── clip.cpython-38.pyc
│ ├── model.cpython-37.pyc
│ ├── model.cpython-38.pyc
│ ├── simple_tokenizer.cpython-37.pyc
│ └── simple_tokenizer.cpython-38.pyc
├── bpe_simple_vocab_16e6.txt.gz
├── clip.py
├── model.py
└── simple_tokenizer.py
├── clip_words.csv
├── configs
├── datasets
│ ├── caltech101.yaml
│ ├── dtd.yaml
│ ├── eurosat.yaml
│ ├── fgvc_aircraft.yaml
│ ├── food101.yaml
│ ├── imagenet.yaml
│ ├── imagenet_a.yaml
│ ├── imagenet_r.yaml
│ ├── imagenet_sketch.yaml
│ ├── imagenetv2.yaml
│ ├── oxford_flowers.yaml
│ ├── oxford_pets.yaml
│ ├── stanford_cars.yaml
│ ├── sun397.yaml
│ └── ucf101.yaml
└── trainers
│ ├── CoCoOp
│ ├── vit_b16_c16_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1.yaml
│ ├── vit_b16_c4_ep10_batch1_ctxv1.yaml
│ └── vit_b16_c8_ep10_batch1.yaml
│ ├── CoOp
│ ├── rn101.yaml
│ ├── rn101_ep50.yaml
│ ├── rn50.yaml
│ ├── rn50_ctxv1.yaml
│ ├── rn50_ep100.yaml
│ ├── rn50_ep50.yaml
│ ├── rn50_ep50_ctxv1.yaml
│ ├── rn50_val.yaml
│ ├── vit_b16.yaml
│ ├── vit_b16_ep100.yaml
│ ├── vit_b16_ep50.yaml
│ ├── vit_b32.yaml
│ └── vit_b32_ep50.yaml
│ ├── IVLP
│ ├── vit_b16_c2_ep5_batch4_2+2ctx.yaml
│ └── vit_b16_c2_ep5_batch4_4ctx_language_only.yaml
│ ├── MMP
│ ├── sun397.yaml
│ ├── vit_b16_c2_ep5_batch4_2ctx.yaml
│ └── vit_h.yaml
│ ├── MaPLe
│ ├── vit_b16_c2_ep5_batch4_2ctx.yaml
│ └── vit_b16_c2_ep5_batch4_2ctx_cross_datasets.yaml
│ └── VPT
│ └── vit_b16_c2_ep5_batch4_4.yaml
├── datasets
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── caltech101.cpython-37.pyc
│ ├── caltech101.cpython-38.pyc
│ ├── dtd.cpython-37.pyc
│ ├── dtd.cpython-38.pyc
│ ├── eurosat.cpython-37.pyc
│ ├── eurosat.cpython-38.pyc
│ ├── fgvc_aircraft.cpython-37.pyc
│ ├── fgvc_aircraft.cpython-38.pyc
│ ├── food101.cpython-37.pyc
│ ├── food101.cpython-38.pyc
│ ├── imagenet.cpython-37.pyc
│ ├── imagenet.cpython-38.pyc
│ ├── imagenet_a.cpython-37.pyc
│ ├── imagenet_a.cpython-38.pyc
│ ├── imagenet_r.cpython-37.pyc
│ ├── imagenet_r.cpython-38.pyc
│ ├── imagenet_sketch.cpython-37.pyc
│ ├── imagenet_sketch.cpython-38.pyc
│ ├── imagenetv2.cpython-37.pyc
│ ├── imagenetv2.cpython-38.pyc
│ ├── oxford_flowers.cpython-37.pyc
│ ├── oxford_flowers.cpython-38.pyc
│ ├── oxford_pets.cpython-37.pyc
│ ├── oxford_pets.cpython-38.pyc
│ ├── stanford_cars.cpython-37.pyc
│ ├── stanford_cars.cpython-38.pyc
│ ├── sun397.cpython-37.pyc
│ ├── sun397.cpython-38.pyc
│ ├── ucf101.cpython-37.pyc
│ └── ucf101.cpython-38.pyc
├── caltech101.py
├── dtd.py
├── eurosat.py
├── fgvc_aircraft.py
├── food101.py
├── imagenet.py
├── imagenet_a.py
├── imagenet_r.py
├── imagenet_sketch.py
├── imagenetv2.py
├── oxford_flowers.py
├── oxford_pets.py
├── stanford_cars.py
├── sun397.py
└── ucf101.py
├── images
└── ALIGN.png
├── parse_test_res.py
├── scripts
├── cocoop
│ ├── base2new_test.sh
│ ├── base2new_train.sh
│ ├── xd_test.sh
│ └── xd_train.sh
├── coop
│ ├── basenewtrain.sh
│ ├── eval.sh
│ └── main.sh
├── independent-vlp
│ ├── base2new_test_ivlp.sh
│ ├── base2new_train_ivlp.sh
│ ├── reproduce_ivlp.sh
│ ├── xd_test_ivlp.sh
│ └── xd_train_ivlp.sh
├── language-prompting
│ ├── base2new_test_lp.sh
│ ├── base2new_train_lp.sh
│ ├── reproduce_lp.sh
│ ├── xd_test_lp.sh
│ └── xd_train_lp.sh
├── maple
│ ├── base2new_test_maple.sh
│ ├── base2new_train_maple.sh
│ ├── fst.sh
│ ├── reproduce_maple.sh
│ ├── reproduce_maple_xd.sh
│ ├── xd_test_maple.sh
│ └── xd_train_maple.sh
├── mmp
│ ├── base_to_new_test.sh
│ └── base_to_new_train.sh
├── vpt
│ ├── base2new_test_vpt.sh
│ ├── base2new_train_vpt.sh
│ ├── reproduce_vpt.sh
│ ├── xd_test_vpt.sh
│ └── xd_train_vpt.sh
└── zsclip
│ └── zeroshot.sh
├── train.py
└── trainers
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── cocoop.cpython-37.pyc
├── cocoop.cpython-38.pyc
├── coop.cpython-37.pyc
├── coop.cpython-38.pyc
├── imagenet_templates.cpython-37.pyc
├── imagenet_templates.cpython-38.pyc
├── independentVL.cpython-37.pyc
├── independentVL.cpython-38.pyc
├── maple.cpython-37.pyc
├── maple.cpython-38.pyc
├── mmp.cpython-37.pyc
├── mmp.cpython-38.pyc
├── vpt.cpython-37.pyc
├── vpt.cpython-38.pyc
├── zsclip.cpython-37.pyc
└── zsclip.cpython-38.pyc
├── cocoop.py
├── coop.py
├── imagenet_templates.py
├── independentVL.py
├── maple.py
├── mmp.py
├── vpt.py
└── zsclip.py
/.idea/ALIGN.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/maple.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 | {
30 | "associatedIndex": 6
31 | }
32 |
33 |
34 |
35 |
36 |
37 |
38 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 | 1680770662970
147 |
148 |
149 | 1680770662970
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Tuning Multi-mode Token-level Prompt Alignment across Modalities [NeurIPS 2023]
2 |
3 | This is the official implementation of our paper [Tuning Multi-mode Token-level Prompt Alignment across Modalities](https://arxiv.org/abs/2309.13847) in NeurIPS 2023.
4 |
5 | 
6 |
7 | The proposed ALIGN algorithm aims to learn multiple prompts in both textual and visual domains. Given the M visual prompts and N textual prompts, ALIGN first views the label/image as discrete distributions over the
8 | the M and N supporting, and each distribution itself can further be modeled as a discrete distribution over its model-specific token-level space. ALIGN applies the Prompt-level OT and Token-level OT to align those two
9 | domains.
10 |
11 | ## TODO
12 | Due to some ddls, we will add more details about the training scripts and results soon.
13 |
14 | ## Getting Started
15 | ### Install
16 | - Clone this repo:
17 | ```bash
18 | git clone https://github.com/wds2014/ALIGN.git
19 | cd ALIGN
20 | ```
21 | - Please follow the [INSTALL.md](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main/docs/INSTALL.md) to build the python environment.
22 |
23 | ### Dataset
24 | - Datasets in our paper
25 |
26 | The datasets we used is as the same as previous works (CoOp and MAPLE). Please follow the [DATASETS.md](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main/docs/DATASETS.md) to prepare all datasets.
27 |
28 | ### Training
29 | - Easy to train:
30 | ```bash
31 | cd scripts/mmp
32 | bash base_to_new_train.sh
33 | ```
34 | Change the DATASET and SEED in the .sh file to train our model in different datasets and seeds.
35 |
36 | ## Citation
37 | If you find this repo useful to your project, please consider to cite it with following bib:
38 |
39 | ```bash
40 | @article{wang2023tuning,
41 | title={Tuning Multi-mode Token-level Prompt Alignment across Modalities},
42 | author={Wang, Dongsheng and Li, Miaoge and Liu, Xinyang and Xu, MingSheng and Chen, Bo and Zhang, Hanwang},
43 | journal={arXiv preprint arXiv:2309.13847},
44 | year={2023}
45 | }
46 | ```
47 |
48 | ## Acknowledgements
49 | Our code is modified based on [CoOp](https://github.com/KaiyangZhou/CoOp) and [MAPLE](https://github.com/muzairkhattak/multimodal-prompt-learning/tree/main) repository.
50 | We thank the authors for releasing their code.
--------------------------------------------------------------------------------
/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from .clip import *
2 |
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/clip.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/clip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/clip.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/model.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/simple_tokenizer.cpython-37.pyc
--------------------------------------------------------------------------------
/clip/__pycache__/simple_tokenizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/__pycache__/simple_tokenizer.cpython-38.pyc
--------------------------------------------------------------------------------
/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Union, List
6 |
7 | import torch
8 | from PIL import Image
9 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
10 | from tqdm import tqdm
11 |
12 | from .model import build_model
13 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
14 |
15 | try:
16 | from torchvision.transforms import InterpolationMode
17 | BICUBIC = InterpolationMode.BICUBIC
18 | except ImportError:
19 | BICUBIC = Image.BICUBIC
20 |
21 |
22 | if torch.__version__.split(".") < ["1", "7", "1"]:
23 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
24 |
25 |
26 | __all__ = ["available_models", "load", "tokenize"]
27 | _tokenizer = _Tokenizer()
28 |
29 | _MODELS = {
30 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
31 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
32 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
33 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
34 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
35 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
36 | }
37 |
38 |
39 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
40 | os.makedirs(root, exist_ok=True)
41 | filename = os.path.basename(url)
42 |
43 | expected_sha256 = url.split("/")[-2]
44 | download_target = os.path.join(root, filename)
45 |
46 | if os.path.exists(download_target) and not os.path.isfile(download_target):
47 | raise RuntimeError(f"{download_target} exists and is not a regular file")
48 |
49 | if os.path.isfile(download_target):
50 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
51 | return download_target
52 | else:
53 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
54 |
55 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
56 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
57 | while True:
58 | buffer = source.read(8192)
59 | if not buffer:
60 | break
61 |
62 | output.write(buffer)
63 | loop.update(len(buffer))
64 |
65 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
66 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
67 |
68 | return download_target
69 |
70 |
71 | def _transform(n_px):
72 | return Compose([
73 | Resize(n_px, interpolation=BICUBIC),
74 | CenterCrop(n_px),
75 | lambda image: image.convert("RGB"),
76 | ToTensor(),
77 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
78 | ])
79 |
80 |
81 | def available_models() -> List[str]:
82 | """Returns the names of available CLIP models"""
83 | return list(_MODELS.keys())
84 |
85 |
86 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=False):
87 | """Load a CLIP model
88 |
89 | Parameters
90 | ----------
91 | name : str
92 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
93 |
94 | device : Union[str, torch.device]
95 | The device to put the loaded model
96 |
97 | jit : bool
98 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
99 |
100 | Returns
101 | -------
102 | model : torch.nn.Module
103 | The CLIP model
104 |
105 | preprocess : Callable[[PIL.Image], torch.Tensor]
106 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
107 | """
108 | if name in _MODELS:
109 | model_path = _download(_MODELS[name])
110 | elif os.path.isfile(name):
111 | model_path = name
112 | else:
113 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
114 |
115 | try:
116 | # loading JIT archive
117 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
118 | state_dict = None
119 | except RuntimeError:
120 | # loading saved state dict
121 | if jit:
122 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
123 | jit = False
124 | state_dict = torch.load(model_path, map_location="cpu")
125 |
126 | if not jit:
127 | model = build_model(state_dict or model.state_dict()).to(device)
128 | if str(device) == "cpu":
129 | model.float()
130 | return model, _transform(model.visual.input_resolution)
131 |
132 | # patch the device names
133 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
134 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
135 |
136 | def patch_device(module):
137 | try:
138 | graphs = [module.graph] if hasattr(module, "graph") else []
139 | except RuntimeError:
140 | graphs = []
141 |
142 | if hasattr(module, "forward1"):
143 | graphs.append(module.forward1.graph)
144 |
145 | for graph in graphs:
146 | for node in graph.findAllNodes("prim::Constant"):
147 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
148 | node.copyAttributes(device_node)
149 |
150 | model.apply(patch_device)
151 | patch_device(model.encode_image)
152 | patch_device(model.encode_text)
153 |
154 | # patch dtype to float32 on CPU
155 | if str(device) == "cpu":
156 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
157 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
158 | float_node = float_input.node()
159 |
160 | def patch_float(module):
161 | try:
162 | graphs = [module.graph] if hasattr(module, "graph") else []
163 | except RuntimeError:
164 | graphs = []
165 |
166 | if hasattr(module, "forward1"):
167 | graphs.append(module.forward1.graph)
168 |
169 | for graph in graphs:
170 | for node in graph.findAllNodes("aten::to"):
171 | inputs = list(node.inputs())
172 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
173 | if inputs[i].node()["value"] == 5:
174 | inputs[i].node().copyAttributes(float_node)
175 |
176 | model.apply(patch_float)
177 | patch_float(model.encode_image)
178 | patch_float(model.encode_text)
179 |
180 | model.float()
181 |
182 | return model, _transform(model.input_resolution.item())
183 |
184 |
185 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
186 | """
187 | Returns the tokenized representation of given input string(s)
188 |
189 | Parameters
190 | ----------
191 | texts : Union[str, List[str]]
192 | An input string or a list of input strings to tokenize
193 |
194 | context_length : int
195 | The context length to use; all CLIP models use 77 as the context length
196 |
197 | truncate: bool
198 | Whether to truncate the text in case its encoding is longer than the context length
199 |
200 | Returns
201 | -------
202 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
203 | """
204 | if isinstance(texts, str):
205 | texts = [texts]
206 |
207 | sot_token = _tokenizer.encoder["<|startoftext|>"]
208 | eot_token = _tokenizer.encoder["<|endoftext|>"]
209 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
210 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
211 |
212 | for i, tokens in enumerate(all_tokens):
213 | if len(tokens) > context_length:
214 | if truncate:
215 | tokens = tokens[:context_length]
216 | tokens[-1] = eot_token
217 | else:
218 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
219 | result[i, :len(tokens)] = torch.tensor(tokens)
220 |
221 | return result
222 |
--------------------------------------------------------------------------------
/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/configs/datasets/caltech101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Caltech101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/dtd.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "DescribableTextures"
3 |
--------------------------------------------------------------------------------
/configs/datasets/eurosat.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "EuroSAT"
3 |
--------------------------------------------------------------------------------
/configs/datasets/fgvc_aircraft.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "FGVCAircraft"
3 |
--------------------------------------------------------------------------------
/configs/datasets/food101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "Food101"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNet"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_a.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetA"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_r.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetR"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenet_sketch.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetSketch"
3 |
--------------------------------------------------------------------------------
/configs/datasets/imagenetv2.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "ImageNetV2"
3 |
--------------------------------------------------------------------------------
/configs/datasets/oxford_flowers.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordFlowers"
--------------------------------------------------------------------------------
/configs/datasets/oxford_pets.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "OxfordPets"
--------------------------------------------------------------------------------
/configs/datasets/stanford_cars.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "StanfordCars"
3 |
--------------------------------------------------------------------------------
/configs/datasets/sun397.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "SUN397"
3 |
--------------------------------------------------------------------------------
/configs/datasets/ucf101.yaml:
--------------------------------------------------------------------------------
1 | DATASET:
2 | NAME: "UCF101"
3 |
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c16_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 16
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c4_ep10_batch1_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 4
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoCoOp/vit_b16_c8_ep10_batch1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 1
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | COCOOP:
33 | N_CTX: 8
34 | CTX_INIT: ""
35 | PREC: "fp16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn101.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN101"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn101_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN101"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: "a photo of a"
34 |
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_ep50_ctxv1.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "RN50"
30 |
31 | TRAINER:
32 | COOP:
33 | CTX_INIT: "a photo of a"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/rn50_val.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 200
4 | TEST:
5 | BATCH_SIZE: 200
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | MODEL:
16 | BACKBONE:
17 | NAME: "RN50"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16_ep100.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 100
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b16_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b32.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 200
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/32"
--------------------------------------------------------------------------------
/configs/trainers/CoOp/vit_b32_ep50.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 32
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 50
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 5
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/32"
--------------------------------------------------------------------------------
/configs/trainers/IVLP/vit_b16_c2_ep5_batch4_2+2ctx.yaml:
--------------------------------------------------------------------------------
1 | # Deep independent V-L Prompting
2 | DATALOADER:
3 | TRAIN_X:
4 | BATCH_SIZE: 4
5 | TEST:
6 | BATCH_SIZE: 100
7 | NUM_WORKERS: 8
8 |
9 | INPUT:
10 | SIZE: (224, 224)
11 | INTERPOLATION: "bicubic"
12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.0035
19 | MAX_EPOCH: 5
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 20
27 |
28 | MODEL:
29 | BACKBONE:
30 | NAME: "ViT-B/16"
31 |
32 | TRAINER:
33 | IVLP:
34 | N_CTX_VISION: 2
35 | N_CTX_TEXT: 2
36 | CTX_INIT: "a photo of a"
37 | PREC: "fp16"
38 | PROMPT_DEPTH_VISION: 12
39 | PROMPT_DEPTH_TEXT: 12
--------------------------------------------------------------------------------
/configs/trainers/IVLP/vit_b16_c2_ep5_batch4_4ctx_language_only.yaml:
--------------------------------------------------------------------------------
1 | # Deep language prompting
2 | DATALOADER:
3 | TRAIN_X:
4 | BATCH_SIZE: 4
5 | TEST:
6 | BATCH_SIZE: 100
7 | NUM_WORKERS: 8
8 |
9 | INPUT:
10 | SIZE: (224, 224)
11 | INTERPOLATION: "bicubic"
12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.0025
19 | MAX_EPOCH: 5
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 20
27 |
28 | MODEL:
29 | BACKBONE:
30 | NAME: "ViT-B/16"
31 |
32 | TRAINER:
33 | IVLP:
34 | N_CTX_VISION: 0
35 | N_CTX_TEXT: 4
36 | CTX_INIT: "a photo of a"
37 | PREC: "fp16"
38 | PROMPT_DEPTH_VISION: 0
39 | PROMPT_DEPTH_TEXT: 12
40 |
--------------------------------------------------------------------------------
/configs/trainers/MMP/sun397.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 4
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.002
18 | MAX_EPOCH: 20
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 | # CHECKPOINT_FREQ: 5
27 |
28 | #TEST:
29 | # FINAL_MODEL: best_val
30 | # NO_TEST: False
31 |
32 | MODEL:
33 | BACKBONE:
34 | NAME: "ViT-B/16"
35 |
36 | TRAINER:
37 | MMP:
38 | N_CTX: 2
39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the"
40 | PREC: "fp16"
41 | TEXT_PROMPT_DEPTH: 9
42 | VISION_PROMPT_DEPTH: 9
43 | TEXT_PROMPT_NUMBER: 2
44 | VISION_PROMPT_NUMBER: 2
45 | HIERARCHICAL: True
46 | USECT: False
47 | # HIERARCHICAL: False
48 | # USECT: True
--------------------------------------------------------------------------------
/configs/trainers/MMP/vit_b16_c2_ep5_batch4_2ctx.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.0025
18 | MAX_EPOCH: 20
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 | # CHECKPOINT_FREQ: 5
27 |
28 | #TEST:
29 | # FINAL_MODEL: best_val
30 | # NO_TEST: False
31 |
32 | MODEL:
33 | BACKBONE:
34 | NAME: "ViT-B/16"
35 |
36 | TRAINER:
37 | MMP:
38 | N_CTX: 2
39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the"
40 | PREC: "fp16"
41 | TEXT_PROMPT_DEPTH: 9
42 | VISION_PROMPT_DEPTH: 9
43 | TEXT_PROMPT_NUMBER: 4
44 | VISION_PROMPT_NUMBER: 4
45 | HIERARCHICAL: True
46 | USECT: False
47 | # HIERARCHICAL: False
48 | # USECT: True
49 |
--------------------------------------------------------------------------------
/configs/trainers/MMP/vit_h.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.0025
18 | MAX_EPOCH: 20
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 | # CHECKPOINT_FREQ: 5
27 |
28 | #TEST:
29 | # FINAL_MODEL: best_val
30 | # NO_TEST: False
31 |
32 | MODEL:
33 | BACKBONE:
34 | NAME: "ViT-H/14"
35 |
36 | TRAINER:
37 | MMP:
38 | N_CTX: 2
39 | CTX_INIT: "a photo of a" #\ta nice photo of \ta large picture of \ta small photo of a \ta nice sketch of a" #"\t a doodle of a \t a bright photo of a \t a sketch of a \t a tattoo of a \t a drawing of a \t a painting of the \t a drawing of the"
40 | PREC: "fp16"
41 | TEXT_PROMPT_DEPTH: 9
42 | VISION_PROMPT_DEPTH: 9
43 | TEXT_PROMPT_NUMBER: 4
44 | VISION_PROMPT_NUMBER: 4
45 | HIERARCHICAL: True
46 | USECT: False
47 | # HIERARCHICAL: False
48 | # USECT: True
49 |
--------------------------------------------------------------------------------
/configs/trainers/MaPLe/vit_b16_c2_ep5_batch4_2ctx.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.0035
18 | MAX_EPOCH: 10
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | MAPLE:
33 | N_CTX: 2
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 | PROMPT_DEPTH: 9
--------------------------------------------------------------------------------
/configs/trainers/MaPLe/vit_b16_c2_ep5_batch4_2ctx_cross_datasets.yaml:
--------------------------------------------------------------------------------
1 | DATALOADER:
2 | TRAIN_X:
3 | BATCH_SIZE: 4
4 | TEST:
5 | BATCH_SIZE: 100
6 | NUM_WORKERS: 8
7 |
8 | INPUT:
9 | SIZE: (224, 224)
10 | INTERPOLATION: "bicubic"
11 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
12 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
13 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
14 |
15 | OPTIM:
16 | NAME: "sgd"
17 | LR: 0.0026
18 | MAX_EPOCH: 2
19 | LR_SCHEDULER: "cosine"
20 | WARMUP_EPOCH: 1
21 | WARMUP_TYPE: "constant"
22 | WARMUP_CONS_LR: 1e-5
23 |
24 | TRAIN:
25 | PRINT_FREQ: 20
26 |
27 | MODEL:
28 | BACKBONE:
29 | NAME: "ViT-B/16"
30 |
31 | TRAINER:
32 | MAPLE:
33 | N_CTX: 2
34 | CTX_INIT: "a photo of a"
35 | PREC: "fp16"
36 | PROMPT_DEPTH: 3
--------------------------------------------------------------------------------
/configs/trainers/VPT/vit_b16_c2_ep5_batch4_4.yaml:
--------------------------------------------------------------------------------
1 | # Deep vision prompting
2 | DATALOADER:
3 | TRAIN_X:
4 | BATCH_SIZE: 4
5 | TEST:
6 | BATCH_SIZE: 100
7 | NUM_WORKERS: 8
8 |
9 | INPUT:
10 | SIZE: (224, 224)
11 | INTERPOLATION: "bicubic"
12 | PIXEL_MEAN: [0.48145466, 0.4578275, 0.40821073]
13 | PIXEL_STD: [0.26862954, 0.26130258, 0.27577711]
14 | TRANSFORMS: ["random_resized_crop", "random_flip", "normalize"]
15 |
16 | OPTIM:
17 | NAME: "sgd"
18 | LR: 0.0025
19 | MAX_EPOCH: 5
20 | LR_SCHEDULER: "cosine"
21 | WARMUP_EPOCH: 1
22 | WARMUP_TYPE: "constant"
23 | WARMUP_CONS_LR: 1e-5
24 |
25 | TRAIN:
26 | PRINT_FREQ: 20
27 |
28 | MODEL:
29 | BACKBONE:
30 | NAME: "ViT-B/16"
31 |
32 | TRAINER:
33 | VPT:
34 | N_CTX_VISION: 8
35 | CTX_INIT: "a photo of a"
36 | PREC: "fp16"
37 | PROMPT_DEPTH_VISION: 12
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__init__.py
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/caltech101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/caltech101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/caltech101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/caltech101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dtd.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/dtd.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/dtd.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/dtd.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/eurosat.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/eurosat.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/eurosat.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/eurosat.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/fgvc_aircraft.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/fgvc_aircraft.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/food101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/food101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/food101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/food101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_a.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_a.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_a.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_a.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_r.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_r.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_r.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_r.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_sketch.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_sketch.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenet_sketch.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenet_sketch.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenetv2.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenetv2.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/imagenetv2.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/imagenetv2.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_flowers.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_flowers.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_flowers.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_flowers.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_pets.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_pets.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/oxford_pets.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/oxford_pets.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stanford_cars.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/stanford_cars.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/stanford_cars.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/stanford_cars.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sun397.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/sun397.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/sun397.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/sun397.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ucf101.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/ucf101.cpython-37.pyc
--------------------------------------------------------------------------------
/datasets/__pycache__/ucf101.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/datasets/__pycache__/ucf101.cpython-38.pyc
--------------------------------------------------------------------------------
/datasets/caltech101.py:
--------------------------------------------------------------------------------
1 | import os
2 | # import pickle
3 | import pickle5 as pickle
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 | from .dtd import DescribableTextures as DTD
10 |
11 | IGNORED = ["BACKGROUND_Google", "Faces_easy"]
12 | NEW_CNAMES = {
13 | "airplanes": "airplane",
14 | "Faces": "face",
15 | "Leopards": "leopard",
16 | "Motorbikes": "motorbike",
17 | }
18 |
19 |
20 | @DATASET_REGISTRY.register()
21 | class Caltech101(DatasetBase):
22 |
23 | dataset_dir = "caltech-101"
24 |
25 | def __init__(self, cfg):
26 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
27 | self.dataset_dir = os.path.join(root, self.dataset_dir)
28 | self.image_dir = os.path.join(self.dataset_dir, "101_ObjectCategories")
29 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Caltech101.json")
30 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
31 | mkdir_if_missing(self.split_fewshot_dir)
32 |
33 | if os.path.exists(self.split_path):
34 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
35 | else:
36 | train, val, test = DTD.read_and_split_data(self.image_dir, ignored=IGNORED, new_cnames=NEW_CNAMES)
37 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
38 |
39 | num_shots = cfg.DATASET.NUM_SHOTS
40 | if num_shots >= 1:
41 | seed = cfg.SEED
42 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
43 |
44 | if os.path.exists(preprocessed):
45 | print(f"Loading preprocessed few-shot data from {preprocessed}")
46 | with open(preprocessed, "rb") as file:
47 | data = pickle.load(file)
48 | train, val = data["train"], data["val"]
49 | else:
50 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
51 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
52 | data = {"train": train, "val": val}
53 | print(f"Saving preprocessed few-shot data to {preprocessed}")
54 | with open(preprocessed, "wb") as file:
55 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
56 |
57 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
58 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
59 |
60 | super().__init__(train_x=train, val=val, test=test)
61 |
--------------------------------------------------------------------------------
/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | # import pickle5 as pickle
4 | import random
5 |
6 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
7 | from dassl.utils import listdir_nohidden, mkdir_if_missing
8 |
9 | from .oxford_pets import OxfordPets
10 |
11 |
12 | @DATASET_REGISTRY.register()
13 | class DescribableTextures(DatasetBase):
14 |
15 | dataset_dir = "dtd"
16 |
17 | def __init__(self, cfg):
18 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
19 | self.dataset_dir = os.path.join(root, self.dataset_dir)
20 | self.image_dir = os.path.join(self.dataset_dir, "images")
21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_DescribableTextures.json")
22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
23 | mkdir_if_missing(self.split_fewshot_dir)
24 |
25 | if os.path.exists(self.split_path):
26 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
27 | else:
28 | train, val, test = self.read_and_split_data(self.image_dir)
29 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
30 |
31 | num_shots = cfg.DATASET.NUM_SHOTS
32 | if num_shots >= 1:
33 | seed = cfg.SEED
34 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
35 |
36 | if os.path.exists(preprocessed):
37 | print(f"Loading preprocessed few-shot data from {preprocessed}")
38 | with open(preprocessed, "rb") as file:
39 | data = pickle.load(file)
40 | train, val = data["train"], data["val"]
41 | else:
42 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
43 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
44 | data = {"train": train, "val": val}
45 | print(f"Saving preprocessed few-shot data to {preprocessed}")
46 | with open(preprocessed, "wb") as file:
47 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
48 |
49 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
50 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
51 |
52 | super().__init__(train_x=train, val=val, test=test)
53 |
54 | @staticmethod
55 | def read_and_split_data(image_dir, p_trn=0.5, p_val=0.2, ignored=[], new_cnames=None):
56 | # The data are supposed to be organized into the following structure
57 | # =============
58 | # images/
59 | # dog/
60 | # cat/
61 | # horse/
62 | # =============
63 | categories = listdir_nohidden(image_dir)
64 | categories = [c for c in categories if c not in ignored]
65 | categories.sort()
66 |
67 | p_tst = 1 - p_trn - p_val
68 | print(f"Splitting into {p_trn:.0%} train, {p_val:.0%} val, and {p_tst:.0%} test")
69 |
70 | def _collate(ims, y, c):
71 | items = []
72 | for im in ims:
73 | item = Datum(impath=im, label=y, classname=c) # is already 0-based
74 | items.append(item)
75 | return items
76 |
77 | train, val, test = [], [], []
78 | for label, category in enumerate(categories):
79 | category_dir = os.path.join(image_dir, category)
80 | images = listdir_nohidden(category_dir)
81 | images = [os.path.join(category_dir, im) for im in images]
82 | random.shuffle(images)
83 | n_total = len(images)
84 | n_train = round(n_total * p_trn)
85 | n_val = round(n_total * p_val)
86 | n_test = n_total - n_train - n_val
87 | assert n_train > 0 and n_val > 0 and n_test > 0
88 |
89 | if new_cnames is not None and category in new_cnames:
90 | category = new_cnames[category]
91 |
92 | train.extend(_collate(images[:n_train], label, category))
93 | val.extend(_collate(images[n_train : n_train + n_val], label, category))
94 | test.extend(_collate(images[n_train + n_val :], label, category))
95 |
96 | return train, val, test
97 |
--------------------------------------------------------------------------------
/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | # import pickle
3 | import pickle5 as pickle
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 | from .dtd import DescribableTextures as DTD
10 |
11 | NEW_CNAMES = {
12 | "AnnualCrop": "Annual Crop Land",
13 | "Forest": "Forest",
14 | "HerbaceousVegetation": "Herbaceous Vegetation Land",
15 | "Highway": "Highway or Road",
16 | "Industrial": "Industrial Buildings",
17 | "Pasture": "Pasture Land",
18 | "PermanentCrop": "Permanent Crop Land",
19 | "Residential": "Residential Buildings",
20 | "River": "River",
21 | "SeaLake": "Sea or Lake",
22 | }
23 |
24 |
25 | @DATASET_REGISTRY.register()
26 | class EuroSAT(DatasetBase):
27 |
28 | dataset_dir = "eurosat"
29 |
30 | def __init__(self, cfg):
31 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
32 | self.dataset_dir = os.path.join(root, self.dataset_dir)
33 | self.image_dir = os.path.join(self.dataset_dir, "2750")
34 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_EuroSAT.json")
35 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
36 | mkdir_if_missing(self.split_fewshot_dir)
37 |
38 | if os.path.exists(self.split_path):
39 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
40 | else:
41 | train, val, test = DTD.read_and_split_data(self.image_dir, new_cnames=NEW_CNAMES)
42 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
43 |
44 | num_shots = cfg.DATASET.NUM_SHOTS
45 | if num_shots >= 1:
46 | seed = cfg.SEED
47 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
48 |
49 | if os.path.exists(preprocessed):
50 | print(f"Loading preprocessed few-shot data from {preprocessed}")
51 | with open(preprocessed, "rb") as file:
52 | data = pickle.load(file)
53 | train, val = data["train"], data["val"]
54 | else:
55 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
56 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
57 | data = {"train": train, "val": val}
58 | print(f"Saving preprocessed few-shot data to {preprocessed}")
59 | with open(preprocessed, "wb") as file:
60 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
61 |
62 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
63 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
64 |
65 | super().__init__(train_x=train, val=val, test=test)
66 |
67 | def update_classname(self, dataset_old):
68 | dataset_new = []
69 | for item_old in dataset_old:
70 | cname_old = item_old.classname
71 | cname_new = NEW_CLASSNAMES[cname_old]
72 | item_new = Datum(impath=item_old.impath, label=item_old.label, classname=cname_new)
73 | dataset_new.append(item_new)
74 | return dataset_new
75 |
--------------------------------------------------------------------------------
/datasets/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class FGVCAircraft(DatasetBase):
12 |
13 | dataset_dir = "fgvc_aircraft"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "images")
19 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
20 | mkdir_if_missing(self.split_fewshot_dir)
21 |
22 | classnames = []
23 | with open(os.path.join(self.dataset_dir, "variants.txt"), "r") as f:
24 | lines = f.readlines()
25 | for line in lines:
26 | classnames.append(line.strip())
27 | cname2lab = {c: i for i, c in enumerate(classnames)}
28 |
29 | train = self.read_data(cname2lab, "images_variant_train.txt")
30 | val = self.read_data(cname2lab, "images_variant_val.txt")
31 | test = self.read_data(cname2lab, "images_variant_test.txt")
32 |
33 | num_shots = cfg.DATASET.NUM_SHOTS
34 | if num_shots >= 1:
35 | seed = cfg.SEED
36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
37 |
38 | if os.path.exists(preprocessed):
39 | print(f"Loading preprocessed few-shot data from {preprocessed}")
40 | with open(preprocessed, "rb") as file:
41 | data = pickle.load(file)
42 | train, val = data["train"], data["val"]
43 | else:
44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
46 | data = {"train": train, "val": val}
47 | print(f"Saving preprocessed few-shot data to {preprocessed}")
48 | with open(preprocessed, "wb") as file:
49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
50 |
51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
52 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
53 |
54 | super().__init__(train_x=train, val=val, test=test)
55 |
56 | def read_data(self, cname2lab, split_file):
57 | filepath = os.path.join(self.dataset_dir, split_file)
58 | items = []
59 |
60 | with open(filepath, "r") as f:
61 | lines = f.readlines()
62 | for line in lines:
63 | line = line.strip().split(" ")
64 | imname = line[0] + ".jpg"
65 | classname = " ".join(line[1:])
66 | impath = os.path.join(self.image_dir, imname)
67 | label = cname2lab[classname]
68 | item = Datum(impath=impath, label=label, classname=classname)
69 | items.append(item)
70 |
71 | return items
72 |
--------------------------------------------------------------------------------
/datasets/food101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 | from .dtd import DescribableTextures as DTD
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class Food101(DatasetBase):
13 |
14 | dataset_dir = "food-101"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_Food101.json")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
22 | mkdir_if_missing(self.split_fewshot_dir)
23 |
24 | if os.path.exists(self.split_path):
25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
26 | else:
27 | train, val, test = DTD.read_and_split_data(self.image_dir)
28 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
29 |
30 | num_shots = cfg.DATASET.NUM_SHOTS
31 | if num_shots >= 1:
32 | seed = cfg.SEED
33 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
34 |
35 | if os.path.exists(preprocessed):
36 | print(f"Loading preprocessed few-shot data from {preprocessed}")
37 | with open(preprocessed, "rb") as file:
38 | data = pickle.load(file)
39 | train, val = data["train"], data["val"]
40 | else:
41 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
42 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
43 | data = {"train": train, "val": val}
44 | print(f"Saving preprocessed few-shot data to {preprocessed}")
45 | with open(preprocessed, "wb") as file:
46 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
47 |
48 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
49 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
50 |
51 | super().__init__(train_x=train, val=val, test=test)
52 |
--------------------------------------------------------------------------------
/datasets/imagenet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from collections import OrderedDict
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import listdir_nohidden, mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNet(DatasetBase):
13 |
14 | dataset_dir = "imagenet"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.preprocessed = os.path.join(self.dataset_dir, "preprocessed.pkl")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
22 | mkdir_if_missing(self.split_fewshot_dir)
23 |
24 | if os.path.exists(self.preprocessed):
25 | with open(self.preprocessed, "rb") as f:
26 | preprocessed = pickle.load(f)
27 | train = preprocessed["train"]
28 | test = preprocessed["test"]
29 | else:
30 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
31 | classnames = self.read_classnames(text_file)
32 | train = self.read_data(classnames, "train")
33 | # Follow standard practice to perform evaluation on the val set
34 | # Also used as the val set (so evaluate the last-step model)
35 | test = self.read_data(classnames, "val")
36 |
37 | preprocessed = {"train": train, "test": test}
38 | with open(self.preprocessed, "wb") as f:
39 | pickle.dump(preprocessed, f, protocol=pickle.HIGHEST_PROTOCOL)
40 |
41 | num_shots = cfg.DATASET.NUM_SHOTS
42 | if num_shots >= 1:
43 | seed = cfg.SEED
44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
45 |
46 | if os.path.exists(preprocessed):
47 | print(f"Loading preprocessed few-shot data from {preprocessed}")
48 | with open(preprocessed, "rb") as file:
49 | data = pickle.load(file)
50 | train = data["train"]
51 | else:
52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
53 | data = {"train": train}
54 | print(f"Saving preprocessed few-shot data to {preprocessed}")
55 | with open(preprocessed, "wb") as file:
56 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
57 |
58 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
59 | train, test = OxfordPets.subsample_classes(train, test, subsample=subsample)
60 |
61 | super().__init__(train_x=train, val=test, test=test)
62 |
63 | @staticmethod
64 | def read_classnames(text_file):
65 | """Return a dictionary containing
66 | key-value pairs of : .
67 | """
68 | classnames = OrderedDict()
69 | with open(text_file, "r") as f:
70 | lines = f.readlines()
71 | for line in lines:
72 | line = line.strip().split(" ")
73 | folder = line[0]
74 | classname = " ".join(line[1:])
75 | classnames[folder] = classname
76 | return classnames
77 |
78 | def read_data(self, classnames, split_dir):
79 | split_dir = os.path.join(self.image_dir, split_dir)
80 | folders = sorted(f.name for f in os.scandir(split_dir) if f.is_dir())
81 | items = []
82 |
83 | for label, folder in enumerate(folders):
84 | imnames = listdir_nohidden(os.path.join(split_dir, folder))
85 | classname = classnames[folder]
86 | for imname in imnames:
87 | impath = os.path.join(split_dir, folder, imname)
88 | item = Datum(impath=impath, label=label, classname=classname)
89 | items.append(item)
90 |
91 | return items
92 |
--------------------------------------------------------------------------------
/datasets/imagenet_a.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 | TO_BE_IGNORED = ["README.txt"]
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNetA(DatasetBase):
13 | """ImageNet-A(dversarial).
14 |
15 | This dataset is used for testing only.
16 | """
17 |
18 | dataset_dir = "imagenet-adversarial"
19 |
20 | def __init__(self, cfg):
21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
22 | self.dataset_dir = os.path.join(root, self.dataset_dir)
23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-a")
24 |
25 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
26 | classnames = ImageNet.read_classnames(text_file)
27 |
28 | data = self.read_data(classnames)
29 |
30 | super().__init__(train_x=data, test=data)
31 |
32 | def read_data(self, classnames):
33 | image_dir = self.image_dir
34 | folders = listdir_nohidden(image_dir, sort=True)
35 | folders = [f for f in folders if f not in TO_BE_IGNORED]
36 | items = []
37 |
38 | for label, folder in enumerate(folders):
39 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(image_dir, folder, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/imagenet_r.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 | TO_BE_IGNORED = ["README.txt"]
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class ImageNetR(DatasetBase):
13 | """ImageNet-R(endition).
14 |
15 | This dataset is used for testing only.
16 | """
17 |
18 | dataset_dir = "imagenet-rendition"
19 |
20 | def __init__(self, cfg):
21 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
22 | self.dataset_dir = os.path.join(root, self.dataset_dir)
23 | self.image_dir = os.path.join(self.dataset_dir, "imagenet-r")
24 |
25 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
26 | classnames = ImageNet.read_classnames(text_file)
27 |
28 | data = self.read_data(classnames)
29 |
30 | super().__init__(train_x=data, test=data)
31 |
32 | def read_data(self, classnames):
33 | image_dir = self.image_dir
34 | folders = listdir_nohidden(image_dir, sort=True)
35 | folders = [f for f in folders if f not in TO_BE_IGNORED]
36 | items = []
37 |
38 | for label, folder in enumerate(folders):
39 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(image_dir, folder, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/imagenet_sketch.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class ImageNetSketch(DatasetBase):
11 | """ImageNet-Sketch.
12 |
13 | This dataset is used for testing only.
14 | """
15 |
16 | dataset_dir = "imagenet-sketch"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | self.image_dir = os.path.join(self.dataset_dir, "images")
22 |
23 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
24 | classnames = ImageNet.read_classnames(text_file)
25 |
26 | data = self.read_data(classnames)
27 |
28 | super().__init__(train_x=data, test=data)
29 |
30 | def read_data(self, classnames):
31 | image_dir = self.image_dir
32 | folders = listdir_nohidden(image_dir, sort=True)
33 | items = []
34 |
35 | for label, folder in enumerate(folders):
36 | imnames = listdir_nohidden(os.path.join(image_dir, folder))
37 | classname = classnames[folder]
38 | for imname in imnames:
39 | impath = os.path.join(image_dir, folder, imname)
40 | item = Datum(impath=impath, label=label, classname=classname)
41 | items.append(item)
42 |
43 | return items
44 |
--------------------------------------------------------------------------------
/datasets/imagenetv2.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
4 | from dassl.utils import listdir_nohidden
5 |
6 | from .imagenet import ImageNet
7 |
8 |
9 | @DATASET_REGISTRY.register()
10 | class ImageNetV2(DatasetBase):
11 | """ImageNetV2.
12 |
13 | This dataset is used for testing only.
14 | """
15 |
16 | dataset_dir = "imagenetv2"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | image_dir = "imagenetv2-matched-frequency-format-val"
22 | self.image_dir = os.path.join(self.dataset_dir, image_dir)
23 |
24 | text_file = os.path.join(self.dataset_dir, "classnames.txt")
25 | classnames = ImageNet.read_classnames(text_file)
26 |
27 | data = self.read_data(classnames)
28 |
29 | super().__init__(train_x=data, test=data)
30 |
31 | def read_data(self, classnames):
32 | image_dir = self.image_dir
33 | folders = list(classnames.keys())
34 | items = []
35 |
36 | for label in range(1000):
37 | class_dir = os.path.join(image_dir, str(label))
38 | imnames = listdir_nohidden(class_dir)
39 | folder = folders[label]
40 | classname = classnames[folder]
41 | for imname in imnames:
42 | impath = os.path.join(class_dir, imname)
43 | item = Datum(impath=impath, label=label, classname=classname)
44 | items.append(item)
45 |
46 | return items
47 |
--------------------------------------------------------------------------------
/datasets/oxford_flowers.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import random
4 | from scipy.io import loadmat
5 | from collections import defaultdict
6 |
7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
8 | from dassl.utils import read_json, mkdir_if_missing
9 |
10 | from .oxford_pets import OxfordPets
11 |
12 |
13 | @DATASET_REGISTRY.register()
14 | class OxfordFlowers(DatasetBase):
15 |
16 | dataset_dir = "oxford_flowers"
17 |
18 | def __init__(self, cfg):
19 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
20 | self.dataset_dir = os.path.join(root, self.dataset_dir)
21 | self.image_dir = os.path.join(self.dataset_dir, "jpg")
22 | self.label_file = os.path.join(self.dataset_dir, "imagelabels.mat")
23 | self.lab2cname_file = os.path.join(self.dataset_dir, "cat_to_name.json")
24 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordFlowers.json")
25 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
26 | mkdir_if_missing(self.split_fewshot_dir)
27 |
28 | if os.path.exists(self.split_path):
29 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
30 | else:
31 | train, val, test = self.read_data()
32 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
33 |
34 | num_shots = cfg.DATASET.NUM_SHOTS
35 | if num_shots >= 1:
36 | seed = cfg.SEED
37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
38 |
39 | if os.path.exists(preprocessed):
40 | print(f"Loading preprocessed few-shot data from {preprocessed}")
41 | with open(preprocessed, "rb") as file:
42 | data = pickle.load(file)
43 | train, val = data["train"], data["val"]
44 | else:
45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
47 | data = {"train": train, "val": val}
48 | print(f"Saving preprocessed few-shot data to {preprocessed}")
49 | with open(preprocessed, "wb") as file:
50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
51 |
52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
54 |
55 | super().__init__(train_x=train, val=val, test=test)
56 |
57 | def read_data(self):
58 | tracker = defaultdict(list)
59 | label_file = loadmat(self.label_file)["labels"][0]
60 | for i, label in enumerate(label_file):
61 | imname = f"image_{str(i + 1).zfill(5)}.jpg"
62 | impath = os.path.join(self.image_dir, imname)
63 | label = int(label)
64 | tracker[label].append(impath)
65 |
66 | print("Splitting data into 50% train, 20% val, and 30% test")
67 |
68 | def _collate(ims, y, c):
69 | items = []
70 | for im in ims:
71 | item = Datum(impath=im, label=y - 1, classname=c) # convert to 0-based label
72 | items.append(item)
73 | return items
74 |
75 | lab2cname = read_json(self.lab2cname_file)
76 | train, val, test = [], [], []
77 | for label, impaths in tracker.items():
78 | random.shuffle(impaths)
79 | n_total = len(impaths)
80 | n_train = round(n_total * 0.5)
81 | n_val = round(n_total * 0.2)
82 | n_test = n_total - n_train - n_val
83 | assert n_train > 0 and n_val > 0 and n_test > 0
84 | cname = lab2cname[str(label)]
85 | train.extend(_collate(impaths[:n_train], label, cname))
86 | val.extend(_collate(impaths[n_train : n_train + n_val], label, cname))
87 | test.extend(_collate(impaths[n_train + n_val :], label, cname))
88 |
89 | return train, val, test
90 |
--------------------------------------------------------------------------------
/datasets/oxford_pets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import math
4 | import random
5 | from collections import defaultdict
6 |
7 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
8 | from dassl.utils import read_json, write_json, mkdir_if_missing
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class OxfordPets(DatasetBase):
13 |
14 | dataset_dir = "oxford_pets"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "images")
20 | self.anno_dir = os.path.join(self.dataset_dir, "annotations")
21 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_OxfordPets.json")
22 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
23 | mkdir_if_missing(self.split_fewshot_dir)
24 |
25 | if os.path.exists(self.split_path):
26 | train, val, test = self.read_split(self.split_path, self.image_dir)
27 | else:
28 | trainval = self.read_data(split_file="trainval.txt")
29 | test = self.read_data(split_file="test.txt")
30 | train, val = self.split_trainval(trainval)
31 | self.save_split(train, val, test, self.split_path, self.image_dir)
32 |
33 | num_shots = cfg.DATASET.NUM_SHOTS
34 | if num_shots >= 1:
35 | seed = cfg.SEED
36 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
37 |
38 | if os.path.exists(preprocessed):
39 | print(f"Loading preprocessed few-shot data from {preprocessed}")
40 | with open(preprocessed, "rb") as file:
41 | data = pickle.load(file)
42 | train, val = data["train"], data["val"]
43 | else:
44 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
45 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
46 | data = {"train": train, "val": val}
47 | print(f"Saving preprocessed few-shot data to {preprocessed}")
48 | with open(preprocessed, "wb") as file:
49 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
50 |
51 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
52 | train, val, test = self.subsample_classes(train, val, test, subsample=subsample)
53 |
54 | super().__init__(train_x=train, val=val, test=test)
55 |
56 | def read_data(self, split_file):
57 | filepath = os.path.join(self.anno_dir, split_file)
58 | items = []
59 |
60 | with open(filepath, "r") as f:
61 | lines = f.readlines()
62 | for line in lines:
63 | line = line.strip()
64 | imname, label, species, _ = line.split(" ")
65 | breed = imname.split("_")[:-1]
66 | breed = "_".join(breed)
67 | breed = breed.lower()
68 | imname += ".jpg"
69 | impath = os.path.join(self.image_dir, imname)
70 | label = int(label) - 1 # convert to 0-based index
71 | item = Datum(impath=impath, label=label, classname=breed)
72 | items.append(item)
73 |
74 | return items
75 |
76 | @staticmethod
77 | def split_trainval(trainval, p_val=0.2):
78 | p_trn = 1 - p_val
79 | print(f"Splitting trainval into {p_trn:.0%} train and {p_val:.0%} val")
80 | tracker = defaultdict(list)
81 | for idx, item in enumerate(trainval):
82 | label = item.label
83 | tracker[label].append(idx)
84 |
85 | train, val = [], []
86 | for label, idxs in tracker.items():
87 | n_val = round(len(idxs) * p_val)
88 | assert n_val > 0
89 | random.shuffle(idxs)
90 | for n, idx in enumerate(idxs):
91 | item = trainval[idx]
92 | if n < n_val:
93 | val.append(item)
94 | else:
95 | train.append(item)
96 |
97 | return train, val
98 |
99 | @staticmethod
100 | def save_split(train, val, test, filepath, path_prefix):
101 | def _extract(items):
102 | out = []
103 | for item in items:
104 | impath = item.impath
105 | label = item.label
106 | classname = item.classname
107 | impath = impath.replace(path_prefix, "")
108 | if impath.startswith("/"):
109 | impath = impath[1:]
110 | out.append((impath, label, classname))
111 | return out
112 |
113 | train = _extract(train)
114 | val = _extract(val)
115 | test = _extract(test)
116 |
117 | split = {"train": train, "val": val, "test": test}
118 |
119 | write_json(split, filepath)
120 | print(f"Saved split to {filepath}")
121 |
122 | @staticmethod
123 | def read_split(filepath, path_prefix):
124 | def _convert(items):
125 | out = []
126 | for impath, label, classname in items:
127 | impath = os.path.join(path_prefix, impath)
128 | item = Datum(impath=impath, label=int(label), classname=classname)
129 | out.append(item)
130 | return out
131 |
132 | print(f"Reading split from {filepath}")
133 | split = read_json(filepath)
134 | train = _convert(split["train"])
135 | val = _convert(split["val"])
136 | test = _convert(split["test"])
137 |
138 | return train, val, test
139 |
140 | @staticmethod
141 | def subsample_classes(*args, subsample="all"):
142 | """Divide classes into two groups. The first group
143 | represents base classes while the second group represents
144 | new classes.
145 |
146 | Args:
147 | args: a list of datasets, e.g. train, val and test.
148 | subsample (str): what classes to subsample.
149 | """
150 | assert subsample in ["all", "base", "new"]
151 |
152 | if subsample == "all":
153 | return args
154 |
155 | dataset = args[0]
156 | labels = set()
157 | for item in dataset:
158 | labels.add(item.label)
159 | labels = list(labels)
160 | labels.sort()
161 | n = len(labels)
162 | # Divide classes into two halves
163 | m = math.ceil(n / 2)
164 |
165 | print(f"SUBSAMPLE {subsample.upper()} CLASSES!")
166 | if subsample == "base":
167 | selected = labels[:m] # take the first half
168 | else:
169 | selected = labels[m:] # take the second half
170 | relabeler = {y: y_new for y_new, y in enumerate(selected)}
171 |
172 | output = []
173 | for dataset in args:
174 | dataset_new = []
175 | for item in dataset:
176 | if item.label not in selected:
177 | continue
178 | item_new = Datum(
179 | impath=item.impath,
180 | label=relabeler[item.label],
181 | classname=item.classname
182 | )
183 | dataset_new.append(item_new)
184 | output.append(dataset_new)
185 |
186 | return output
187 |
--------------------------------------------------------------------------------
/datasets/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | from scipy.io import loadmat
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class StanfordCars(DatasetBase):
13 |
14 | dataset_dir = "stanford_cars"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_StanfordCars.json")
20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
21 | mkdir_if_missing(self.split_fewshot_dir)
22 |
23 | if os.path.exists(self.split_path):
24 | train, val, test = OxfordPets.read_split(self.split_path, self.dataset_dir)
25 | else:
26 | trainval_file = os.path.join(self.dataset_dir, "devkit", "cars_train_annos.mat")
27 | test_file = os.path.join(self.dataset_dir, "cars_test_annos_withlabels.mat")
28 | meta_file = os.path.join(self.dataset_dir, "devkit", "cars_meta.mat")
29 | trainval = self.read_data("cars_train", trainval_file, meta_file)
30 | test = self.read_data("cars_test", test_file, meta_file)
31 | train, val = OxfordPets.split_trainval(trainval)
32 | OxfordPets.save_split(train, val, test, self.split_path, self.dataset_dir)
33 |
34 | num_shots = cfg.DATASET.NUM_SHOTS
35 | if num_shots >= 1:
36 | seed = cfg.SEED
37 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
38 |
39 | if os.path.exists(preprocessed):
40 | print(f"Loading preprocessed few-shot data from {preprocessed}")
41 | with open(preprocessed, "rb") as file:
42 | data = pickle.load(file)
43 | train, val = data["train"], data["val"]
44 | else:
45 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
46 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
47 | data = {"train": train, "val": val}
48 | print(f"Saving preprocessed few-shot data to {preprocessed}")
49 | with open(preprocessed, "wb") as file:
50 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
51 |
52 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
53 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
54 |
55 | super().__init__(train_x=train, val=val, test=test)
56 |
57 | def read_data(self, image_dir, anno_file, meta_file):
58 | anno_file = loadmat(anno_file)["annotations"][0]
59 | meta_file = loadmat(meta_file)["class_names"][0]
60 | items = []
61 |
62 | for i in range(len(anno_file)):
63 | imname = anno_file[i]["fname"][0]
64 | impath = os.path.join(self.dataset_dir, image_dir, imname)
65 | label = anno_file[i]["class"][0, 0]
66 | label = int(label) - 1 # convert to 0-based index
67 | classname = meta_file[label][0]
68 | names = classname.split(" ")
69 | year = names.pop(-1)
70 | names.insert(0, year)
71 | classname = " ".join(names)
72 | item = Datum(impath=impath, label=label, classname=classname)
73 | items.append(item)
74 |
75 | return items
76 |
--------------------------------------------------------------------------------
/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 |
4 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
5 | from dassl.utils import mkdir_if_missing
6 |
7 | from .oxford_pets import OxfordPets
8 |
9 |
10 | @DATASET_REGISTRY.register()
11 | class SUN397(DatasetBase):
12 |
13 | dataset_dir = "sun397"
14 |
15 | def __init__(self, cfg):
16 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
17 | self.dataset_dir = os.path.join(root, self.dataset_dir)
18 | self.image_dir = os.path.join(self.dataset_dir, "SUN397")
19 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_SUN397.json")
20 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
21 | mkdir_if_missing(self.split_fewshot_dir)
22 |
23 | if os.path.exists(self.split_path):
24 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
25 | else:
26 | classnames = []
27 | with open(os.path.join(self.dataset_dir, "ClassName.txt"), "r") as f:
28 | lines = f.readlines()
29 | for line in lines:
30 | line = line.strip()[1:] # remove /
31 | classnames.append(line)
32 | cname2lab = {c: i for i, c in enumerate(classnames)}
33 | trainval = self.read_data(cname2lab, "Training_01.txt")
34 | test = self.read_data(cname2lab, "Testing_01.txt")
35 | train, val = OxfordPets.split_trainval(trainval)
36 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
37 |
38 | num_shots = cfg.DATASET.NUM_SHOTS
39 | if num_shots >= 1:
40 | seed = cfg.SEED
41 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
42 |
43 | if os.path.exists(preprocessed):
44 | print(f"Loading preprocessed few-shot data from {preprocessed}")
45 | with open(preprocessed, "rb") as file:
46 | data = pickle.load(file)
47 | train, val = data["train"], data["val"]
48 | else:
49 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
50 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
51 | data = {"train": train, "val": val}
52 | print(f"Saving preprocessed few-shot data to {preprocessed}")
53 | with open(preprocessed, "wb") as file:
54 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
55 |
56 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
57 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
58 |
59 | super().__init__(train_x=train, val=val, test=test)
60 |
61 | def read_data(self, cname2lab, text_file):
62 | text_file = os.path.join(self.dataset_dir, text_file)
63 | items = []
64 |
65 | with open(text_file, "r") as f:
66 | lines = f.readlines()
67 | for line in lines:
68 | imname = line.strip()[1:] # remove /
69 | classname = os.path.dirname(imname)
70 | label = cname2lab[classname]
71 | impath = os.path.join(self.image_dir, imname)
72 |
73 | names = classname.split("/")[1:] # remove 1st letter
74 | names = names[::-1] # put words like indoor/outdoor at first
75 | classname = " ".join(names)
76 |
77 | item = Datum(impath=impath, label=label, classname=classname)
78 | items.append(item)
79 |
80 | return items
81 |
--------------------------------------------------------------------------------
/datasets/ucf101.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import re
4 |
5 | from dassl.data.datasets import DATASET_REGISTRY, Datum, DatasetBase
6 | from dassl.utils import mkdir_if_missing
7 |
8 | from .oxford_pets import OxfordPets
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class UCF101(DatasetBase):
13 |
14 | dataset_dir = "ucf101"
15 |
16 | def __init__(self, cfg):
17 | root = os.path.abspath(os.path.expanduser(cfg.DATASET.ROOT))
18 | self.dataset_dir = os.path.join(root, self.dataset_dir)
19 | self.image_dir = os.path.join(self.dataset_dir, "UCF-101-midframes")
20 | self.split_path = os.path.join(self.dataset_dir, "split_zhou_UCF101.json")
21 | self.split_fewshot_dir = os.path.join(self.dataset_dir, "split_fewshot")
22 | mkdir_if_missing(self.split_fewshot_dir)
23 |
24 | if os.path.exists(self.split_path):
25 | train, val, test = OxfordPets.read_split(self.split_path, self.image_dir)
26 | else:
27 | cname2lab = {}
28 | filepath = os.path.join(self.dataset_dir, "ucfTrainTestlist/classInd.txt")
29 | with open(filepath, "r") as f:
30 | lines = f.readlines()
31 | for line in lines:
32 | label, classname = line.strip().split(" ")
33 | label = int(label) - 1 # conver to 0-based index
34 | cname2lab[classname] = label
35 |
36 | trainval = self.read_data(cname2lab, "ucfTrainTestlist/trainlist01.txt")
37 | test = self.read_data(cname2lab, "ucfTrainTestlist/testlist01.txt")
38 | train, val = OxfordPets.split_trainval(trainval)
39 | OxfordPets.save_split(train, val, test, self.split_path, self.image_dir)
40 |
41 | num_shots = cfg.DATASET.NUM_SHOTS
42 | if num_shots >= 1:
43 | seed = cfg.SEED
44 | preprocessed = os.path.join(self.split_fewshot_dir, f"shot_{num_shots}-seed_{seed}.pkl")
45 |
46 | if os.path.exists(preprocessed):
47 | print(f"Loading preprocessed few-shot data from {preprocessed}")
48 | with open(preprocessed, "rb") as file:
49 | data = pickle.load(file)
50 | train, val = data["train"], data["val"]
51 | else:
52 | train = self.generate_fewshot_dataset(train, num_shots=num_shots)
53 | val = self.generate_fewshot_dataset(val, num_shots=min(num_shots, 4))
54 | data = {"train": train, "val": val}
55 | print(f"Saving preprocessed few-shot data to {preprocessed}")
56 | with open(preprocessed, "wb") as file:
57 | pickle.dump(data, file, protocol=pickle.HIGHEST_PROTOCOL)
58 |
59 | subsample = cfg.DATASET.SUBSAMPLE_CLASSES
60 | train, val, test = OxfordPets.subsample_classes(train, val, test, subsample=subsample)
61 |
62 | super().__init__(train_x=train, val=val, test=test)
63 |
64 | def read_data(self, cname2lab, text_file):
65 | text_file = os.path.join(self.dataset_dir, text_file)
66 | items = []
67 |
68 | with open(text_file, "r") as f:
69 | lines = f.readlines()
70 | for line in lines:
71 | line = line.strip().split(" ")[0] # trainlist: filename, label
72 | action, filename = line.split("/")
73 | label = cname2lab[action]
74 |
75 | elements = re.findall("[A-Z][^A-Z]*", action)
76 | renamed_action = "_".join(elements)
77 |
78 | filename = filename.replace(".avi", ".jpg")
79 | impath = os.path.join(self.image_dir, renamed_action, filename)
80 |
81 | item = Datum(impath=impath, label=label, classname=renamed_action)
82 | items.append(item)
83 |
84 | return items
85 |
--------------------------------------------------------------------------------
/images/ALIGN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/images/ALIGN.png
--------------------------------------------------------------------------------
/parse_test_res.py:
--------------------------------------------------------------------------------
1 | """
2 | Goal
3 | ---
4 | 1. Read test results from log.txt files
5 | 2. Compute mean and std across different folders (seeds)
6 |
7 | Usage
8 | ---
9 | Assume the output files are saved under output/my_experiment,
10 | which contains results of different seeds, e.g.,
11 |
12 | my_experiment/
13 | seed1/
14 | log.txt
15 | seed2/
16 | log.txt
17 | seed3/
18 | log.txt
19 |
20 | Run the following command from the root directory:
21 |
22 | $ python tools/parse_test_res.py output/my_experiment
23 |
24 | Add --ci95 to the argument if you wanna get 95% confidence
25 | interval instead of standard deviation:
26 |
27 | $ python tools/parse_test_res.py output/my_experiment --ci95
28 |
29 | If my_experiment/ has the following structure,
30 |
31 | my_experiment/
32 | exp-1/
33 | seed1/
34 | log.txt
35 | ...
36 | seed2/
37 | log.txt
38 | ...
39 | seed3/
40 | log.txt
41 | ...
42 | exp-2/
43 | ...
44 | exp-3/
45 | ...
46 |
47 | Run
48 |
49 | $ python tools/parse_test_res.py output/my_experiment --multi-exp
50 | """
51 | import re
52 | import numpy as np
53 | import os.path as osp
54 | import argparse
55 | from collections import OrderedDict, defaultdict
56 |
57 | from dassl.utils import check_isfile, listdir_nohidden
58 |
59 |
60 | def compute_ci95(res):
61 | return 1.96 * np.std(res) / np.sqrt(len(res))
62 |
63 |
64 | def parse_function(*metrics, directory="", args=None, end_signal=None):
65 | print(f"Parsing files in {directory}")
66 | subdirs = listdir_nohidden(directory, sort=True)
67 |
68 | outputs = []
69 |
70 | for subdir in subdirs:
71 | fpath = osp.join(directory, subdir, "log.txt")
72 | assert check_isfile(fpath)
73 | good_to_go = False
74 | output = OrderedDict()
75 |
76 | with open(fpath, "r") as f:
77 | lines = f.readlines()
78 |
79 | for line in lines:
80 | line = line.strip()
81 |
82 | if line == end_signal:
83 | good_to_go = True
84 |
85 | for metric in metrics:
86 | match = metric["regex"].search(line)
87 | if match and good_to_go:
88 | if "file" not in output:
89 | output["file"] = fpath
90 | num = float(match.group(1))
91 | name = metric["name"]
92 | output[name] = num
93 |
94 | if output:
95 | outputs.append(output)
96 |
97 | assert len(outputs) > 0, f"Nothing found in {directory}"
98 |
99 | metrics_results = defaultdict(list)
100 |
101 | for output in outputs:
102 | msg = ""
103 | for key, value in output.items():
104 | if isinstance(value, float):
105 | msg += f"{key}: {value:.2f}%. "
106 | else:
107 | msg += f"{key}: {value}. "
108 | if key != "file":
109 | metrics_results[key].append(value)
110 | print(msg)
111 |
112 | output_results = OrderedDict()
113 |
114 | print("===")
115 | print(f"Summary of directory: {directory}")
116 | for key, values in metrics_results.items():
117 | avg = np.mean(values)
118 | std = compute_ci95(values) if args.ci95 else np.std(values)
119 | print(f"* {key}: {avg:.2f}% +- {std:.2f}%")
120 | output_results[key] = avg
121 | print("===")
122 |
123 | return output_results
124 |
125 |
126 | def main(args, end_signal):
127 | metric = {
128 | "name": args.keyword,
129 | "regex": re.compile(fr"\* {args.keyword}: ([\.\deE+-]+)%"),
130 | }
131 |
132 | if args.multi_exp:
133 | final_results = defaultdict(list)
134 |
135 | for directory in listdir_nohidden(args.directory, sort=True):
136 | directory = osp.join(args.directory, directory)
137 | results = parse_function(
138 | metric, directory=directory, args=args, end_signal=end_signal
139 | )
140 |
141 | for key, value in results.items():
142 | final_results[key].append(value)
143 |
144 | print("Average performance")
145 | for key, values in final_results.items():
146 | avg = np.mean(values)
147 | print(f"* {key}: {avg:.2f}%")
148 |
149 | else:
150 | parse_function(
151 | metric, directory=args.directory, args=args, end_signal=end_signal
152 | )
153 |
154 |
155 | if __name__ == "__main__":
156 | parser = argparse.ArgumentParser()
157 | parser.add_argument("directory", type=str, help="path to directory")
158 | parser.add_argument(
159 | "--ci95", action="store_true", help=r"compute 95\% confidence interval"
160 | )
161 | parser.add_argument("--test-log", action="store_true", help="parse test-only logs")
162 | parser.add_argument(
163 | "--multi-exp", action="store_true", help="parse multiple experiments"
164 | )
165 | parser.add_argument(
166 | "--keyword", default="accuracy", type=str, help="which keyword to extract"
167 | )
168 | args = parser.parse_args()
169 |
170 | end_signal = "Finished training"
171 | if args.test_log:
172 | end_signal = "=> result"
173 |
174 | main(args, end_signal)
175 |
--------------------------------------------------------------------------------
/scripts/cocoop/base2new_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=CoCoOp
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 | LOADEP=10
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Evaluating model"
23 | echo "Results are available in ${DIR}. Resuming..."
24 |
25 | python train.py \
26 | --root ${DATA} \
27 | --seed ${SEED} \
28 | --trainer ${TRAINER} \
29 | --dataset-config-file configs/datasets/${DATASET}.yaml \
30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
31 | --output-dir ${DIR} \
32 | --model-dir ${MODEL_DIR} \
33 | --load-epoch ${LOADEP} \
34 | --eval-only \
35 | DATASET.NUM_SHOTS ${SHOTS} \
36 | DATASET.SUBSAMPLE_CLASSES ${SUB}
37 |
38 | else
39 | echo "Evaluating model"
40 | echo "Runing the first phase job and save the output to ${DIR}"
41 |
42 | python train.py \
43 | --root ${DATA} \
44 | --seed ${SEED} \
45 | --trainer ${TRAINER} \
46 | --dataset-config-file configs/datasets/${DATASET}.yaml \
47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
48 | --output-dir ${DIR} \
49 | --model-dir ${MODEL_DIR} \
50 | --load-epoch ${LOADEP} \
51 | --eval-only \
52 | DATASET.NUM_SHOTS ${SHOTS} \
53 | DATASET.SUBSAMPLE_CLASSES ${SUB}
54 | fi
--------------------------------------------------------------------------------
/scripts/cocoop/base2new_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=CoCoOp
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 |
15 |
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Resuming..."
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | DATASET.NUM_SHOTS ${SHOTS} \
27 | DATASET.SUBSAMPLE_CLASSES base
28 | else
29 | echo "Run this job and save the output to ${DIR}"
30 | python train.py \
31 | --root ${DATA} \
32 | --seed ${SEED} \
33 | --trainer ${TRAINER} \
34 | --dataset-config-file configs/datasets/${DATASET}.yaml \
35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
36 | --output-dir ${DIR} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
--------------------------------------------------------------------------------
/scripts/cocoop/xd_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=CoCoOp
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 10 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/cocoop/xd_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=CoCoOp
8 |
9 | DATASET=imagenet
10 | SEED=$1
11 |
12 | CFG=vit_b16_c4_ep10_batch1_ctxv1
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/coop/basenewtrain.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="/data4/wds/dataset/CoOpData/"
7 | TRAINER=CoOp
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_ep50
13 | SHOTS=16
14 | NCTX=16
15 |
16 |
17 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
18 | if [ -d "$DIR" ]; then
19 | echo "Results are available in ${DIR}. Resuming..."
20 | python train.py \
21 | --root ${DATA} \
22 | --seed ${SEED} \
23 | --trainer ${TRAINER} \
24 | --dataset-config-file configs/datasets/${DATASET}.yaml \
25 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
26 | --output-dir ${DIR} \
27 | TRAINER.COOP.N_CTX ${NCTX} \
28 | DATASET.NUM_SHOTS ${SHOTS} \
29 | DATASET.SUBSAMPLE_CLASSES base
30 | else
31 | echo "Run this job and save the output to ${DIR}"
32 | python train.py \
33 | --root ${DATA} \
34 | --seed ${SEED} \
35 | --trainer ${TRAINER} \
36 | --dataset-config-file configs/datasets/${DATASET}.yaml \
37 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
38 | --output-dir ${DIR} \
39 | TRAINER.COOP.N_CTX ${NCTX} \
40 | DATASET.NUM_SHOTS ${SHOTS} \
41 | DATASET.SUBSAMPLE_CLASSES base
42 | fi
--------------------------------------------------------------------------------
/scripts/coop/eval.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=CoOp
8 | SHOTS=16
9 | NCTX=16
10 | CSC=False
11 | CTP=end
12 |
13 | DATASET=$1
14 | CFG=$2
15 |
16 | for SEED in 1 2 3
17 | do
18 | python train.py \
19 | --root ${DATA} \
20 | --seed ${SEED} \
21 | --trainer ${TRAINER} \
22 | --dataset-config-file configs/datasets/${DATASET}.yaml \
23 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
24 | --output-dir output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/${DATASET}/seed${SEED} \
25 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED} \
26 | --load-epoch 50 \
27 | --eval-only \
28 | TRAINER.COOP.N_CTX ${NCTX} \
29 | TRAINER.COOP.CSC ${CSC} \
30 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP}
31 | done
--------------------------------------------------------------------------------
/scripts/coop/main.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="/data4/wds/dataset/CoOpData/"
7 | TRAINER=CoOp
8 |
9 | DATASET=eurosat
10 | CFG=rn50 # config file
11 | CTP=end # class token position (end or middle)
12 | NCTX=16 # number of context tokens
13 | SHOTS=8 # number of shots (1, 2, 4, 8, 16)
14 | CSC=False # class-specific context (False or True)
15 |
16 | for SEED in 1
17 | do
18 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/nctx${NCTX}_csc${CSC}_ctp${CTP}/seed${SEED}
19 | if [ -d "$DIR" ]; then
20 | echo "Results are available in ${DIR}. Skip this job"
21 | else
22 | echo "Run this job and save the output to ${DIR}"
23 | python train.py \
24 | --root ${DATA} \
25 | --seed ${SEED} \
26 | --trainer ${TRAINER} \
27 | --dataset-config-file configs/datasets/${DATASET}.yaml \
28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
29 | --output-dir ${DIR} \
30 | TRAINER.COOP.N_CTX ${NCTX} \
31 | TRAINER.COOP.CSC ${CSC} \
32 | TRAINER.COOP.CLASS_TOKEN_POSITION ${CTP} \
33 | DATASET.NUM_SHOTS ${SHOTS}
34 | fi
35 | done
--------------------------------------------------------------------------------
/scripts/independent-vlp/base2new_test_ivlp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx
13 | SHOTS=16
14 | LOADEP=5
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Evaluating model"
23 | echo "Results are available in ${DIR}. Resuming..."
24 |
25 | python train.py \
26 | --root ${DATA} \
27 | --seed ${SEED} \
28 | --trainer ${TRAINER} \
29 | --dataset-config-file configs/datasets/${DATASET}.yaml \
30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
31 | --output-dir ${DIR} \
32 | --model-dir ${MODEL_DIR} \
33 | --load-epoch ${LOADEP} \
34 | --eval-only \
35 | DATASET.NUM_SHOTS ${SHOTS} \
36 | DATASET.SUBSAMPLE_CLASSES ${SUB}
37 |
38 | else
39 | echo "Evaluating model"
40 | echo "Runing the first phase job and save the output to ${DIR}"
41 |
42 | python train.py \
43 | --root ${DATA} \
44 | --seed ${SEED} \
45 | --trainer ${TRAINER} \
46 | --dataset-config-file configs/datasets/${DATASET}.yaml \
47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
48 | --output-dir ${DIR} \
49 | --model-dir ${MODEL_DIR} \
50 | --load-epoch ${LOADEP} \
51 | --eval-only \
52 | DATASET.NUM_SHOTS ${SHOTS} \
53 | DATASET.SUBSAMPLE_CLASSES ${SUB}
54 | fi
--------------------------------------------------------------------------------
/scripts/independent-vlp/base2new_train_ivlp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx
13 | SHOTS=16
14 |
15 |
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Resuming..."
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | DATASET.NUM_SHOTS ${SHOTS} \
27 | DATASET.SUBSAMPLE_CLASSES base
28 | else
29 | echo "Run this job and save the output to ${DIR}"
30 | python train.py \
31 | --root ${DATA} \
32 | --seed ${SEED} \
33 | --trainer ${TRAINER} \
34 | --dataset-config-file configs/datasets/${DATASET}.yaml \
35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
36 | --output-dir ${DIR} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
--------------------------------------------------------------------------------
/scripts/independent-vlp/reproduce_ivlp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 | WEIGHTSPATH=$3
12 |
13 | CFG=vit_b16_c2_ep5_batch4_2+2ctx
14 | SHOTS=16
15 | LOADEP=5
16 | SUB_base=base
17 | SUB_novel=new
18 |
19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
23 | if [ -d "$DIR" ]; then
24 | echo "Results are already available in ${DIR}. Skipping..."
25 | else
26 | echo "Evaluating model"
27 | echo "Runing the first phase job and save the output to ${DIR}"
28 | # Evaluate on base classes
29 | python train.py \
30 | --root ${DATA} \
31 | --seed ${SEED} \
32 | --trainer ${TRAINER} \
33 | --dataset-config-file configs/datasets/${DATASET}.yaml \
34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
35 | --output-dir ${DIR_base} \
36 | --model-dir ${MODEL_DIR} \
37 | --load-epoch ${LOADEP} \
38 | --eval-only \
39 | DATASET.NUM_SHOTS ${SHOTS} \
40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base}
41 |
42 | # Evaluate on novel classes
43 | python train.py \
44 | --root ${DATA} \
45 | --seed ${SEED} \
46 | --trainer ${TRAINER} \
47 | --dataset-config-file configs/datasets/${DATASET}.yaml \
48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
49 | --output-dir ${DIR_novel} \
50 | --model-dir ${MODEL_DIR} \
51 | --load-epoch ${LOADEP} \
52 | --eval-only \
53 | DATASET.NUM_SHOTS ${SHOTS} \
54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
55 |
56 | fi
--------------------------------------------------------------------------------
/scripts/independent-vlp/xd_test_ivlp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 2 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/independent-vlp/xd_train_ivlp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2+2ctx
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}."
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/language-prompting/base2new_test_lp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only
13 | SHOTS=16
14 | LOADEP=5
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Evaluating model"
23 | echo "Results are available in ${DIR}. Resuming..."
24 |
25 | python train.py \
26 | --root ${DATA} \
27 | --seed ${SEED} \
28 | --trainer ${TRAINER} \
29 | --dataset-config-file configs/datasets/${DATASET}.yaml \
30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
31 | --output-dir ${DIR} \
32 | --model-dir ${MODEL_DIR} \
33 | --load-epoch ${LOADEP} \
34 | --eval-only \
35 | DATASET.NUM_SHOTS ${SHOTS} \
36 | DATASET.SUBSAMPLE_CLASSES ${SUB}
37 |
38 | else
39 | echo "Evaluating model"
40 | echo "Runing the first phase job and save the output to ${DIR}"
41 |
42 | python train.py \
43 | --root ${DATA} \
44 | --seed ${SEED} \
45 | --trainer ${TRAINER} \
46 | --dataset-config-file configs/datasets/${DATASET}.yaml \
47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
48 | --output-dir ${DIR} \
49 | --model-dir ${MODEL_DIR} \
50 | --load-epoch ${LOADEP} \
51 | --eval-only \
52 | DATASET.NUM_SHOTS ${SHOTS} \
53 | DATASET.SUBSAMPLE_CLASSES ${SUB}
54 | fi
--------------------------------------------------------------------------------
/scripts/language-prompting/base2new_train_lp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only
13 | SHOTS=16
14 |
15 |
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Resuming..."
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | DATASET.NUM_SHOTS ${SHOTS} \
27 | DATASET.SUBSAMPLE_CLASSES base
28 | else
29 | echo "Run this job and save the output to ${DIR}"
30 | python train.py \
31 | --root ${DATA} \
32 | --seed ${SEED} \
33 | --trainer ${TRAINER} \
34 | --dataset-config-file configs/datasets/${DATASET}.yaml \
35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
36 | --output-dir ${DIR} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
--------------------------------------------------------------------------------
/scripts/language-prompting/reproduce_lp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 | WEIGHTSPATH=$3
12 |
13 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only
14 | SHOTS=16
15 | LOADEP=5
16 | SUB_base=base
17 | SUB_novel=new
18 |
19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
23 | if [ -d "$DIR" ]; then
24 | echo "Results are already available in ${DIR}. Skipping..."
25 | else
26 | echo "Evaluating model"
27 | echo "Runing the first phase job and save the output to ${DIR}"
28 | # Evaluate on base classes
29 | python train.py \
30 | --root ${DATA} \
31 | --seed ${SEED} \
32 | --trainer ${TRAINER} \
33 | --dataset-config-file configs/datasets/${DATASET}.yaml \
34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
35 | --output-dir ${DIR_base} \
36 | --model-dir ${MODEL_DIR} \
37 | --load-epoch ${LOADEP} \
38 | --eval-only \
39 | DATASET.NUM_SHOTS ${SHOTS} \
40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base}
41 |
42 | # Evaluate on novel classes
43 | python train.py \
44 | --root ${DATA} \
45 | --seed ${SEED} \
46 | --trainer ${TRAINER} \
47 | --dataset-config-file configs/datasets/${DATASET}.yaml \
48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
49 | --output-dir ${DIR_novel} \
50 | --model-dir ${MODEL_DIR} \
51 | --load-epoch ${LOADEP} \
52 | --eval-only \
53 | DATASET.NUM_SHOTS ${SHOTS} \
54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
55 |
56 | fi
--------------------------------------------------------------------------------
/scripts/language-prompting/xd_test_lp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 2 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/language-prompting/xd_train_lp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=IVLP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4ctx_language_only
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}."
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/maple/base2new_test_maple.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="/data4/wds/dataset/CoOpData/"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx
13 | SHOTS=16
14 | LOADEP=10
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Evaluating model"
23 | echo "Results are available in ${DIR}. Resuming..."
24 |
25 | python train.py \
26 | --root ${DATA} \
27 | --seed ${SEED} \
28 | --trainer ${TRAINER} \
29 | --dataset-config-file configs/datasets/${DATASET}.yaml \
30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
31 | --output-dir ${DIR} \
32 | --model-dir ${MODEL_DIR} \
33 | --load-epoch ${LOADEP} \
34 | --eval-only \
35 | DATASET.NUM_SHOTS ${SHOTS} \
36 | DATASET.SUBSAMPLE_CLASSES ${SUB}
37 |
38 | else
39 | echo "Evaluating model"
40 | echo "Runing the first phase job and save the output to ${DIR}"
41 |
42 | python train.py \
43 | --root ${DATA} \
44 | --seed ${SEED} \
45 | --trainer ${TRAINER} \
46 | --dataset-config-file configs/datasets/${DATASET}.yaml \
47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
48 | --output-dir ${DIR} \
49 | --model-dir ${MODEL_DIR} \
50 | --load-epoch ${LOADEP} \
51 | --eval-only \
52 | DATASET.NUM_SHOTS ${SHOTS} \
53 | DATASET.SUBSAMPLE_CLASSES ${SUB}
54 | fi
--------------------------------------------------------------------------------
/scripts/maple/base2new_train_maple.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="/data4/wds/dataset/CoOpData/"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx
13 | SHOTS=16
14 |
15 |
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Resuming..."
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | DATASET.NUM_SHOTS ${SHOTS} \
27 | DATASET.SUBSAMPLE_CLASSES base
28 | else
29 | echo "Run this job and save the output to ${DIR}"
30 | python train.py \
31 | --root ${DATA} \
32 | --seed ${SEED} \
33 | --trainer ${TRAINER} \
34 | --dataset-config-file configs/datasets/${DATASET}.yaml \
35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
36 | --output-dir ${DIR} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
--------------------------------------------------------------------------------
/scripts/maple/fst.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="/data4/wds/dataset/CoOpData/"
7 | TRAINER=MaPLe
8 |
9 | CFG=vit_b16_c2_ep5_batch4_2ctx
10 | #DATASET=$1
11 |
12 | #for DATASET in caltech101 dtd eurosat fgvc_aircraft
13 | #for DATASET in food101 oxford_flowers oxford_pets stanford_cars
14 | for DATASET in ucf101
15 | do
16 | for SHOTS in 1 2 4 8 16
17 | do
18 | for SEED in 1 2 3
19 | do
20 | DIR=output/fewshot/${DATASET}/${TRAINER}_2/shots_${SHOTS}/seed${SEED}
21 | if [ -d "$DIR" ]; then
22 | echo "Results are available in ${DIR}. Resuming..."
23 | python train.py \
24 | --root ${DATA} \
25 | --seed ${SEED} \
26 | --trainer ${TRAINER} \
27 | --dataset-config-file configs/datasets/${DATASET}.yaml \
28 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
29 | --output-dir ${DIR} \
30 | DATASET.NUM_SHOTS ${SHOTS}
31 | else
32 | echo "Run this job and save the output to ${DIR}"
33 | python train.py \
34 | --root ${DATA} \
35 | --seed ${SEED} \
36 | --trainer ${TRAINER} \
37 | --dataset-config-file configs/datasets/${DATASET}.yaml \
38 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
39 | --output-dir ${DIR} \
40 | DATASET.NUM_SHOTS ${SHOTS}
41 | fi
42 | done
43 | done
44 | done
--------------------------------------------------------------------------------
/scripts/maple/reproduce_maple.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 | WEIGHTSPATH=$3
12 |
13 | CFG=vit_b16_c2_ep5_batch4_2ctx
14 | SHOTS=16
15 | LOADEP=5
16 | SUB_base=base
17 | SUB_novel=new
18 |
19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
23 | if [ -d "$DIR" ]; then
24 | echo "Results are already available in ${DIR}. Skipping..."
25 | else
26 | echo "Evaluating model"
27 | echo "Runing the first phase job and save the output to ${DIR}"
28 | # Evaluate on base classes
29 | python train.py \
30 | --root ${DATA} \
31 | --seed ${SEED} \
32 | --trainer ${TRAINER} \
33 | --dataset-config-file configs/datasets/${DATASET}.yaml \
34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
35 | --output-dir ${DIR_base} \
36 | --model-dir ${MODEL_DIR} \
37 | --load-epoch ${LOADEP} \
38 | --eval-only \
39 | DATASET.NUM_SHOTS ${SHOTS} \
40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base}
41 |
42 | # Evaluate on novel classes
43 | python train.py \
44 | --root ${DATA} \
45 | --seed ${SEED} \
46 | --trainer ${TRAINER} \
47 | --dataset-config-file configs/datasets/${DATASET}.yaml \
48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
49 | --output-dir ${DIR_novel} \
50 | --model-dir ${MODEL_DIR} \
51 | --load-epoch ${LOADEP} \
52 | --eval-only \
53 | DATASET.NUM_SHOTS ${SHOTS} \
54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
55 |
56 |
57 |
58 | fi
--------------------------------------------------------------------------------
/scripts/maple/reproduce_maple_xd.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 | WEIGHTSPATH=$3
12 |
13 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
14 | SHOTS=16
15 | LOADEP=2
16 |
17 | MODEL_DIR=${WEIGHTSPATH}/seed${SEED}
18 |
19 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
20 | if [ -d "$DIR" ]; then
21 | echo "Results are already available in ${DIR}. Skipping..."
22 | else
23 | echo "Evaluating model"
24 | echo "Runing the first phase job and save the output to ${DIR}"
25 | # Evaluate on evaluation datasets
26 | python train.py \
27 | --root ${DATA} \
28 | --seed ${SEED} \
29 | --trainer ${TRAINER} \
30 | --dataset-config-file configs/datasets/${DATASET}.yaml \
31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
32 | --output-dir ${DIR} \
33 | --model-dir ${MODEL_DIR} \
34 | --load-epoch ${LOADEP} \
35 | --eval-only \
36 | DATASET.NUM_SHOTS ${SHOTS} \
37 |
38 | fi
--------------------------------------------------------------------------------
/scripts/maple/xd_test_maple.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 2 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/maple/xd_train_maple.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=MaPLe
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx_cross_datasets
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}."
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/mmp/base_to_new_test.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="dirs to datasets"
7 | TRAINER=MMP
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx
13 | SHOTS=16
14 | LOADEP=15
15 | SUB=base
16 |
17 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
18 | MODEL_DIR=/home/dongsheng//wds/maple/output/base2new/train_vis_2/${COMMON_DIR}
19 | DIR=/home/dongsheng//wds/maple/output/base2new/train_vis_2/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
20 |
21 |
22 | if [ -d "$DIR" ]; then
23 | echo "Evaluating model"
24 | echo "Results are available in ${DIR}. Resuming..."
25 |
26 | python train.py \
27 | --root ${DATA} \
28 | --seed ${SEED} \
29 | --trainer ${TRAINER} \
30 | --dataset-config-file configs/datasets/${DATASET}.yaml \
31 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
32 | --output-dir ${DIR} \
33 | --model-dir ${MODEL_DIR} \
34 | --load-epoch ${LOADEP} \
35 | --eval-only \
36 | DATASET.NUM_SHOTS ${SHOTS} \
37 | DATASET.SUBSAMPLE_CLASSES ${SUB}
38 |
39 | else
40 | echo "Evaluating model"
41 | echo "Runing the first phase job and save the output to ${DIR}"
42 |
43 | python train.py \
44 | --root ${DATA} \
45 | --seed ${SEED} \
46 | --trainer ${TRAINER} \
47 | --dataset-config-file configs/datasets/${DATASET}.yaml \
48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
49 | --output-dir ${DIR} \
50 | --model-dir ${MODEL_DIR} \
51 | --load-epoch ${LOADEP} \
52 | --eval-only \
53 | DATASET.NUM_SHOTS ${SHOTS} \
54 | DATASET.SUBSAMPLE_CLASSES ${SUB}
55 | fi
--------------------------------------------------------------------------------
/scripts/mmp/base_to_new_train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | cd ../..
4 |
5 | # custom config
6 | DATA="dirs to datasets"
7 | TRAINER=MMP
8 |
9 | #DATASET=$1
10 | #SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_2ctx
13 | #CFG=sun397
14 | SHOTS=16
15 | N_CTX=2
16 | for DATASET in caltech101
17 | do
18 | for SEED in 1
19 | do
20 |
21 | DIR=/home/dongsheng/wds/maple/output/base2new/train_vis_${N_CTX}/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
22 | if [ -d "$DIR" ]; then
23 | echo "Results are available in ${DIR}. Resuming..."
24 | python train.py \
25 | --root ${DATA} \
26 | --seed ${SEED} \
27 | --trainer ${TRAINER} \
28 | --dataset-config-file configs/datasets/${DATASET}.yaml \
29 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
30 | --output-dir ${DIR} \
31 | DATASET.NUM_SHOTS ${SHOTS} \
32 | DATASET.SUBSAMPLE_CLASSES base
33 | else
34 | echo "Run this job and save the output to ${DIR}"
35 | python train.py \
36 | --root ${DATA} \
37 | --seed ${SEED} \
38 | --trainer ${TRAINER} \
39 | --dataset-config-file configs/datasets/${DATASET}.yaml \
40 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
41 | --output-dir ${DIR} \
42 | DATASET.NUM_SHOTS ${SHOTS} \
43 | DATASET.SUBSAMPLE_CLASSES base
44 | fi
45 | done
46 | done
--------------------------------------------------------------------------------
/scripts/vpt/base2new_test_vpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=VPT
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4
13 | SHOTS=16
14 | LOADEP=5
15 | SUB=new
16 |
17 |
18 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
19 | MODEL_DIR=output/base2new/train_base/${COMMON_DIR}
20 | DIR=output/base2new/test_${SUB}/${COMMON_DIR}
21 | if [ -d "$DIR" ]; then
22 | echo "Evaluating model"
23 | echo "Results are available in ${DIR}. Resuming..."
24 |
25 | python train.py \
26 | --root ${DATA} \
27 | --seed ${SEED} \
28 | --trainer ${TRAINER} \
29 | --dataset-config-file configs/datasets/${DATASET}.yaml \
30 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
31 | --output-dir ${DIR} \
32 | --model-dir ${MODEL_DIR} \
33 | --load-epoch ${LOADEP} \
34 | --eval-only \
35 | DATASET.NUM_SHOTS ${SHOTS} \
36 | DATASET.SUBSAMPLE_CLASSES ${SUB}
37 |
38 | else
39 | echo "Evaluating model"
40 | echo "Runing the first phase job and save the output to ${DIR}"
41 |
42 | python train.py \
43 | --root ${DATA} \
44 | --seed ${SEED} \
45 | --trainer ${TRAINER} \
46 | --dataset-config-file configs/datasets/${DATASET}.yaml \
47 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
48 | --output-dir ${DIR} \
49 | --model-dir ${MODEL_DIR} \
50 | --load-epoch ${LOADEP} \
51 | --eval-only \
52 | DATASET.NUM_SHOTS ${SHOTS} \
53 | DATASET.SUBSAMPLE_CLASSES ${SUB}
54 | fi
--------------------------------------------------------------------------------
/scripts/vpt/base2new_train_vpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=VPT
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4
13 | SHOTS=16
14 |
15 |
16 | DIR=output/base2new/train_base/${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Resuming..."
19 | python train.py \
20 | --root ${DATA} \
21 | --seed ${SEED} \
22 | --trainer ${TRAINER} \
23 | --dataset-config-file configs/datasets/${DATASET}.yaml \
24 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
25 | --output-dir ${DIR} \
26 | DATASET.NUM_SHOTS ${SHOTS} \
27 | DATASET.SUBSAMPLE_CLASSES base
28 | else
29 | echo "Run this job and save the output to ${DIR}"
30 | python train.py \
31 | --root ${DATA} \
32 | --seed ${SEED} \
33 | --trainer ${TRAINER} \
34 | --dataset-config-file configs/datasets/${DATASET}.yaml \
35 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
36 | --output-dir ${DIR} \
37 | DATASET.NUM_SHOTS ${SHOTS} \
38 | DATASET.SUBSAMPLE_CLASSES base
39 | fi
--------------------------------------------------------------------------------
/scripts/vpt/reproduce_vpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=VPT
8 |
9 | DATASET=$1
10 | SEED=$2
11 | WEIGHTSPATH=$3
12 |
13 | CFG=vit_b16_c2_ep5_batch4_4
14 | SHOTS=16
15 | LOADEP=5
16 | SUB_base=base
17 | SUB_novel=new
18 |
19 | COMMON_DIR=${DATASET}/shots_${SHOTS}/${TRAINER}/${CFG}/seed${SEED}
20 | MODEL_DIR=${WEIGHTSPATH}/base/seed${SEED}
21 | DIR_base=output/base2new/test_${SUB_base}/${COMMON_DIR}
22 | DIR_novel=output/base2new/test_${SUB_novel}/${COMMON_DIR}
23 | if [ -d "$DIR" ]; then
24 | echo "Results are already available in ${DIR}. Skipping..."
25 | else
26 | echo "Evaluating model"
27 | echo "Runing the first phase job and save the output to ${DIR}"
28 | # Evaluate on base classes
29 | python train.py \
30 | --root ${DATA} \
31 | --seed ${SEED} \
32 | --trainer ${TRAINER} \
33 | --dataset-config-file configs/datasets/${DATASET}.yaml \
34 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
35 | --output-dir ${DIR_base} \
36 | --model-dir ${MODEL_DIR} \
37 | --load-epoch ${LOADEP} \
38 | --eval-only \
39 | DATASET.NUM_SHOTS ${SHOTS} \
40 | DATASET.SUBSAMPLE_CLASSES ${SUB_base}
41 |
42 | # Evaluate on novel classes
43 | python train.py \
44 | --root ${DATA} \
45 | --seed ${SEED} \
46 | --trainer ${TRAINER} \
47 | --dataset-config-file configs/datasets/${DATASET}.yaml \
48 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
49 | --output-dir ${DIR_novel} \
50 | --model-dir ${MODEL_DIR} \
51 | --load-epoch ${LOADEP} \
52 | --eval-only \
53 | DATASET.NUM_SHOTS ${SHOTS} \
54 | DATASET.SUBSAMPLE_CLASSES ${SUB_novel}
55 |
56 | fi
--------------------------------------------------------------------------------
/scripts/vpt/xd_test_vpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=VPT
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4
13 | SHOTS=16
14 |
15 |
16 | DIR=output/evaluation/${TRAINER}/${CFG}_${SHOTS}shots/${DATASET}/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}. Skip this job"
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | --model-dir output/imagenet/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED} \
30 | --load-epoch 2 \
31 | --eval-only
32 | fi
--------------------------------------------------------------------------------
/scripts/vpt/xd_train_vpt.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA="/path/to/dataset/folder"
7 | TRAINER=VPT
8 |
9 | DATASET=$1
10 | SEED=$2
11 |
12 | CFG=vit_b16_c2_ep5_batch4_4
13 | SHOTS=16
14 |
15 |
16 | DIR=output/${DATASET}/${TRAINER}/${CFG}_${SHOTS}shots/seed${SEED}
17 | if [ -d "$DIR" ]; then
18 | echo "Results are available in ${DIR}."
19 | else
20 | echo "Run this job and save the output to ${DIR}"
21 |
22 | python train.py \
23 | --root ${DATA} \
24 | --seed ${SEED} \
25 | --trainer ${TRAINER} \
26 | --dataset-config-file configs/datasets/${DATASET}.yaml \
27 | --config-file configs/trainers/${TRAINER}/${CFG}.yaml \
28 | --output-dir ${DIR} \
29 | DATASET.NUM_SHOTS ${SHOTS}
30 | fi
--------------------------------------------------------------------------------
/scripts/zsclip/zeroshot.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #cd ../..
4 |
5 | # custom config
6 | DATA=/path/to/datasets
7 | TRAINER=ZeroshotCLIP
8 | DATASET=$1
9 | CFG=$2 # rn50, rn101, vit_b32 or vit_b16
10 |
11 | python train.py \
12 | --root ${DATA} \
13 | --trainer ${TRAINER} \
14 | --dataset-config-file configs/datasets/${DATASET}.yaml \
15 | --config-file configs/trainers/CoOp/${CFG}.yaml \
16 | --output-dir output/${TRAINER}/${CFG}/${DATASET} \
17 | --eval-only
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 |
4 | from dassl.utils import setup_logger, set_random_seed, collect_env_info
5 | from dassl.config import get_cfg_default
6 | from dassl.engine import build_trainer
7 |
8 | # custom
9 | import datasets.oxford_pets
10 | import datasets.oxford_flowers
11 | import datasets.fgvc_aircraft
12 | import datasets.dtd
13 | import datasets.eurosat
14 | import datasets.stanford_cars
15 | import datasets.food101
16 | import datasets.sun397
17 | import datasets.caltech101
18 | import datasets.ucf101
19 | import datasets.imagenet
20 |
21 | import datasets.imagenet_sketch
22 | import datasets.imagenetv2
23 | import datasets.imagenet_a
24 | import datasets.imagenet_r
25 |
26 | import trainers.coop
27 | import trainers.cocoop
28 | import trainers.zsclip
29 | import trainers.maple
30 | import trainers.mmp
31 | import trainers.independentVL
32 | import trainers.vpt
33 |
34 | def print_args(args, cfg):
35 | print("***************")
36 | print("** Arguments **")
37 | print("***************")
38 | optkeys = list(args.__dict__.keys())
39 | optkeys.sort()
40 | for key in optkeys:
41 | print("{}: {}".format(key, args.__dict__[key]))
42 | print("************")
43 | print("** Config **")
44 | print("************")
45 | print(cfg)
46 |
47 |
48 | def reset_cfg(cfg, args):
49 | if args.root:
50 | cfg.DATASET.ROOT = args.root
51 |
52 | if args.output_dir:
53 | cfg.OUTPUT_DIR = args.output_dir
54 |
55 | if args.resume:
56 | cfg.RESUME = args.resume
57 |
58 | if args.seed:
59 | cfg.SEED = args.seed
60 |
61 | if args.source_domains:
62 | cfg.DATASET.SOURCE_DOMAINS = args.source_domains
63 |
64 | if args.target_domains:
65 | cfg.DATASET.TARGET_DOMAINS = args.target_domains
66 |
67 | if args.transforms:
68 | cfg.INPUT.TRANSFORMS = args.transforms
69 |
70 | if args.trainer:
71 | cfg.TRAINER.NAME = args.trainer
72 |
73 | if args.backbone:
74 | cfg.MODEL.BACKBONE.NAME = args.backbone
75 |
76 | if args.head:
77 | cfg.MODEL.HEAD.NAME = args.head
78 |
79 |
80 | def extend_cfg(cfg):
81 | """
82 | Add new config variables.
83 |
84 | E.g.
85 | from yacs.config import CfgNode as CN
86 | cfg.TRAINER.MY_MODEL = CN()
87 | cfg.TRAINER.MY_MODEL.PARAM_A = 1.
88 | cfg.TRAINER.MY_MODEL.PARAM_B = 0.5
89 | cfg.TRAINER.MY_MODEL.PARAM_C = False
90 | """
91 | from yacs.config import CfgNode as CN
92 |
93 | cfg.TRAINER.COOP = CN()
94 | cfg.TRAINER.COOP.N_CTX = 16 # number of context vectors
95 | cfg.TRAINER.COOP.CSC = False # class-specific context
96 | cfg.TRAINER.COOP.CTX_INIT = "" # initialization words
97 | cfg.TRAINER.COOP.PREC = "fp16" # fp16, fp32, amp
98 | cfg.TRAINER.COOP.CLASS_TOKEN_POSITION = "end" # 'middle' or 'end' or 'front'
99 |
100 | cfg.TRAINER.COCOOP = CN()
101 | cfg.TRAINER.COCOOP.N_CTX = 16 # number of context vectors
102 | cfg.TRAINER.COCOOP.CTX_INIT = "" # initialization words
103 | cfg.TRAINER.COCOOP.PREC = "fp16" # fp16, fp32, amp
104 |
105 | # Config for MaPLe
106 | cfg.TRAINER.MAPLE = CN()
107 | cfg.TRAINER.MAPLE.N_CTX = 16 # number of context vectors
108 | cfg.TRAINER.MAPLE.CTX_INIT = "a photo of a" # initialization words
109 | cfg.TRAINER.MAPLE.PREC = "fp16" # fp16, fp32, amp
110 | cfg.TRAINER.MAPLE.PROMPT_DEPTH = 9 # Max 12, minimum 0, for 1 it will act as shallow MaPLe (J=1)
111 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
112 |
113 | # Config for MMP
114 | cfg.TRAINER.MMP = CN()
115 | cfg.TRAINER.MMP.N_CTX = 2 # number of context vectors
116 | cfg.TRAINER.MMP.CTX_INIT = "a photo of a" # initialization words
117 | cfg.TRAINER.MMP.PREC = "fp16" # fp16, fp32, amp
118 | cfg.TRAINER.MMP.TEXT_PROMPT_DEPTH = 1 # Max 12, minimum 0, for 1 it will act as shallow MMP (J=1)
119 | cfg.TRAINER.MMP.VISION_PROMPT_DEPTH = 1 # Max 12, minimum 0, for 1 it will act as shallow MMP (J=1)
120 | cfg.TRAINER.MMP.TEXT_PROMPT_NUMBER = 4 # number of to be learned language prompts
121 | cfg.TRAINER.MMP.VISION_PROMPT_NUMBER = 4 # number of to be learned vision prompts
122 | cfg.TRAINER.MMP.HIERARCHICAL = True
123 | cfg.TRAINER.MMP.USECT = True
124 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
125 |
126 |
127 | # Config for independent Vision Language prompting (independent-vlp)
128 | cfg.TRAINER.IVLP = CN()
129 | cfg.TRAINER.IVLP.N_CTX_VISION = 2 # number of context vectors at the vision branch
130 | cfg.TRAINER.IVLP.N_CTX_TEXT = 2 # number of context vectors at the language branch
131 | cfg.TRAINER.IVLP.CTX_INIT = "a photo of a" # initialization words (only for language prompts)
132 | cfg.TRAINER.IVLP.PREC = "fp16" # fp16, fp32, amp
133 | # If both variables below are set to 0, 0, will the config will degenerate to COOP model
134 | cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
135 | cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT = 9 # Max 12, minimum 0, for 0 it will act as shallow MaPLe (J=1)
136 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
137 |
138 | # Config for only vision side prompting
139 | cfg.TRAINER.VPT = CN()
140 | cfg.TRAINER.VPT.N_CTX_VISION = 2 # number of context vectors at the vision branch
141 | cfg.TRAINER.VPT.CTX_INIT = "a photo of a" # initialization words
142 | cfg.TRAINER.VPT.PREC = "fp16" # fp16, fp32, amp
143 | cfg.TRAINER.VPT.PROMPT_DEPTH_VISION = 1 # if set to 1, will represent shallow vision prompting only
144 | cfg.DATASET.SUBSAMPLE_CLASSES = "all" # all, base or new
145 |
146 |
147 | def setup_cfg(args):
148 | cfg = get_cfg_default()
149 | extend_cfg(cfg)
150 |
151 | # 1. From the dataset config file
152 | if args.dataset_config_file:
153 | cfg.merge_from_file(args.dataset_config_file)
154 |
155 | # 2. From the method config file
156 | if args.config_file:
157 | cfg.merge_from_file(args.config_file)
158 |
159 | # 3. From input arguments
160 | reset_cfg(cfg, args)
161 |
162 | # 4. From optional input arguments
163 | cfg.merge_from_list(args.opts)
164 |
165 | if cfg.DATASET.SUBSAMPLE_CLASSES == 'all': ## few shot setting
166 | if cfg.DATASET.NUM_SHOTS == 1:
167 | cfg.OPTIM.MAX_EPOCH = 30
168 | elif cfg.DATASET.NUM_SHOTS == 2 or cfg.DATASET.NUM_SHOTS == 4:
169 | cfg.OPTIM.MAX_EPOCH = 50
170 | else:
171 | cfg.OPTIM.MAX_EPOCH = 80
172 |
173 | if cfg.DATASET.NAME == "ImageNet":
174 | cfg.OPTIM.MAX_EPOCH = 20
175 |
176 | # if cfg.DATASET.NAME in ['OxfordFlowers', 'FGVCAircraft', 'StanfordCars']:
177 | # cfg.DATALOADER.TRAIN_X.BATCH_SIZE = 32 # 32 for small dataset such as Car,Air,Flowers
178 |
179 | cfg.freeze()
180 |
181 | return cfg
182 |
183 |
184 | def main(args):
185 | cfg = setup_cfg(args)
186 | if cfg.SEED >= 0:
187 | print("Setting fixed seed: {}".format(cfg.SEED))
188 | set_random_seed(cfg.SEED)
189 | setup_logger(cfg.OUTPUT_DIR)
190 |
191 | if torch.cuda.is_available() and cfg.USE_CUDA:
192 | torch.backends.cudnn.benchmark = True
193 |
194 | print_args(args, cfg)
195 | print("Collecting env info ...")
196 | print("** System info **\n{}\n".format(collect_env_info()))
197 |
198 | trainer = build_trainer(cfg)
199 |
200 | if args.eval_only:
201 | trainer.load_model(args.model_dir, epoch=args.load_epoch)
202 | trainer.test()
203 | return
204 |
205 | if not args.no_train:
206 | trainer.train()
207 |
208 |
209 | if __name__ == "__main__":
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument("--root", type=str, default="", help="path to dataset")
212 | parser.add_argument("--output-dir", type=str, default="", help="output directory")
213 | parser.add_argument(
214 | "--resume",
215 | type=str,
216 | default="",
217 | help="checkpoint directory (from which the training resumes)",
218 | )
219 | parser.add_argument(
220 | "--seed", type=int, default=-1, help="only positive value enables a fixed seed"
221 | )
222 | parser.add_argument(
223 | "--source-domains", type=str, nargs="+", help="source domains for DA/DG"
224 | )
225 | parser.add_argument(
226 | "--target-domains", type=str, nargs="+", help="target domains for DA/DG"
227 | )
228 | parser.add_argument(
229 | "--transforms", type=str, nargs="+", help="data augmentation methods"
230 | )
231 | parser.add_argument(
232 | "--config-file", type=str, default="", help="path to config file"
233 | )
234 | parser.add_argument(
235 | "--dataset-config-file",
236 | type=str,
237 | default="",
238 | help="path to config file for dataset setup",
239 | )
240 | parser.add_argument("--trainer", type=str, default="", help="name of trainer")
241 | parser.add_argument("--backbone", type=str, default="", help="name of CNN backbone")
242 | parser.add_argument("--head", type=str, default="", help="name of head")
243 | parser.add_argument("--eval-only", action="store_true", help="evaluation only")
244 | parser.add_argument(
245 | "--model-dir",
246 | type=str,
247 | default="",
248 | help="load model from this directory for eval-only mode",
249 | )
250 | parser.add_argument(
251 | "--load-epoch", type=int, help="load model weights at this epoch for evaluation"
252 | )
253 | parser.add_argument(
254 | "--no-train", action="store_true", help="do not call trainer.train()"
255 | )
256 | parser.add_argument(
257 | "opts",
258 | default=None,
259 | nargs=argparse.REMAINDER,
260 | help="modify config options using the command-line",
261 | )
262 | args = parser.parse_args()
263 | main(args)
264 |
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__init__.py
--------------------------------------------------------------------------------
/trainers/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/cocoop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/cocoop.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/cocoop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/cocoop.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/coop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/coop.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/coop.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/coop.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/imagenet_templates.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/imagenet_templates.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/imagenet_templates.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/imagenet_templates.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/independentVL.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/independentVL.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/independentVL.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/independentVL.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/maple.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/maple.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/maple.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/maple.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/mmp.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/mmp.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/mmp.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/mmp.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/vpt.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/vpt.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/vpt.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/vpt.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/zsclip.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/zsclip.cpython-37.pyc
--------------------------------------------------------------------------------
/trainers/__pycache__/zsclip.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/wds2014/ALIGN/9abd4be25b85a79e0308e7cdc8ac6932c749850a/trainers/__pycache__/zsclip.cpython-38.pyc
--------------------------------------------------------------------------------
/trainers/cocoop.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from torch.cuda.amp import GradScaler, autocast
9 |
10 | from dassl.engine import TRAINER_REGISTRY, TrainerX
11 | from dassl.metrics import compute_accuracy
12 | from dassl.utils import load_pretrained_weights, load_checkpoint
13 | from dassl.optim import build_optimizer, build_lr_scheduler
14 |
15 | from clip import clip
16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | _tokenizer = _Tokenizer()
19 |
20 |
21 | def load_clip_to_cpu(cfg):
22 | backbone_name = cfg.MODEL.BACKBONE.NAME
23 | url = clip._MODELS[backbone_name]
24 | model_path = clip._download(url)
25 |
26 | try:
27 | # loading JIT archive
28 | model = torch.jit.load(model_path, map_location="cpu").eval()
29 | state_dict = None
30 |
31 | except RuntimeError:
32 | state_dict = torch.load(model_path, map_location="cpu")
33 | design_details = {"trainer": 'CoCoOp',
34 | "vision_depth": 0,
35 | "language_depth": 0, "vision_ctx": 0,
36 | "language_ctx": 0}
37 | model = clip.build_model(state_dict or model.state_dict(), design_details)
38 |
39 | return model
40 |
41 |
42 | class TextEncoder(nn.Module):
43 | def __init__(self, clip_model):
44 | super().__init__()
45 | self.transformer = clip_model.transformer
46 | self.positional_embedding = clip_model.positional_embedding
47 | self.ln_final = clip_model.ln_final
48 | self.text_projection = clip_model.text_projection
49 | self.dtype = clip_model.dtype
50 |
51 | def forward(self, prompts, tokenized_prompts):
52 | x = prompts + self.positional_embedding.type(self.dtype)
53 | x = x.permute(1, 0, 2) # NLD -> LND
54 | x = self.transformer(x)
55 | x = x.permute(1, 0, 2) # LND -> NLD
56 | x = self.ln_final(x).type(self.dtype)
57 |
58 | # x.shape = [batch_size, n_ctx, transformer.width]
59 | # take features from the eot embedding (eot_token is the highest number in each sequence)
60 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
61 |
62 | return x
63 |
64 |
65 | class PromptLearner(nn.Module):
66 | def __init__(self, cfg, classnames, clip_model):
67 | super().__init__()
68 | n_cls = len(classnames)
69 | n_ctx = cfg.TRAINER.COCOOP.N_CTX
70 | ctx_init = cfg.TRAINER.COCOOP.CTX_INIT
71 | dtype = clip_model.dtype
72 | ctx_dim = clip_model.ln_final.weight.shape[0]
73 | vis_dim = clip_model.visual.output_dim
74 | clip_imsize = clip_model.visual.input_resolution
75 | cfg_imsize = cfg.INPUT.SIZE[0]
76 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
77 |
78 | if ctx_init:
79 | # use given words to initialize context vectors
80 | ctx_init = ctx_init.replace("_", " ")
81 | n_ctx = len(ctx_init.split(" "))
82 | prompt = clip.tokenize(ctx_init)
83 | with torch.no_grad():
84 | embedding = clip_model.token_embedding(prompt).type(dtype)
85 | ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
86 | prompt_prefix = ctx_init
87 | else:
88 | # random initialization
89 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
90 | nn.init.normal_(ctx_vectors, std=0.02)
91 | prompt_prefix = " ".join(["X"] * n_ctx)
92 |
93 | print(f'Initial context: "{prompt_prefix}"')
94 | print(f"Number of context words (tokens): {n_ctx}")
95 |
96 | self.ctx = nn.Parameter(ctx_vectors)
97 |
98 | self.meta_net = nn.Sequential(OrderedDict([
99 | ("linear1", nn.Linear(vis_dim, vis_dim // 16)),
100 | ("relu", nn.ReLU(inplace=True)),
101 | ("linear2", nn.Linear(vis_dim // 16, ctx_dim))
102 | ]))
103 |
104 | if cfg.TRAINER.COCOOP.PREC == "fp16":
105 | self.meta_net.half()
106 |
107 | classnames = [name.replace("_", " ") for name in classnames]
108 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
109 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
110 |
111 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
112 | with torch.no_grad():
113 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
114 |
115 | # These token vectors will be saved when in save_model(),
116 | # but they should be ignored in load_model() as we want to use
117 | # those computed using the current class names
118 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
119 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
120 |
121 | self.n_cls = n_cls
122 | self.n_ctx = n_ctx
123 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
124 | self.name_lens = name_lens
125 |
126 | def construct_prompts(self, ctx, prefix, suffix, label=None):
127 | # dim0 is either batch_size (during training) or n_cls (during testing)
128 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
129 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
130 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
131 |
132 | if label is not None:
133 | prefix = prefix[label]
134 | suffix = suffix[label]
135 |
136 | prompts = torch.cat(
137 | [
138 | prefix, # (dim0, 1, dim)
139 | ctx, # (dim0, n_ctx, dim)
140 | suffix, # (dim0, *, dim)
141 | ],
142 | dim=1,
143 | )
144 |
145 | return prompts
146 |
147 | def forward(self, im_features):
148 | prefix = self.token_prefix
149 | suffix = self.token_suffix
150 | ctx = self.ctx # (n_ctx, ctx_dim)
151 | bias = self.meta_net(im_features) # (batch, ctx_dim)
152 | bias = bias.unsqueeze(1) # (batch, 1, ctx_dim)
153 | ctx = ctx.unsqueeze(0) # (1, n_ctx, ctx_dim)
154 | ctx_shifted = ctx + bias # (batch, n_ctx, ctx_dim)
155 |
156 | # Use instance-conditioned context tokens for all classes
157 | prompts = []
158 | for ctx_shifted_i in ctx_shifted:
159 | ctx_i = ctx_shifted_i.unsqueeze(0).expand(self.n_cls, -1, -1)
160 | pts_i = self.construct_prompts(ctx_i, prefix, suffix) # (n_cls, n_tkn, ctx_dim)
161 | prompts.append(pts_i)
162 | prompts = torch.stack(prompts)
163 |
164 | return prompts
165 |
166 |
167 | class CustomCLIP(nn.Module):
168 | def __init__(self, cfg, classnames, clip_model):
169 | super().__init__()
170 | self.prompt_learner = PromptLearner(cfg, classnames, clip_model)
171 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
172 | self.image_encoder = clip_model.visual
173 | self.text_encoder = TextEncoder(clip_model)
174 | self.logit_scale = clip_model.logit_scale
175 | self.dtype = clip_model.dtype
176 |
177 | def forward(self, image, label=None):
178 | tokenized_prompts = self.tokenized_prompts
179 | logit_scale = self.logit_scale.exp()
180 |
181 | image_features = self.image_encoder(image.type(self.dtype))
182 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
183 |
184 | prompts = self.prompt_learner(image_features)
185 |
186 | logits = []
187 | for pts_i, imf_i in zip(prompts, image_features):
188 | text_features = self.text_encoder(pts_i, tokenized_prompts)
189 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
190 | l_i = logit_scale * imf_i @ text_features.t()
191 | logits.append(l_i)
192 | logits = torch.stack(logits)
193 |
194 | if self.prompt_learner.training:
195 | return F.cross_entropy(logits, label)
196 |
197 | return logits
198 |
199 |
200 | @TRAINER_REGISTRY.register()
201 | class CoCoOp(TrainerX):
202 | def check_cfg(self, cfg):
203 | assert cfg.TRAINER.COCOOP.PREC in ["fp16", "fp32", "amp"]
204 |
205 | def build_model(self):
206 | cfg = self.cfg
207 | classnames = self.dm.dataset.classnames
208 |
209 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
210 | clip_model = load_clip_to_cpu(cfg)
211 |
212 | if cfg.TRAINER.COCOOP.PREC == "fp32" or cfg.TRAINER.COCOOP.PREC == "amp":
213 | # CLIP's default precision is fp16
214 | clip_model.float()
215 |
216 | print("Building custom CLIP")
217 | self.model = CustomCLIP(cfg, classnames, clip_model)
218 |
219 | print("Turning off gradients in both the image and the text encoder")
220 | name_to_update = "prompt_learner"
221 |
222 | for name, param in self.model.named_parameters():
223 | if name_to_update not in name:
224 | param.requires_grad_(False)
225 |
226 | # Double check
227 | enabled = set()
228 | for name, param in self.model.named_parameters():
229 | if param.requires_grad:
230 | enabled.add(name)
231 | print(f"Parameters to be updated: {enabled}")
232 |
233 | if cfg.MODEL.INIT_WEIGHTS:
234 | load_pretrained_weights(self.model.prompt_learner, cfg.MODEL.INIT_WEIGHTS)
235 |
236 | self.model.to(self.device)
237 | # NOTE: only give prompt_learner to the optimizer
238 | self.optim = build_optimizer(self.model.prompt_learner, cfg.OPTIM)
239 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
240 | self.register_model("prompt_learner", self.model.prompt_learner, self.optim, self.sched)
241 |
242 | self.scaler = GradScaler() if cfg.TRAINER.COCOOP.PREC == "amp" else None
243 |
244 | # Note that multi-gpu training could be slow because CLIP's size is
245 | # big, which slows down the copy operation in DataParallel
246 | device_count = torch.cuda.device_count()
247 | if device_count > 1:
248 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
249 | self.model = nn.DataParallel(self.model)
250 |
251 | def forward_backward(self, batch):
252 | image, label = self.parse_batch_train(batch)
253 |
254 | model = self.model
255 | optim = self.optim
256 | scaler = self.scaler
257 |
258 | prec = self.cfg.TRAINER.COCOOP.PREC
259 | if prec == "amp":
260 | with autocast():
261 | loss = model(image, label)
262 | optim.zero_grad()
263 | scaler.scale(loss).backward()
264 | scaler.step(optim)
265 | scaler.update()
266 | else:
267 | loss = model(image, label)
268 | optim.zero_grad()
269 | loss.backward()
270 | optim.step()
271 |
272 | loss_summary = {"loss": loss.item()}
273 |
274 | if (self.batch_idx + 1) == self.num_batches:
275 | self.update_lr()
276 |
277 | return loss_summary
278 |
279 | def parse_batch_train(self, batch):
280 | input = batch["img"]
281 | label = batch["label"]
282 | input = input.to(self.device)
283 | label = label.to(self.device)
284 | return input, label
285 |
286 | def load_model(self, directory, epoch=None):
287 | if not directory:
288 | print("Note that load_model() is skipped as no pretrained model is given")
289 | return
290 |
291 | names = self.get_model_names()
292 |
293 | # By default, the best model is loaded
294 | model_file = "model-best.pth.tar"
295 |
296 | if epoch is not None:
297 | model_file = "model.pth.tar-" + str(epoch)
298 |
299 | for name in names:
300 | model_path = osp.join(directory, name, model_file)
301 |
302 | if not osp.exists(model_path):
303 | raise FileNotFoundError('Model not found at "{}"'.format(model_path))
304 |
305 | checkpoint = load_checkpoint(model_path)
306 | state_dict = checkpoint["state_dict"]
307 | epoch = checkpoint["epoch"]
308 |
309 | # Ignore fixed token vectors
310 | if "token_prefix" in state_dict:
311 | del state_dict["token_prefix"]
312 |
313 | if "token_suffix" in state_dict:
314 | del state_dict["token_suffix"]
315 |
316 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
317 | # set strict=False
318 | self._models[name].load_state_dict(state_dict, strict=False)
319 |
--------------------------------------------------------------------------------
/trainers/imagenet_templates.py:
--------------------------------------------------------------------------------
1 | # source: https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
2 |
3 | IMAGENET_TEMPLATES = [
4 | "a bad photo of a {}.",
5 | "a photo of many {}.",
6 | "a sculpture of a {}.",
7 | "a photo of the hard to see {}.",
8 | "a low resolution photo of the {}.",
9 | "a rendering of a {}.",
10 | "graffiti of a {}.",
11 | "a bad photo of the {}.",
12 | "a cropped photo of the {}.",
13 | "a tattoo of a {}.",
14 | "the embroidered {}.",
15 | "a photo of a hard to see {}.",
16 | "a bright photo of a {}.",
17 | "a photo of a clean {}.",
18 | "a photo of a dirty {}.",
19 | "a dark photo of the {}.",
20 | "a drawing of a {}.",
21 | "a photo of my {}.",
22 | "the plastic {}.",
23 | "a photo of the cool {}.",
24 | "a close-up photo of a {}.",
25 | "a black and white photo of the {}.",
26 | "a painting of the {}.",
27 | "a painting of a {}.",
28 | "a pixelated photo of the {}.",
29 | "a sculpture of the {}.",
30 | "a bright photo of the {}.",
31 | "a cropped photo of a {}.",
32 | "a plastic {}.",
33 | "a photo of the dirty {}.",
34 | "a jpeg corrupted photo of a {}.",
35 | "a blurry photo of the {}.",
36 | "a photo of the {}.",
37 | "a good photo of the {}.",
38 | "a rendering of the {}.",
39 | "a {} in a video game.",
40 | "a photo of one {}.",
41 | "a doodle of a {}.",
42 | "a close-up photo of the {}.",
43 | "a photo of a {}.",
44 | "the origami {}.",
45 | "the {} in a video game.",
46 | "a sketch of a {}.",
47 | "a doodle of the {}.",
48 | "a origami {}.",
49 | "a low resolution photo of a {}.",
50 | "the toy {}.",
51 | "a rendition of the {}.",
52 | "a photo of the clean {}.",
53 | "a photo of a large {}.",
54 | "a rendition of a {}.",
55 | "a photo of a nice {}.",
56 | "a photo of a weird {}.",
57 | "a blurry photo of a {}.",
58 | "a cartoon {}.",
59 | "art of a {}.",
60 | "a sketch of the {}.",
61 | "a embroidered {}.",
62 | "a pixelated photo of a {}.",
63 | "itap of the {}.",
64 | "a jpeg corrupted photo of the {}.",
65 | "a good photo of a {}.",
66 | "a plushie {}.",
67 | "a photo of the nice {}.",
68 | "a photo of the small {}.",
69 | "a photo of the weird {}.",
70 | "the cartoon {}.",
71 | "art of the {}.",
72 | "a drawing of the {}.",
73 | "a photo of the large {}.",
74 | "a black and white photo of a {}.",
75 | "the plushie {}.",
76 | "a dark photo of a {}.",
77 | "itap of a {}.",
78 | "graffiti of the {}.",
79 | "a toy {}.",
80 | "itap of my {}.",
81 | "a photo of a cool {}.",
82 | "a photo of a small {}.",
83 | "a tattoo of the {}.",
84 | ]
85 |
86 | IMAGENET_TEMPLATES_SELECT = [
87 | "itap of a {}.",
88 | "a bad photo of the {}.",
89 | "a origami {}.",
90 | "a photo of the large {}.",
91 | "a {} in a video game.",
92 | "art of the {}.",
93 | "a photo of the small {}.",
94 | ]
95 |
--------------------------------------------------------------------------------
/trainers/independentVL.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from torch.cuda.amp import GradScaler, autocast
9 |
10 | from dassl.engine import TRAINER_REGISTRY, TrainerX
11 | from dassl.metrics import compute_accuracy
12 | from dassl.utils import load_pretrained_weights, load_checkpoint
13 | from dassl.optim import build_optimizer, build_lr_scheduler
14 |
15 | from clip import clip
16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | _tokenizer = _Tokenizer()
19 |
20 |
21 | def load_clip_to_cpu(cfg):
22 | backbone_name = cfg.MODEL.BACKBONE.NAME
23 | url = clip._MODELS[backbone_name]
24 | model_path = clip._download(url)
25 |
26 | try:
27 | # loading JIT archive
28 | model = torch.jit.load(model_path, map_location="cpu").eval()
29 | state_dict = None
30 |
31 | except RuntimeError:
32 | state_dict = torch.load(model_path, map_location="cpu")
33 | design_details = {"trainer": 'IVLP',
34 | "vision_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_VISION,
35 | "language_depth": cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT, "vision_ctx": cfg.TRAINER.IVLP.N_CTX_VISION,
36 | "language_ctx": cfg.TRAINER.IVLP.N_CTX_TEXT}
37 | model = clip.build_model(state_dict or model.state_dict(), design_details)
38 |
39 | return model
40 |
41 |
42 | class TextEncoder(nn.Module):
43 | def __init__(self, clip_model):
44 | super().__init__()
45 | self.transformer = clip_model.transformer
46 | self.positional_embedding = clip_model.positional_embedding
47 | self.ln_final = clip_model.ln_final
48 | self.text_projection = clip_model.text_projection
49 | self.dtype = clip_model.dtype
50 |
51 | def forward(self, prompts, tokenized_prompts):
52 | x = prompts + self.positional_embedding.type(self.dtype)
53 | x = x.permute(1, 0, 2) # NLD -> LND
54 | x = self.transformer(x)
55 | x = x.permute(1, 0, 2) # LND -> NLD
56 | x = self.ln_final(x).type(self.dtype)
57 |
58 | # x.shape = [batch_size, n_ctx, transformer.width]
59 | # take features from the eot embedding (eot_token is the highest number in each sequence)
60 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
61 |
62 | return x
63 |
64 |
65 | class VLPromptLearner(nn.Module):
66 | def __init__(self, cfg, classnames, clip_model):
67 | super().__init__()
68 | n_cls = len(classnames)
69 | # Make sure Language depth >= 1
70 | assert cfg.TRAINER.IVLP.PROMPT_DEPTH_TEXT >= 1, "In Independent VL prompting, Language prompt depth should be >=1" \
71 | "\nPlease use VPT trainer if you want to learn only vision " \
72 | "branch "
73 | n_ctx = cfg.TRAINER.IVLP.N_CTX_TEXT
74 | ctx_init = cfg.TRAINER.IVLP.CTX_INIT
75 | dtype = clip_model.dtype
76 | ctx_dim = clip_model.ln_final.weight.shape[0]
77 | vis_dim = clip_model.visual.output_dim
78 | clip_imsize = clip_model.visual.input_resolution
79 | cfg_imsize = cfg.INPUT.SIZE[0]
80 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
81 |
82 | if ctx_init and (n_ctx) <= 4:
83 | # use given words to initialize context vectors
84 | ctx_init = ctx_init.replace("_", " ")
85 | n_ctx = n_ctx
86 | prompt = clip.tokenize(ctx_init)
87 | with torch.no_grad():
88 | embedding = clip_model.token_embedding(prompt).type(dtype)
89 | ctx_vectors = embedding[0, 1: 1 + n_ctx, :]
90 | prompt_prefix = ctx_init
91 | else:
92 | # random initialization
93 | ctx_vectors = torch.empty(n_ctx, ctx_dim, dtype=dtype)
94 | nn.init.normal_(ctx_vectors, std=0.02)
95 | prompt_prefix = " ".join(["X"] * n_ctx)
96 | print(f"Independent V-L design")
97 | print(f'Initial text context: "{prompt_prefix}"')
98 | print(f"Number of context words (tokens) for Language prompting: {n_ctx}")
99 | print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.IVLP.N_CTX_VISION}")
100 | self.ctx = nn.Parameter(ctx_vectors)
101 |
102 | classnames = [name.replace("_", " ") for name in classnames]
103 | name_lens = [len(_tokenizer.encode(name)) for name in classnames]
104 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
105 |
106 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts]) # (n_cls, n_tkn)
107 | with torch.no_grad():
108 | embedding = clip_model.token_embedding(tokenized_prompts).type(dtype)
109 |
110 | # These token vectors will be saved when in save_model(),
111 | # but they should be ignored in load_model() as we want to use
112 | # those computed using the current class names
113 | self.register_buffer("token_prefix", embedding[:, :1, :]) # SOS
114 | self.register_buffer("token_suffix", embedding[:, 1 + n_ctx:, :]) # CLS, EOS
115 |
116 | self.n_cls = n_cls
117 | self.n_ctx = n_ctx
118 | self.tokenized_prompts = tokenized_prompts # torch.Tensor
119 | self.name_lens = name_lens
120 |
121 | def construct_prompts(self, ctx, prefix, suffix, label=None):
122 | # dim0 is either batch_size (during training) or n_cls (during testing)
123 | # ctx: context tokens, with shape of (dim0, n_ctx, ctx_dim)
124 | # prefix: the sos token, with shape of (n_cls, 1, ctx_dim)
125 | # suffix: remaining tokens, with shape of (n_cls, *, ctx_dim)
126 |
127 | if label is not None:
128 | prefix = prefix[label]
129 | suffix = suffix[label]
130 |
131 | prompts = torch.cat(
132 | [
133 | prefix, # (dim0, 1, dim)
134 | ctx, # (dim0, n_ctx, dim)
135 | suffix, # (dim0, *, dim)
136 | ],
137 | dim=1,
138 | )
139 |
140 | return prompts
141 |
142 | def forward(self):
143 | ctx = self.ctx
144 | if ctx.dim() == 2:
145 | ctx = ctx.unsqueeze(0).expand(self.n_cls, -1, -1)
146 |
147 | prefix = self.token_prefix
148 | suffix = self.token_suffix
149 | prompts = self.construct_prompts(ctx, prefix, suffix)
150 |
151 | return prompts
152 |
153 |
154 | class CustomCLIP(nn.Module):
155 | def __init__(self, cfg, classnames, clip_model):
156 | super().__init__()
157 | self.prompt_learner = VLPromptLearner(cfg, classnames, clip_model)
158 | self.tokenized_prompts = self.prompt_learner.tokenized_prompts
159 | self.image_encoder = clip_model.visual
160 | self.text_encoder = TextEncoder(clip_model)
161 | self.logit_scale = clip_model.logit_scale
162 | self.dtype = clip_model.dtype
163 |
164 | def forward(self, image, label=None):
165 | tokenized_prompts = self.tokenized_prompts
166 | logit_scale = self.logit_scale.exp()
167 |
168 | prompts = self.prompt_learner()
169 | text_features = self.text_encoder(prompts, tokenized_prompts)
170 | image_features = self.image_encoder(image.type(self.dtype))
171 |
172 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
173 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
174 | logits = logit_scale * image_features @ text_features.t()
175 |
176 | if self.prompt_learner.training:
177 | return F.cross_entropy(logits, label)
178 |
179 | return logits
180 |
181 |
182 | @TRAINER_REGISTRY.register()
183 | class IVLP(TrainerX):
184 | def check_cfg(self, cfg):
185 | assert cfg.TRAINER.IVLP.PREC in ["fp16", "fp32", "amp"]
186 |
187 | def build_model(self):
188 | cfg = self.cfg
189 | classnames = self.dm.dataset.classnames
190 |
191 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
192 | clip_model = load_clip_to_cpu(cfg)
193 |
194 | if cfg.TRAINER.IVLP.PREC == "fp32" or cfg.TRAINER.IVLP.PREC == "amp":
195 | # CLIP's default precision is fp16
196 | clip_model.float()
197 |
198 | print("Building custom CLIP")
199 | self.model = CustomCLIP(cfg, classnames, clip_model)
200 |
201 | print("Turning off gradients in both the image and the text encoder")
202 | name_to_update = "prompt_learner"
203 |
204 | for name, param in self.model.named_parameters():
205 | if name_to_update not in name:
206 | # Make sure that VPT prompts are updated
207 | if "VPT" in name:
208 | param.requires_grad_(True)
209 | else:
210 | param.requires_grad_(False)
211 |
212 | # Double check
213 | enabled = set()
214 | for name, param in self.model.named_parameters():
215 | if param.requires_grad:
216 | enabled.add(name)
217 | print(f"Parameters to be updated: {enabled}")
218 |
219 | if cfg.MODEL.INIT_WEIGHTS:
220 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
221 |
222 | self.model.to(self.device)
223 | # NOTE: only give prompt_learner to the optimizer
224 | self.optim = build_optimizer(self.model, cfg.OPTIM)
225 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
226 | self.register_model("VLPromptLearner", self.model, self.optim, self.sched)
227 |
228 | self.scaler = GradScaler() if cfg.TRAINER.IVLP.PREC == "amp" else None
229 |
230 | # Note that multi-gpu training could be slow because CLIP's size is
231 | # big, which slows down the copy operation in DataParallel
232 | device_count = torch.cuda.device_count()
233 | if device_count > 1:
234 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
235 | self.model = nn.DataParallel(self.model)
236 |
237 | def forward_backward(self, batch):
238 | image, label = self.parse_batch_train(batch)
239 |
240 | model = self.model
241 | optim = self.optim
242 | scaler = self.scaler
243 |
244 | prec = self.cfg.TRAINER.IVLP.PREC
245 | if prec == "amp":
246 | with autocast():
247 | loss = model(image, label)
248 | optim.zero_grad()
249 | scaler.scale(loss).backward()
250 | scaler.step(optim)
251 | scaler.update()
252 | else:
253 | loss = model(image, label)
254 | optim.zero_grad()
255 | loss.backward()
256 | optim.step()
257 |
258 | loss_summary = {"loss": loss.item()}
259 |
260 | if (self.batch_idx + 1) == self.num_batches:
261 | self.update_lr()
262 |
263 | return loss_summary
264 |
265 | def parse_batch_train(self, batch):
266 | input = batch["img"]
267 | label = batch["label"]
268 | input = input.to(self.device)
269 | label = label.to(self.device)
270 | return input, label
271 |
272 | def load_model(self, directory, epoch=None):
273 | if not directory:
274 | print("Note that load_model() is skipped as no pretrained model is given")
275 | return
276 |
277 | names = self.get_model_names()
278 |
279 | # By default, the best model is loaded
280 | model_file = "model-best.pth.tar"
281 |
282 | if epoch is not None:
283 | model_file = "model.pth.tar-" + str(epoch)
284 |
285 | for name in names:
286 | model_path = osp.join(directory, name, model_file)
287 |
288 | if not osp.exists(model_path):
289 | raise FileNotFoundError('Model not found at "{}"'.format(model_path))
290 |
291 | checkpoint = load_checkpoint(model_path)
292 | state_dict = checkpoint["state_dict"]
293 | epoch = checkpoint["epoch"]
294 |
295 | # Ignore fixed token vectors
296 | if "prompt_learner.token_prefix" in state_dict:
297 | del state_dict["prompt_learner.token_prefix"]
298 |
299 | if "prompt_learner.token_suffix" in state_dict:
300 | del state_dict["prompt_learner.token_suffix"]
301 |
302 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
303 | # set strict=False
304 | self._models[name].load_state_dict(state_dict, strict=False)
305 |
--------------------------------------------------------------------------------
/trainers/vpt.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | from collections import OrderedDict
3 | import math
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import functional as F
8 | from torch.cuda.amp import GradScaler, autocast
9 |
10 | from dassl.engine import TRAINER_REGISTRY, TrainerX
11 | from dassl.metrics import compute_accuracy
12 | from dassl.utils import load_pretrained_weights, load_checkpoint
13 | from dassl.optim import build_optimizer, build_lr_scheduler
14 |
15 | from clip import clip
16 | from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
17 |
18 | _tokenizer = _Tokenizer()
19 |
20 |
21 | def load_clip_to_cpu(cfg):
22 | backbone_name = cfg.MODEL.BACKBONE.NAME
23 | url = clip._MODELS[backbone_name]
24 | model_path = clip._download(url)
25 |
26 | try:
27 | # loading JIT archive
28 | model = torch.jit.load(model_path, map_location="cpu").eval()
29 | state_dict = None
30 |
31 | except RuntimeError:
32 | state_dict = torch.load(model_path, map_location="cpu")
33 | design_details = { "trainer": "VPT",
34 | "vision_depth": cfg.TRAINER.VPT.PROMPT_DEPTH_VISION,
35 | "vision_ctx": cfg.TRAINER.VPT.N_CTX_VISION,
36 | "language_depth": 0,
37 | "language_ctx": 0}
38 | assert cfg.TRAINER.VPT.PROMPT_DEPTH_VISION >= 1, "For Vision Prompting, PROMPT_DEPTH_VISION should be >= 1"
39 | model = clip.build_model(state_dict or model.state_dict(), design_details)
40 |
41 | return model.float()
42 |
43 |
44 | class TextEncoder(nn.Module):
45 | def __init__(self, clip_model):
46 | super().__init__()
47 | self.transformer = clip_model.transformer
48 | self.positional_embedding = clip_model.positional_embedding
49 | self.ln_final = clip_model.ln_final
50 | self.text_projection = clip_model.text_projection
51 | self.dtype = clip_model.dtype
52 |
53 | def forward(self, prompts, tokenized_prompts):
54 | x = prompts + self.positional_embedding.type(self.dtype)
55 | x = x.permute(1, 0, 2) # NLD -> LND
56 | x = self.transformer(x)
57 | x = x.permute(1, 0, 2) # LND -> NLD
58 | x = self.ln_final(x).type(self.dtype)
59 |
60 | # x.shape = [batch_size, n_ctx, transformer.width]
61 | # take features from the eot embedding (eot_token is the highest number in each sequence)
62 | x = x[torch.arange(x.shape[0]), tokenized_prompts.argmax(dim=-1)] @ self.text_projection
63 |
64 | return x
65 |
66 |
67 | class FixedEmbeddings():
68 | def __init__(self, cfg, classnames, clip_model):
69 | clip_imsize = clip_model.visual.input_resolution
70 | cfg_imsize = cfg.INPUT.SIZE[0]
71 | assert cfg_imsize == clip_imsize, f"cfg_imsize ({cfg_imsize}) must equal to clip_imsize ({clip_imsize})"
72 |
73 | prompt_prefix = "a photo of a"
74 | print('Vision Prompting Design')
75 | print(f'Initial context: "{prompt_prefix}"')
76 | print(f"Number of context words (tokens) for Vision prompting: {cfg.TRAINER.VPT.N_CTX_VISION}")
77 | print(f"Using fixed hand crated prompts")
78 |
79 | classnames = [name.replace("_", " ") for name in classnames]
80 | prompts = [prompt_prefix + " " + name + "." for name in classnames]
81 |
82 | tokenized_prompts = torch.cat([clip.tokenize(p) for p in prompts])
83 | with torch.no_grad():
84 | text_features = clip_model.encode_text(tokenized_prompts)
85 |
86 | self.fixed_embeddings = text_features
87 |
88 | def return_fixed_embeddings(self):
89 | return self.fixed_embeddings
90 |
91 |
92 | class CustomCLIP(nn.Module):
93 | def __init__(self, cfg, classnames, clip_model):
94 | super().__init__()
95 | self.embeddings = FixedEmbeddings(cfg, classnames, clip_model)
96 | self.image_encoder = clip_model.visual
97 | self.text_encoder = TextEncoder(clip_model)
98 | self.logit_scale = clip_model.logit_scale
99 | self.dtype = clip_model.dtype
100 |
101 | def forward(self, image, label=None, training=False):
102 | logit_scale = self.logit_scale.exp()
103 |
104 | text_features = self.embeddings.return_fixed_embeddings().cuda()
105 | image_features = self.image_encoder(image.type(self.dtype))
106 |
107 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
108 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
109 | logits = logit_scale * image_features @ text_features.t()
110 |
111 | if training:
112 | return F.cross_entropy(logits, label)
113 |
114 | return logits
115 |
116 |
117 | @TRAINER_REGISTRY.register()
118 | class VPT(TrainerX):
119 | def check_cfg(self, cfg):
120 | assert cfg.TRAINER.VPT.PREC in ["fp16", "fp32", "amp"]
121 |
122 | def build_model(self):
123 | cfg = self.cfg
124 | classnames = self.dm.dataset.classnames
125 |
126 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
127 | clip_model = load_clip_to_cpu(cfg)
128 |
129 | if cfg.TRAINER.VPT.PREC == "fp32" or cfg.TRAINER.VPT.PREC == "amp":
130 | # CLIP's default precision is fp16
131 | clip_model.float()
132 |
133 | print("Building custom CLIP")
134 | self.model = CustomCLIP(cfg, classnames, clip_model)
135 |
136 | print("Turning off gradients in both the image and the text encoder")
137 | name_to_update = "prompt_learner"
138 |
139 | for name, param in self.model.named_parameters():
140 | if name_to_update not in name:
141 | # Make sure that VPT prompts are updated
142 | if "VPT" in name:
143 | param.requires_grad_(True)
144 | else:
145 | param.requires_grad_(False)
146 |
147 | # Double check
148 | enabled = set()
149 | for name, param in self.model.named_parameters():
150 | if param.requires_grad:
151 | enabled.add(name)
152 | print(f"Parameters to be updated: {enabled}")
153 |
154 | if cfg.MODEL.INIT_WEIGHTS:
155 | load_pretrained_weights(self.model, cfg.MODEL.INIT_WEIGHTS)
156 |
157 | self.model.to(self.device)
158 | # NOTE: only give prompt_learner to the optimizer
159 | self.optim = build_optimizer(self.model, cfg.OPTIM)
160 | self.sched = build_lr_scheduler(self.optim, cfg.OPTIM)
161 | self.register_model("prompt_learner", self.model, self.optim, self.sched)
162 |
163 | self.scaler = GradScaler() if cfg.TRAINER.VPT.PREC == "amp" else None
164 |
165 | # Note that multi-gpu training could be slow because CLIP's size is
166 | # big, which slows down the copy operation in DataParallel
167 | device_count = torch.cuda.device_count()
168 | if device_count > 1:
169 | print(f"Multiple GPUs detected (n_gpus={device_count}), use all of them!")
170 | self.model = nn.DataParallel(self.model)
171 |
172 | def forward_backward(self, batch):
173 | image, label = self.parse_batch_train(batch)
174 |
175 | model = self.model
176 | optim = self.optim
177 | scaler = self.scaler
178 |
179 | prec = self.cfg.TRAINER.VPT.PREC
180 | if prec == "amp":
181 | with autocast():
182 | loss = model(image, label)
183 | optim.zero_grad()
184 | scaler.scale(loss).backward()
185 | scaler.step(optim)
186 | scaler.update()
187 | else:
188 | loss = model(image, label, training=True)
189 | optim.zero_grad()
190 | loss.backward()
191 | optim.step()
192 |
193 | loss_summary = {"loss": loss.item()}
194 |
195 | if (self.batch_idx + 1) == self.num_batches:
196 | self.update_lr()
197 |
198 | return loss_summary
199 |
200 | def parse_batch_train(self, batch):
201 | input = batch["img"]
202 | label = batch["label"]
203 | input = input.to(self.device)
204 | label = label.to(self.device)
205 | return input, label
206 |
207 | def load_model(self, directory, epoch=None):
208 | if not directory:
209 | print("Note that load_model() is skipped as no pretrained model is given")
210 | return
211 |
212 | names = self.get_model_names()
213 |
214 | # By default, the best model is loaded
215 | model_file = "model-best.pth.tar"
216 |
217 | if epoch is not None:
218 | model_file = "model.pth.tar-" + str(epoch)
219 |
220 | for name in names:
221 | model_path = osp.join(directory, name, model_file)
222 |
223 | if not osp.exists(model_path):
224 | raise FileNotFoundError('Model not found at "{}"'.format(model_path))
225 |
226 | checkpoint = load_checkpoint(model_path)
227 | state_dict = checkpoint["state_dict"]
228 | epoch = checkpoint["epoch"]
229 |
230 | # Ignore fixed token vectors
231 | if "prompt_learner.token_prefix" in state_dict:
232 | del state_dict["prompt_learner.token_prefix"]
233 |
234 | if "prompt_learner.token_suffix" in state_dict:
235 | del state_dict["prompt_learner.token_suffix"]
236 |
237 | print("Loading weights to {} " 'from "{}" (epoch = {})'.format(name, model_path, epoch))
238 | # set strict=False
239 | self._models[name].load_state_dict(state_dict, strict=False)
240 |
--------------------------------------------------------------------------------
/trainers/zsclip.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from dassl.engine import TRAINER_REGISTRY, TrainerX
5 | from dassl.optim import build_optimizer, build_lr_scheduler
6 |
7 | from clip import clip
8 | from clip.model import convert_weights
9 |
10 | from .coop import load_clip_to_cpu
11 | from .imagenet_templates import IMAGENET_TEMPLATES, IMAGENET_TEMPLATES_SELECT
12 |
13 | CUSTOM_TEMPLATES = {
14 | "OxfordPets": "a photo of a {}, a type of pet.",
15 | "OxfordFlowers": "a photo of a {}, a type of flower.",
16 | "FGVCAircraft": "a photo of a {}, a type of aircraft.",
17 | "DescribableTextures": "{} texture.",
18 | "EuroSAT": "a centered satellite photo of {}.",
19 | "StanfordCars": "a photo of a {}.",
20 | "Food101": "a photo of {}, a type of food.",
21 | "SUN397": "a photo of a {}.",
22 | "Caltech101": "a photo of a {}.",
23 | "UCF101": "a photo of a person doing {}.",
24 | "ImageNet": "a photo of a {}.",
25 | "ImageNetSketch": "a photo of a {}.",
26 | "ImageNetV2": "a photo of a {}.",
27 | "ImageNetA": "a photo of a {}.",
28 | "ImageNetR": "a photo of a {}.",
29 | }
30 |
31 |
32 | @TRAINER_REGISTRY.register()
33 | class ZeroshotCLIP(TrainerX):
34 | def build_model(self):
35 | cfg = self.cfg
36 | classnames = self.dm.dataset.classnames
37 |
38 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
39 | clip_model = load_clip_to_cpu(cfg)
40 | clip_model.to(self.device)
41 |
42 | temp = CUSTOM_TEMPLATES[cfg.DATASET.NAME]
43 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
44 | print(f"Prompts: {prompts}")
45 | prompts = torch.cat([clip.tokenize(p) for p in prompts])
46 | prompts = prompts.to(self.device)
47 |
48 | with torch.no_grad():
49 | text_features = clip_model.encode_text(prompts)
50 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
51 |
52 | self.text_features = text_features
53 | self.clip_model = clip_model
54 |
55 | def model_inference(self, image):
56 | image_features = self.clip_model.encode_image(image)
57 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
58 | logit_scale = self.clip_model.logit_scale.exp()
59 | logits = logit_scale * image_features @ self.text_features.t()
60 | return logits
61 |
62 |
63 | @TRAINER_REGISTRY.register()
64 | class ZeroshotCLIP2(ZeroshotCLIP):
65 | """Prompt ensembling."""
66 |
67 | # templates = IMAGENET_TEMPLATES
68 | templates = IMAGENET_TEMPLATES_SELECT
69 |
70 | def build_model(self):
71 | cfg = self.cfg
72 | classnames = self.dm.dataset.classnames
73 |
74 | print(f"Loading CLIP (backbone: {cfg.MODEL.BACKBONE.NAME})")
75 | clip_model = load_clip_to_cpu(cfg)
76 | clip_model.to(self.device)
77 |
78 | for params in clip_model.parameters():
79 | params.requires_grad_(False)
80 |
81 | # add custom-made prompt
82 | if cfg.DATASET.NAME != "ImageNet":
83 | self.templates += [CUSTOM_TEMPLATES[cfg.DATASET.NAME]]
84 |
85 | num_temp = len(self.templates)
86 | print(f"Prompt ensembling (n={num_temp})")
87 |
88 | mean_text_features = 0
89 | for i, temp in enumerate(self.templates):
90 | prompts = [temp.format(c.replace("_", " ")) for c in classnames]
91 | prompts = torch.cat([clip.tokenize(p) for p in prompts]).to(self.device)
92 | text_features = clip_model.encode_text(prompts)
93 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
94 | mean_text_features = mean_text_features + text_features
95 | mean_text_features = mean_text_features / num_temp
96 | mean_text_features = mean_text_features / mean_text_features.norm(dim=-1, keepdim=True)
97 |
98 | self.text_features = mean_text_features
99 | self.clip_model = clip_model
100 |
--------------------------------------------------------------------------------