├── AFT_model_weights
└── readme.md
├── README.md
├── TTC.sh
├── code
├── attacks.py
├── func.py
├── models
│ ├── download_models.sh
│ ├── model.py
│ └── prompters.py
├── replace
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── datasets
│ │ ├── caltech.py
│ │ ├── country211.py
│ │ ├── dtd.py
│ │ ├── eurosat.py
│ │ ├── fgvc_aircraft.py
│ │ ├── flowers102.py
│ │ ├── folder.py
│ │ ├── food101.py
│ │ ├── oxford_iiit_pet.py
│ │ ├── pcam.py
│ │ ├── stanford_cars.py
│ │ └── sun397.py
│ ├── model.py
│ └── simple_tokenizer.py
├── test_time_counterattack.py
└── utils.py
├── data
└── readme.md
├── download_weights.py
├── environment.yml
├── figures
├── fig2b.png
└── teaser.png
├── poster_CVPR_XING.png
├── requirements.txt
└── support
├── imagenet_classes_names.txt
├── imagenet_refined_labels.json
├── readme.md
└── tinyimagenet_refined_labels.json
/AFT_model_weights/readme.md:
--------------------------------------------------------------------------------
1 | This folder keeps the model weights of AFT methods, if one wants to test TTC on adversarially finetuned models such as TeCoA.
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CLIP-Test-time-Counterattacks 🚀
2 | This is the official code of our work:
3 |
4 | [CLIP is Strong Enough to Fight Back: Test-time Counterattacks towards Zero-shot Adversarial Robustness of CLIP](https://arxiv.org/abs/2503.03613). Songlong Xing, Zhengyu Zhao, Nicu Sebe. To appear in CVPR 2025.
5 |
6 | > **Abstract**: Despite its prevalent use in image-text matching tasks in a zero-shot manner, CLIP has been shown to be highly vulnerable to adversarial perturbations added onto images. Recent studies propose to finetune the vision encoder of CLIP with adversarial samples generated on the fly, and show improved robustness against adversarial attacks on a spectrum of downstream datasets, a property termed as zero-shot robustness. In this paper, we show that malicious perturbations that seek to maximise the classification loss lead to `falsely stable' images, and propose to leverage the pre-trained vision encoder of CLIP to counterattack such adversarial images during inference to achieve robustness. Our paradigm is simple and training-free, providing the first method to defend CLIP from adversarial attacks at test time, which is orthogonal to existing methods aiming to boost zero-shot adversarial robustness of CLIP. We conduct experiments across 16 classification datasets, and demonstrate stable and consistent gains compared to test-time defence methods adapted from existing adversarial robustness studies that do not rely on external networks, without noticeably impairing performance on clean images. We also show that our paradigm can be employed on CLIP models that have been adversarially finetuned to further enhance their robustness at test time. Our code is available here.
7 |
8 |
9 |
10 |
11 |
12 | ## 🛠️ Setup
13 | ### Environment
14 | Make sure you have installed conda and use the following commands to get the env ready!
15 | ```bash
16 | conda env create -f environment.yml
17 | conda activate TTC
18 | pip install -r requirements.txt
19 | ```
20 | ### Data preparation
21 | Please download and unzip all the raw datasets into `./data`. It's okay to skip this step because `torchvision.datasets` will automatically download (most of) them as you run the code if you don't already have them.
22 |
23 | ## 🔥 Run
24 | The file `code/test_time_counterattack.py` contains the main program. To reproduce the results of TTC employed on the original CLIP (Tab.1 in the paper), run the following command:
25 | ```bash
26 | conda activate TTC
27 | python code/test_time_counterattack.py --batch_size 256 --test_attack_type 'pgd' --test_eps 1 --test_numsteps 10 --test_stepsize 1 --outdir 'TTC_results' --seed 1 --ttc_eps 4 --beta 2 --tau_thres 0.2 --ttc_numsteps 2
28 | ```
29 | This command can also be found in `TTC.sh`. You can run `bash TTC.sh` to avoid typing the lengthy command in the terminal.
30 | The results will be saved in the folder specified by `--outdir`.
31 |
32 | To employ TTC on adversarially finetuned models, please specify the path to the finetuned model weights with `--victim_resume`. We provide at this [Google drive link](https://drive.google.com/drive/folders/1aDChTWGOrqK6IrIKVqSyMf4IdIBHEiJr?usp=drive_link) the checkpoints we have obtained via adversarial finetuning, which are used for producing the results reported in Tab.3 in the paper. Please download those checkpoints and keep them in the folder `./AFT_model_weights`. This can be done by running `python download_weights.py`. Alternatively, you can implement adversarial finetuning methods on your own and put the model weights in `./AFT_model_weights`.
33 |
34 | After the AFT checkpoints are in place, you can run the same command as in `TTC.sh` with an additional argument `--victim_resume` specifying the path to the model weights. For example, use `--victim_resume "./AFT_model_weights/TeCoA.pth.tar"` if you want to employ TTC on top of TeCoA.
35 |
36 | ## 📬 Updates
37 | 7 Mar 2025: **Please stay tuned for instructions to run the code!**
38 |
39 | 10 Mar 2025: **Setup** updated!
40 |
41 | 10 Mar 2025: **Run Instructions** updated!
42 |
43 | ## 🗂️ Reference
44 | ```
45 | @InProceedings{Xing_2025_CVPR,
46 | author = {Xing, Songlong and Zhao, Zhengyu and Sebe, Nicu},
47 | title = {CLIP is Strong Enough to Fight Back: Test-time Counterattacks towards Zero-shot Adversarial Robustness of CLIP},
48 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
49 | month = {June},
50 | year = {2025},
51 | pages = {15172-15182}
52 | }
53 | ```
54 |
55 | ## Ackowledgement
56 | Our code is developed based on the open-source code of [TeCoA (ICLR-23)](https://github.com/cvlab-columbia/ZSRobust4FoundationModel). We thank the authors for their brilliant work. Please also consider citing their paper:
57 | ```
58 | @inproceedings{maounderstanding,
59 | title={Understanding Zero-shot Adversarial Robustness for Large-Scale Models},
60 | author={Mao, Chengzhi and Geng, Scott and Yang, Junfeng and Wang, Xin and Vondrick, Carl},
61 | booktitle={The Eleventh International Conference on Learning Representations},
62 | year={2023}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/TTC.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | conda activate TTC
4 |
5 | cd ./
6 |
7 | python code/test_time_counterattack.py \
8 | --batch_size 256 \
9 | --test_attack_type 'pgd' \
10 | --test_eps 1 \
11 | --test_numsteps 10 \
12 | --test_stepsize 1 \
13 | --outdir 'TTC_results' \
14 | --seed 1 \
15 | --ttc_eps 4 \
16 | --beta 2 \
17 | --tau_thres 0.2 \
18 | --ttc_numsteps 2
19 |
--------------------------------------------------------------------------------
/code/attacks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from utils import one_hot_embedding
3 | # from models.model import *
4 | import torch.nn.functional as F
5 | import functools
6 | from autoattack import AutoAttack
7 | from func import clip_img_preprocessing, multiGPU_CLIP, multiGPU_CLIP_image_logits
8 |
9 | lower_limit, upper_limit = 0, 1
10 | def clamp(X, lower_limit, upper_limit):
11 | return torch.max(torch.min(X, upper_limit), lower_limit)
12 |
13 |
14 | def attack_CW(args, prompter, model, model_text, model_image, add_prompter, criterion, X, target, text_tokens, alpha,
15 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0):
16 | delta = torch.zeros_like(X).cuda()
17 | if norm == "l_inf":
18 | delta.uniform_(-epsilon, epsilon)
19 | elif norm == "l_2":
20 | delta.normal_()
21 | d_flat = delta.view(delta.size(0), -1)
22 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
23 | r = torch.zeros_like(n).uniform_(0, 1)
24 | delta *= r / n * epsilon
25 | else:
26 | raise ValueError
27 | delta = clamp(delta, lower_limit - X, upper_limit - X)
28 | delta.requires_grad = True
29 | for _ in range(attack_iters):
30 | # output = model(normalize(X ))
31 |
32 | prompted_images = prompter(clip_img_preprocessing(X + delta))
33 | prompt_token = add_prompter()
34 |
35 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, prompted_images, text_tokens, prompt_token)
36 |
37 | num_class = output.size(1)
38 | label_mask = one_hot_embedding(target, num_class)
39 | label_mask = label_mask.cuda()
40 |
41 | correct_logit = torch.sum(label_mask * output, dim=1)
42 | wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, axis=1)
43 |
44 | # loss = criterion(output, target)
45 | loss = - torch.sum(F.relu(correct_logit - wrong_logit + 50))
46 |
47 | loss.backward()
48 | grad = delta.grad.detach()
49 | d = delta[:, :, :, :]
50 | g = grad[:, :, :, :]
51 | x = X[:, :, :, :]
52 | if norm == "l_inf":
53 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
54 | elif norm == "l_2":
55 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
56 | scaled_g = g / (g_norm + 1e-10)
57 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
58 | d = clamp(d, lower_limit - x, upper_limit - x)
59 | delta.data[:, :, :, :] = d
60 | delta.grad.zero_()
61 |
62 | return delta
63 |
64 |
65 | def attack_CW_noprompt(args, prompter, model, model_text, model_image, criterion, X, target, text_tokens, alpha,
66 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0):
67 | delta = torch.zeros_like(X).cuda()
68 | if norm == "l_inf":
69 | delta.uniform_(-epsilon, epsilon)
70 | elif norm == "l_2":
71 | delta.normal_()
72 | d_flat = delta.view(delta.size(0), -1)
73 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
74 | r = torch.zeros_like(n).uniform_(0, 1)
75 | delta *= r / n * epsilon
76 | else:
77 | raise ValueError
78 | delta = clamp(delta, lower_limit - X, upper_limit - X)
79 | delta.requires_grad = True
80 | for _ in range(attack_iters):
81 | # output = model(normalize(X ))
82 |
83 | _images = clip_img_preprocessing(X + delta)
84 | # output, _ = model(_images, text_tokens)
85 |
86 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, _images, text_tokens, None)
87 |
88 | num_class = output.size(1)
89 | label_mask = one_hot_embedding(target, num_class)
90 | label_mask = label_mask.cuda()
91 |
92 | correct_logit = torch.sum(label_mask * output, dim=1)
93 | wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, axis=1)
94 |
95 | # loss = criterion(output, target)
96 | loss = - torch.sum(F.relu(correct_logit - wrong_logit + 50))
97 |
98 | loss.backward()
99 | grad = delta.grad.detach()
100 | d = delta[:, :, :, :]
101 | g = grad[:, :, :, :]
102 | x = X[:, :, :, :]
103 | if norm == "l_inf":
104 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
105 | elif norm == "l_2":
106 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
107 | scaled_g = g / (g_norm + 1e-10)
108 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
109 | d = clamp(d, lower_limit - x, upper_limit - x)
110 | delta.data[:, :, :, :] = d
111 | delta.grad.zero_()
112 |
113 | return delta
114 |
115 | def attack_unlabelled(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0,
116 | visual_model_orig=None):
117 | delta = torch.zeros_like(X)
118 | if norm == "l_inf":
119 | delta.uniform_(-epsilon, epsilon)
120 | elif norm == "l_2":
121 | delta.normal_()
122 | d_flat = delta.view(delta.size(0), -1)
123 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
124 | r = torch.zeros_like(n).uniform_(0, 1)
125 | delta *= r / n * epsilon
126 | else:
127 | raise ValueError
128 |
129 | # turn off model parameters temmporarily
130 | tunable_param_names = []
131 | for n,p in model.module.named_parameters():
132 | if p.requires_grad:
133 | tunable_param_names.append(n)
134 | p.requires_grad = False
135 |
136 | delta = clamp(delta, lower_limit - X, upper_limit - X)
137 | delta.requires_grad = True
138 |
139 | if attack_iters <= 0:
140 | return delta
141 |
142 | prompt_token = add_prompter()
143 | with torch.no_grad():
144 | if visual_model_orig is None: # use the model itself as anchor
145 | X_ori_reps = model.module.encode_image(
146 | prompter(clip_img_preprocessing(X)), prompt_token
147 | )
148 | else: # use original frozen model as anchor
149 | X_ori_reps = visual_model_orig.module(
150 | prompter(clip_img_preprocessing(X)), prompt_token
151 | )
152 |
153 | for _ in range(attack_iters):
154 |
155 | prompted_images = prompter(clip_img_preprocessing(X + delta))
156 |
157 | X_att_reps = model.module.encode_image(prompted_images, prompt_token)
158 | # l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))**(0.5)).sum()
159 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum()
160 |
161 | grad = torch.autograd.grad(l2_loss, delta)[0]
162 | d = delta[:, :, :, :]
163 | g = grad[:, :, :, :]
164 | x = X[:, :, :, :]
165 | if norm == "l_inf":
166 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
167 | elif norm == "l_2":
168 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
169 | scaled_g = g / (g_norm + 1e-10)
170 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
171 | d = clamp(d, lower_limit - x, upper_limit - x)
172 | delta.data[:, :, :, :] = d
173 |
174 | # # Turn on model parameters
175 | for n,p in model.module.named_parameters():
176 | if n in tunable_param_names:
177 | p.requires_grad = True
178 |
179 | return delta
180 |
181 | #### opposite update direction of attack_unlabelled()
182 | def attack_unlabelled_opp(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0,
183 | visual_model_orig=None):
184 | delta = torch.zeros_like(X)
185 | if norm == "l_inf":
186 | delta.uniform_(-epsilon, epsilon)
187 | elif norm == "l_2":
188 | delta.normal_()
189 | d_flat = delta.view(delta.size(0), -1)
190 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
191 | r = torch.zeros_like(n).uniform_(0, 1)
192 | delta *= r / n * epsilon
193 | else:
194 | raise ValueError
195 |
196 | # turn off model parameters temmporarily
197 | tunable_param_names = []
198 | for n,p in model.module.named_parameters():
199 | if p.requires_grad:
200 | tunable_param_names.append(n)
201 | p.requires_grad = False
202 |
203 | delta = clamp(delta, lower_limit - X, upper_limit - X)
204 | delta.requires_grad = True
205 |
206 | prompt_token = add_prompter()
207 | with torch.no_grad():
208 | if visual_model_orig is None: # use the model itself as anchor
209 | X_ori_reps = model.module.encode_image(
210 | prompter(clip_img_preprocessing(X)), prompt_token
211 | )
212 | else: # use original frozen model as anchor
213 | X_ori_reps = visual_model_orig.module(
214 | prompter(clip_img_preprocessing(X)), prompt_token
215 | )
216 |
217 | for _ in range(attack_iters):
218 |
219 | prompted_images = prompter(clip_img_preprocessing(X + delta))
220 |
221 | X_att_reps = model.module.encode_image(prompted_images, prompt_token)
222 | # l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))**(0.5)).sum()
223 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum()
224 |
225 | grad = torch.autograd.grad(l2_loss, delta)[0]
226 | d = delta[:, :, :, :]
227 | g = grad[:, :, :, :]
228 | x = X[:, :, :, :]
229 | if norm == "l_inf":
230 | # d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
231 | d = torch.clamp(d - alpha * torch.sign(g), min=-epsilon, max=epsilon)
232 | elif norm == "l_2":
233 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
234 | scaled_g = g / (g_norm + 1e-10)
235 | # d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
236 | d = (d - scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
237 | d = clamp(d, lower_limit - x, upper_limit - x)
238 | delta.data[:, :, :, :] = d
239 |
240 | # # Turn on model parameters
241 | for n,p in model.module.named_parameters():
242 | if n in tunable_param_names:
243 | p.requires_grad = True
244 |
245 | return delta
246 |
247 | def attack_unlabelled_cosine(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0,
248 | visual_model_orig=None):
249 | # unlabelled attack to maximise cosine similarity between the attacked image
250 | # and the original image, computed by PGD
251 | delta = torch.zeros_like(X)
252 | if norm == "l_inf":
253 | delta.uniform_(-epsilon, epsilon)
254 | elif norm == "l_2":
255 | delta.normal_()
256 | d_flat = delta.view(delta.size(0), -1)
257 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
258 | r = torch.zeros_like(n).uniform_(0, 1)
259 | delta *= r / n * epsilon
260 | else:
261 | raise ValueError
262 |
263 | # turn off model parameters temmporarily
264 | tunable_param_names = []
265 | for n,p in model.module.named_parameters():
266 | if p.requires_grad:
267 | tunable_param_names.append(n)
268 | p.requires_grad = False
269 |
270 | delta = clamp(delta, lower_limit - X, upper_limit - X)
271 | delta.requires_grad = True
272 |
273 | prompt_token = add_prompter()
274 | with torch.no_grad():
275 | if visual_model_orig is None: # use the model itself as anchor
276 | X_ori_reps = model.module.encode_image(
277 | prompter(clip_img_preprocessing(X)), prompt_token
278 | )
279 | else: # use original frozen model as anchor
280 | X_ori_reps = visual_model_orig.module(
281 | prompter(clip_img_preprocessing(X)), prompt_token
282 | )
283 | # X_ori_reps_norm = X_ori_reps / X_ori_reps.norm(dim=-1, keepdim=True)
284 |
285 | for _ in range(attack_iters):
286 |
287 | prompted_images = prompter(clip_img_preprocessing(X + delta))
288 |
289 | X_att_reps = model.module.encode_image(prompted_images, prompt_token) # [bs, d_out]
290 | # X_att_reps_norm = X_att_reps / X_att_reps.norm(dim=-1, keepdim=True)
291 |
292 | cos_loss = 1 - F.cosine_similarity(X_att_reps, X_ori_reps) # [bs]
293 |
294 | grad = torch.autograd.grad(cos_loss, delta)[0]
295 | d = delta[:, :, :, :]
296 | g = grad[:, :, :, :]
297 | x = X[:, :, :, :]
298 | if norm == "l_inf":
299 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
300 | elif norm == "l_2":
301 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
302 | scaled_g = g / (g_norm + 1e-10)
303 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
304 | d = clamp(d, lower_limit - x, upper_limit - x)
305 | delta.data[:, :, :, :] = d
306 |
307 | # # Turn on model parameters
308 | for n,p in model.module.named_parameters():
309 | if n in tunable_param_names:
310 | p.requires_grad = True
311 |
312 | return delta
313 |
314 |
315 | def attack_pgd(args, prompter, model, model_text, model_image, add_prompter, criterion, X, target, alpha,
316 | attack_iters, norm, text_tokens=None, restarts=1, early_stop=True, epsilon=0, dataset_name=None):
317 | delta = torch.zeros_like(X).cuda()
318 | if norm == "l_inf":
319 | delta.uniform_(-epsilon, epsilon)
320 | elif norm == "l_2":
321 | delta.normal_()
322 | d_flat = delta.view(delta.size(0), -1)
323 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
324 | r = torch.zeros_like(n).uniform_(0, 1)
325 | delta *= r / n * epsilon
326 | else:
327 | raise ValueError
328 | delta = clamp(delta, lower_limit - X, upper_limit - X)
329 | delta.requires_grad = True
330 |
331 | # turn off model parameters temmporarily
332 | tunable_param_names = []
333 | for n,p in model.module.named_parameters():
334 | if p.requires_grad:
335 | tunable_param_names.append(n)
336 | p.requires_grad = False
337 |
338 | for iter in range(attack_iters):
339 |
340 | prompted_images = prompter(clip_img_preprocessing(X + delta))
341 | prompt_token = add_prompter()
342 |
343 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, prompted_images,
344 | text_tokens=text_tokens, prompt_token=prompt_token, dataset_name=dataset_name)
345 |
346 | loss = criterion(output, target)
347 |
348 | grad = torch.autograd.grad(loss, delta)[0]
349 |
350 | d = delta[:, :, :, :]
351 | g = grad[:, :, :, :]
352 | x = X[:, :, :, :]
353 | if norm == "l_inf":
354 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
355 | elif norm == "l_2":
356 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
357 | scaled_g = g / (g_norm + 1e-10)
358 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
359 | d = clamp(d, lower_limit - x, upper_limit - x)
360 | delta.data[:, :, :, :] = d
361 | # delta.grad.zero_()
362 |
363 | # # Turn on model parameters
364 | for n,p in model.module.named_parameters():
365 | if n in tunable_param_names:
366 | p.requires_grad = True
367 |
368 | return delta
369 |
370 |
371 | def attack_pgd_noprompt(args, prompter, model, model_text, model_image, criterion, X, target, text_tokens, alpha,
372 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0):
373 | delta = torch.zeros_like(X).cuda()
374 | if norm == "l_inf":
375 | delta.uniform_(-epsilon, epsilon)
376 | elif norm == "l_2":
377 | delta.normal_()
378 | d_flat = delta.view(delta.size(0), -1)
379 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
380 | r = torch.zeros_like(n).uniform_(0, 1)
381 | delta *= r / n * epsilon
382 | else:
383 | raise ValueError
384 | delta = clamp(delta, lower_limit - X, upper_limit - X)
385 | delta.requires_grad = True
386 | for _ in range(attack_iters):
387 |
388 | _images = clip_img_preprocessing(X + delta)
389 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, _images, text_tokens, None)
390 |
391 | loss = criterion(output, target)
392 |
393 | loss.backward()
394 | grad = delta.grad.detach()
395 | d = delta[:, :, :, :]
396 | g = grad[:, :, :, :]
397 | x = X[:, :, :, :]
398 | if norm == "l_inf":
399 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
400 | elif norm == "l_2":
401 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
402 | scaled_g = g / (g_norm + 1e-10)
403 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
404 | d = clamp(d, lower_limit - x, upper_limit - x)
405 | delta.data[:, :, :, :] = d
406 | delta.grad.zero_()
407 |
408 | return delta
409 |
410 | def attack_auto(model, images, target, text_tokens, prompter, add_prompter,
411 | attacks_to_run=['apgd-ce', 'apgd-dlr'], epsilon=0):
412 |
413 | forward_pass = functools.partial(
414 | multiGPU_CLIP_image_logits,
415 | model=model, text_tokens=text_tokens,
416 | prompter=None, add_prompter=None
417 | )
418 |
419 | adversary = AutoAttack(forward_pass, norm='Linf', eps=epsilon, version='standard', verbose=False)
420 | adversary.attacks_to_run = attacks_to_run
421 | x_adv = adversary.run_standard_evaluation(images, target, bs=images.shape[0])
422 | return x_adv
--------------------------------------------------------------------------------
/code/func.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | device = "cuda" if torch.cuda.is_available() else "cpu"
5 |
6 | CIFAR100_MEAN = (0.48145466, 0.4578275, 0.40821073)
7 | CIFAR100_STD = (0.26862954, 0.26130258, 0.27577711)
8 |
9 | mu = torch.tensor(CIFAR100_MEAN).view(3, 1, 1).to(device)
10 | std = torch.tensor(CIFAR100_STD).view(3, 1, 1).to(device)
11 |
12 | def normalize(X):
13 | return (X - mu) / std
14 | def clip_img_preprocessing(X):
15 | img_size = 224
16 | X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic')
17 | X = normalize(X)
18 | return X
19 |
20 | def rev_normalize(X):
21 | return X * std + mu
22 | def reverse_clip_img_preprocessing(X):
23 | X = rev_normalize(X)
24 | return X
25 |
26 | def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None):
27 | image_tokens = clip_img_preprocessing(images)
28 | prompt_token = None if add_prompter is None else add_prompter()
29 | if prompter is not None:
30 | image_tokens = prompter(image_tokens)
31 | return multiGPU_CLIP(None, None, None, model, image_tokens, text_tokens, prompt_token=prompt_token)[0]
32 |
33 |
34 | def multiGPU_CLIP(args, model_image, model_text, model, images, text_tokens=None, prompt_token=None, dataset_name=None):
35 | if prompt_token is not None:
36 | bs = images.size(0)
37 | prompt_token = prompt_token.repeat(bs, 1, 1)
38 | if args is not None and dataset_name is not None:
39 | cache_prompts = os.path.join(args.cache, f"refined_{dataset_name.lower()}_prompts.pt")
40 | cache_wordnet_def = os.path.join(args.cache, f"refined_{dataset_name.lower()}_wn_def.pt")
41 | else:
42 | cache_prompts, cache_wordnet_def = None, None
43 | if cache_prompts is not None and os.path.exists(cache_prompts):
44 | text_features = torch.load(cache_prompts).to('cpu')
45 | if args.advanced_text == "wordnet_def":
46 | a_text_features = torch.load(cache_wordnet_def).to('cpu')
47 | text_features = (text_features + a_text_features) * 0.5
48 | else:
49 | text_features = model.module.encode_text(text_tokens)
50 | text_features = text_features / text_features.norm(dim=1, keepdim=True) # [n_class, d_emb]
51 | text_features = text_features.to(device)
52 | image_features = model.module.encode_image(images, prompt_token)
53 | image_features = image_features / image_features.norm(dim=1, keepdim=True) # [bs, d_emb]
54 | logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp()
55 | logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp()
56 |
57 | return logits_per_image, logits_per_text, image_features, text_features
58 |
59 | def kl_div(p_logits, q_logits):
60 | # p_logits, q_logits [bs, n_class] both have been softmax normalized
61 | kl_divs = (p_logits * (p_logits.log() - q_logits.log())).sum(dim=1) # [bs,]
62 | return kl_divs.mean()
63 |
64 | def get_loss_general(tgt_logits, a_images, model_image_copy, text_features):
65 | # feed the perturbed image into the original visual encoder, regularise the predictive logits
66 | image_features = model_image_copy(a_images) # [bs, d_emb]
67 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
68 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
69 | logits_per_image_ = image_features @ text_features.t() * model_image_copy.module.logit_scale.exp() # [bs, n_class]
70 | l_general = kl_div(tgt_logits.softmax(dim=1), logits_per_image_.softmax(dim=1))
71 | # l_general = criterion_(F.log_softmax(logits_per_image_, dim=1), F.softmax(tgt_logits))
72 | return l_general
73 |
74 | def get_loss_clean(clean_images, tgt_logits, model, text_features, prompt_token=None):
75 | # feed the clean image into the visual encoder, regularise the predictive logits
76 | image_features = model.module.encode_image(clean_images, prompt_token) # [bs, d_emb]
77 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
78 | logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp() # [bs, n_class]
79 | l_clean = kl_div(tgt_logits.softmax(dim=1), logits_per_image.softmax(dim=1))
80 | # l_clean = criterion_(F.log_softmax(logits_per_image, dim=1), F.softmax(tgt_logits, dim=1))
81 | return l_clean
--------------------------------------------------------------------------------
/code/models/download_models.sh:
--------------------------------------------------------------------------------
1 | # CLIP
2 | pip install git+https://github.com/openai/CLIP.git
3 |
--------------------------------------------------------------------------------
/code/models/model.py:
--------------------------------------------------------------------------------
1 | import torch, clip
2 |
3 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073)
4 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711)
5 |
6 | mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1).cuda()
7 | std = torch.tensor(IMAGENET_STD).view(3, 1, 1).cuda()
8 |
9 | def normalize(X):
10 | return (X - mu) / std
11 |
12 | def clip_img_preprocessing(X):
13 | img_size = 224
14 | X = torch.nn.functional.upsample(X, size=(img_size, img_size), mode='bicubic')
15 | X = normalize(X)
16 | return X
17 |
18 | def create_logits(x1, x2, logit_scale):
19 | x1 = x1 / x1.norm(dim=-1, keepdim=True)
20 | x2 = x2 / x2.norm(dim=-1, keepdim=True)
21 | # cosine similarity as logits
22 | logits_per_x1 = logit_scale * x1 @ x2.t()
23 | logits_per_x2 = logit_scale * x2 @ x1.t()
24 | return logits_per_x1, logits_per_x2
25 |
26 | def multiGPU_CLIP(clip_model, images, text_tokens, prompt_token=None):
27 | if prompt_token is not None:
28 | bs = images.size(0)
29 | prompt_token = prompt_token.repeat(bs, 1, 1)
30 |
31 | img_embed, scale_text_embed = clip_model(images, text_tokens, prompt_token)
32 | logits_per_image = img_embed @ scale_text_embed.t()
33 | logits_per_text = scale_text_embed @ img_embed.t()
34 | return logits_per_image, logits_per_text
35 |
--------------------------------------------------------------------------------
/code/models/prompters.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 |
6 | import torch
7 | from torch import nn
8 |
9 | from einops import rearrange, repeat
10 | from einops.layers.torch import Rearrange
11 |
12 | #############
13 |
14 | class PreNorm(nn.Module):
15 | def __init__(self, dim, fn):
16 | super().__init__()
17 | self.norm = nn.LayerNorm(dim)
18 | self.fn = fn
19 | def forward(self, x, **kwargs):
20 | return self.fn(self.norm(x), **kwargs)
21 |
22 | class FeedForward(nn.Module):
23 | def __init__(self, dim, hidden_dim, dropout = 0.):
24 | super().__init__()
25 | self.net = nn.Sequential(
26 | nn.Linear(dim, hidden_dim),
27 | nn.GELU(),
28 | nn.Dropout(dropout),
29 | nn.Linear(hidden_dim, dim),
30 | nn.Dropout(dropout)
31 | )
32 | def forward(self, x):
33 | return self.net(x)
34 |
35 | class Attention(nn.Module):
36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
37 | super().__init__()
38 | inner_dim = dim_head * heads
39 | project_out = not (heads == 1 and dim_head == dim)
40 |
41 | self.heads = heads
42 | self.scale = dim_head ** -0.5
43 |
44 | self.attend = nn.Softmax(dim = -1)
45 | self.dropout = nn.Dropout(dropout)
46 |
47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
48 |
49 | self.to_out = nn.Sequential(
50 | nn.Linear(inner_dim, dim),
51 | nn.Dropout(dropout)
52 | ) if project_out else nn.Identity()
53 |
54 | def forward(self, x):
55 | qkv = self.to_qkv(x).chunk(3, dim = -1)
56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
57 |
58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
59 |
60 | attn = self.attend(dots)
61 | attn = self.dropout(attn)
62 |
63 | out = torch.matmul(attn, v)
64 | out = rearrange(out, 'b h n d -> b n (h d)')
65 | return self.to_out(out)
66 |
67 | class Transformer(nn.Module):
68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
69 | super().__init__()
70 | self.layers = nn.ModuleList([])
71 | for _ in range(depth):
72 | self.layers.append(nn.ModuleList([
73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
75 | ]))
76 | def forward(self, x):
77 | for attn, ff in self.layers:
78 | x = attn(x) + x
79 | x = ff(x) + x
80 | return x
81 | #############3
82 |
83 |
84 |
85 |
86 | class NullPrompter(nn.Module):
87 | def __init__(self):
88 | super(NullPrompter, self).__init__()
89 | pass
90 |
91 | def forward(self, x):
92 | return x
93 |
94 |
95 | class PadPrompter(nn.Module):
96 | def __init__(self, args):
97 | super(PadPrompter, self).__init__()
98 | pad_size = args.prompt_size
99 | image_size = args.image_size
100 |
101 | self.base_size = image_size - pad_size*2
102 | self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
103 | self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size]))
104 | self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
105 | self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size]))
106 |
107 | def forward(self, x):
108 | base = torch.zeros(1, 3, self.base_size, self.base_size).cuda()
109 | prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3)
110 | prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2)
111 | prompt = torch.cat(x.size(0) * [prompt])
112 |
113 | return x + prompt
114 |
115 | class TokenPrompter(nn.Module):
116 | def __init__(self, prompt_len) -> None:
117 | super(TokenPrompter, self).__init__()
118 |
119 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768]))
120 |
121 | def forward(self):
122 | return self.prompt
123 |
124 |
125 | class TokenPrompter_w_pos(nn.Module):
126 | def __init__(self, prompt_len) -> None:
127 | super(TokenPrompter_w_pos, self).__init__()
128 |
129 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768]))
130 | self.pos_embedding = nn.Parameter(torch.randn(1, prompt_len, 1))
131 |
132 | def forward(self):
133 | return self.prompt + self.pos_embedding
134 |
135 |
136 | class TokenPrompter_w_pos_TransformerGEN(nn.Module):
137 | def __init__(self, prompt_len) -> None:
138 | super(TokenPrompter_w_pos_TransformerGEN, self).__init__()
139 |
140 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768]))
141 |
142 | self.dropout = nn.Dropout(0)
143 | self.transformer = Transformer(768, 3, 4, 768, 768)
144 |
145 | self.pos_embedding = nn.Parameter(torch.randn(1, prompt_len, 1))
146 |
147 | def forward(self):
148 | return self.transformer(self.prompt + self.pos_embedding)
149 |
150 | class FixedPatchPrompter(nn.Module):
151 | def __init__(self, args):
152 | super(FixedPatchPrompter, self).__init__()
153 | self.isize = args.image_size
154 | self.psize = args.prompt_size
155 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize]))
156 |
157 | def forward(self, x):
158 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda()
159 | prompt[:, :, :self.psize, :self.psize] = self.patch
160 |
161 | return x + prompt
162 |
163 |
164 | class RandomPatchPrompter(nn.Module):
165 | def __init__(self, args):
166 | super(RandomPatchPrompter, self).__init__()
167 | self.isize = args.image_size
168 | self.psize = args.prompt_size
169 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize]))
170 |
171 | def forward(self, x):
172 | x_ = np.random.choice(self.isize - self.psize)
173 | y_ = np.random.choice(self.isize - self.psize)
174 |
175 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda()
176 | prompt[:, :, x_:x_ + self.psize, y_:y_ + self.psize] = self.patch
177 |
178 | return x + prompt
179 |
180 |
181 | def padding(args):
182 | return PadPrompter(args)
183 |
184 |
185 | def fixed_patch(args):
186 | return FixedPatchPrompter(args)
187 |
188 |
189 | def random_patch(args):
190 | return RandomPatchPrompter(args)
191 |
192 | def null_patch(args):
193 | return NullPrompter()
194 |
--------------------------------------------------------------------------------
/code/replace/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/code/replace/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/code/replace/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import urllib
4 | import warnings
5 | from typing import Any, Union, List
6 | from pkg_resources import packaging
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from .model import build_model
14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
40 | }
41 |
42 |
43 | def _download(url: str, root: str):
44 | os.makedirs(root, exist_ok=True)
45 | filename = os.path.basename(url)
46 |
47 | expected_sha256 = url.split("/")[-2]
48 | download_target = os.path.join(root, filename)
49 |
50 | if os.path.exists(download_target) and not os.path.isfile(download_target):
51 | raise RuntimeError(f"{download_target} exists and is not a regular file")
52 |
53 | if os.path.isfile(download_target):
54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
55 | return download_target
56 | else:
57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
58 |
59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
61 | while True:
62 | buffer = source.read(8192)
63 | if not buffer:
64 | break
65 |
66 | output.write(buffer)
67 | loop.update(len(buffer))
68 |
69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
71 |
72 | return download_target
73 |
74 |
75 | def _convert_image_to_rgb(image):
76 | return image.convert("RGB")
77 |
78 |
79 | def _transform(n_px):
80 | return Compose([
81 | Resize(n_px, interpolation=BICUBIC),
82 | CenterCrop(n_px),
83 | _convert_image_to_rgb,
84 | ToTensor(),
85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
86 | ])
87 |
88 |
89 | def available_models() -> List[str]:
90 | """Returns the names of available CLIP models"""
91 | return list(_MODELS.keys())
92 |
93 |
94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, prompt_len: int = 0):
95 | """Load a CLIP model
96 |
97 | Parameters
98 | ----------
99 | name : str
100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
101 |
102 | device : Union[str, torch.device]
103 | The device to put the loaded model
104 |
105 | jit : bool
106 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
107 |
108 | download_root: str
109 | path to download the model files; by default, it uses "~/.cache/clip"
110 |
111 | Returns
112 | -------
113 | model : torch.nn.Module
114 | The CLIP model
115 |
116 | preprocess : Callable[[PIL.Image], torch.Tensor]
117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
118 | """
119 | if name in _MODELS:
120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
121 | elif os.path.isfile(name):
122 | model_path = name
123 | else:
124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
125 |
126 | with open(model_path, 'rb') as opened_file:
127 | try:
128 | # loading JIT archive
129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
130 | state_dict = None
131 | except RuntimeError:
132 | # loading saved state dict
133 | if jit:
134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
135 | jit = False
136 | state_dict = torch.load(opened_file, map_location="cpu")
137 |
138 | if not jit:
139 | model = build_model(state_dict or model.state_dict(), prompt_len).to(device)
140 | if str(device) == "cpu":
141 | model.float()
142 | return model, _transform(model.visual.input_resolution)
143 |
144 | # patch the device names
145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
147 |
148 | def patch_device(module):
149 | try:
150 | graphs = [module.graph] if hasattr(module, "graph") else []
151 | except RuntimeError:
152 | graphs = []
153 |
154 | if hasattr(module, "forward1"):
155 | graphs.append(module.forward1.graph)
156 |
157 | for graph in graphs:
158 | for node in graph.findAllNodes("prim::Constant"):
159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
160 | node.copyAttributes(device_node)
161 |
162 | model.apply(patch_device)
163 | patch_device(model.encode_image)
164 | patch_device(model.encode_text)
165 |
166 | # patch dtype to float32 on CPU
167 | if str(device) == "cpu":
168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
170 | float_node = float_input.node()
171 |
172 | def patch_float(module):
173 | try:
174 | graphs = [module.graph] if hasattr(module, "graph") else []
175 | except RuntimeError:
176 | graphs = []
177 |
178 | if hasattr(module, "forward1"):
179 | graphs.append(module.forward1.graph)
180 |
181 | for graph in graphs:
182 | for node in graph.findAllNodes("aten::to"):
183 | inputs = list(node.inputs())
184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
185 | if inputs[i].node()["value"] == 5:
186 | inputs[i].node().copyAttributes(float_node)
187 |
188 | model.apply(patch_float)
189 | patch_float(model.encode_image)
190 | patch_float(model.encode_text)
191 |
192 | model.float()
193 |
194 | return model, _transform(model.input_resolution.item())
195 |
196 |
197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
198 | """
199 | Returns the tokenized representation of given input string(s)
200 |
201 | Parameters
202 | ----------
203 | texts : Union[str, List[str]]
204 | An input string or a list of input strings to tokenize
205 |
206 | context_length : int
207 | The context length to use; all CLIP models use 77 as the context length
208 |
209 | truncate: bool
210 | Whether to truncate the text in case its encoding is longer than the context length
211 |
212 | Returns
213 | -------
214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
216 | """
217 | if isinstance(texts, str):
218 | texts = [texts]
219 |
220 | sot_token = _tokenizer.encoder["<|startoftext|>"]
221 | eot_token = _tokenizer.encoder["<|endoftext|>"]
222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
225 | else:
226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
227 |
228 | for i, tokens in enumerate(all_tokens):
229 | if len(tokens) > context_length:
230 | if truncate:
231 | tokens = tokens[:context_length]
232 | tokens[-1] = eot_token
233 | else:
234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
235 | result[i, :len(tokens)] = torch.tensor(tokens)
236 |
237 | return result
238 |
--------------------------------------------------------------------------------
/code/replace/datasets/caltech.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | from typing import Any, Callable, List, Optional, Union, Tuple
4 |
5 | from PIL import Image
6 |
7 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 |
11 | class Caltech101(VisionDataset):
12 | """`Caltech 101 `_ Dataset.
13 |
14 | .. warning::
15 |
16 | This class needs `scipy `_ to load target files from `.mat` format.
17 |
18 | Args:
19 | root (string): Root directory of dataset where directory
20 | ``caltech101`` exists or will be saved to if download is set to True.
21 | target_type (string or list, optional): Type of target to use, ``category`` or
22 | ``annotation``. Can also be a list to output a tuple with all specified
23 | target types. ``category`` represents the target class, and
24 | ``annotation`` is a list of points from a hand-generated outline.
25 | Defaults to ``category``.
26 | transform (callable, optional): A function/transform that takes in an PIL image
27 | and returns a transformed version. E.g, ``transforms.RandomCrop``
28 | target_transform (callable, optional): A function/transform that takes in the
29 | target and transforms it.
30 | download (bool, optional): If true, downloads the dataset from the internet and
31 | puts it in root directory. If dataset is already downloaded, it is not
32 | downloaded again.
33 | """
34 |
35 | def __init__(
36 | self,
37 | root: str,
38 | target_type: Union[List[str], str] = "category",
39 | transform: Optional[Callable] = None,
40 | target_transform: Optional[Callable] = None,
41 | download: bool = False,
42 | ) -> None:
43 | super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform)
44 | os.makedirs(self.root, exist_ok=True)
45 | if isinstance(target_type, str):
46 | target_type = [target_type]
47 | self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation", "category_name")) for t in target_type]
48 |
49 | if download:
50 | self.download()
51 |
52 | if not self._check_integrity():
53 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
54 |
55 | self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories")))
56 | self.categories.remove("BACKGROUND_Google") # this is not a real class
57 |
58 | # For some reason, the category names in "101_ObjectCategories" and
59 | # "Annotations" do not always match. This is a manual map between the
60 | # two. Defaults to using same name, since most names are fine.
61 | name_map = {
62 | "Faces": "Faces_2",
63 | "Faces_easy": "Faces_3",
64 | "Motorbikes": "Motorbikes_16",
65 | "airplanes": "Airplanes_Side_2",
66 | }
67 | self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories))
68 |
69 | self.index: List[int] = []
70 | self.y = []
71 | for (i, c) in enumerate(self.categories):
72 | n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c)))
73 | self.index.extend(range(1, n + 1))
74 | self.y.extend(n * [i])
75 |
76 | self.clip_categories = self.categories.copy()
77 | self.clip_categories[0] = "person"
78 | self.clip_categories[1] = "person"
79 |
80 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
81 | """
82 | Args:
83 | index (int): Index
84 |
85 | Returns:
86 | tuple: (image, target) where the type of target specified by target_type.
87 | """
88 | import scipy.io
89 |
90 | img = Image.open(
91 | os.path.join(
92 | self.root,
93 | "101_ObjectCategories",
94 | self.categories[self.y[index]],
95 | f"image_{self.index[index]:04d}.jpg",
96 | )
97 | )
98 |
99 | target: Any = []
100 | for t in self.target_type:
101 | if t == "category":
102 | target.append(self.y[index])
103 | elif t == "category_name":
104 | target.append(self.clip_categories[self.y[index]])
105 | elif t == "annotation":
106 | data = scipy.io.loadmat(
107 | os.path.join(
108 | self.root,
109 | "Annotations",
110 | self.annotation_categories[self.y[index]],
111 | f"annotation_{self.index[index]:04d}.mat",
112 | )
113 | )
114 | target.append(data["obj_contour"])
115 | target = tuple(target) if len(target) > 1 else target[0]
116 |
117 | if self.transform is not None:
118 | img = self.transform(img)
119 |
120 | if self.target_transform is not None:
121 | target = self.target_transform(target)
122 |
123 | return img, target
124 |
125 | def _check_integrity(self) -> bool:
126 | # can be more robust and check hash of files
127 | return os.path.exists(os.path.join(self.root, "101_ObjectCategories"))
128 |
129 | def __len__(self) -> int:
130 | return len(self.index)
131 |
132 | def download(self) -> None:
133 | if self._check_integrity():
134 | print("Files already downloaded and verified")
135 | return
136 |
137 | download_and_extract_archive(
138 | "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp",
139 | self.root,
140 | filename="101_ObjectCategories.tar.gz",
141 | md5="b224c7392d521a49829488ab0f1120d9",
142 | )
143 | download_and_extract_archive(
144 | "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m",
145 | self.root,
146 | filename="Annotations.tar",
147 | md5="6f83eeb1f24d99cab4eb377263132c91",
148 | )
149 |
150 | def extra_repr(self) -> str:
151 | return "Target type: {target_type}".format(**self.__dict__)
152 |
153 |
154 | class Caltech256(VisionDataset):
155 | """`Caltech 256 `_ Dataset.
156 |
157 | Args:
158 | root (string): Root directory of dataset where directory
159 | ``caltech256`` exists or will be saved to if download is set to True.
160 | transform (callable, optional): A function/transform that takes in an PIL image
161 | and returns a transformed version. E.g, ``transforms.RandomCrop``
162 | target_transform (callable, optional): A function/transform that takes in the
163 | target and transforms it.
164 | download (bool, optional): If true, downloads the dataset from the internet and
165 | puts it in root directory. If dataset is already downloaded, it is not
166 | downloaded again.
167 | """
168 |
169 | def __init__(
170 | self,
171 | root: str,
172 | transform: Optional[Callable] = None,
173 | target_transform: Optional[Callable] = None,
174 | download: bool = False,
175 | prompt_template = "A photo of a {}."
176 | ) -> None:
177 | super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform)
178 | os.makedirs(self.root, exist_ok=True)
179 |
180 | if download:
181 | self.download()
182 |
183 | if not self._check_integrity():
184 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
185 |
186 | self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories")))
187 | self.index: List[int] = []
188 | self.y = []
189 | for (i, c) in enumerate(self.categories):
190 | n = len(
191 | [
192 | item
193 | for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c))
194 | if item.endswith(".jpg")
195 | ]
196 | )
197 | self.index.extend(range(1, n + 1))
198 | self.y.extend(n * [i])
199 |
200 | refined_classes = []
201 | for class_name in self.categories:
202 | class_name = class_name[4:]
203 | refined_classes.append(class_name.replace('-101', ''))
204 |
205 | self.prompt_template = prompt_template
206 | self.clip_prompts = [
207 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ').strip()) \
208 | for label in refined_classes
209 | ]
210 |
211 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
212 | """
213 | Args:
214 | index (int): Index
215 |
216 | Returns:
217 | tuple: (image, target) where target is index of the target class.
218 | """
219 | img = Image.open(
220 | os.path.join(
221 | self.root,
222 | "256_ObjectCategories",
223 | self.categories[self.y[index]],
224 | f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg",
225 | )
226 | )
227 |
228 | target = self.y[index]
229 |
230 | if self.transform is not None:
231 | img = self.transform(img)
232 |
233 | if self.target_transform is not None:
234 | target = self.target_transform(target)
235 |
236 | return img, target
237 |
238 | def _check_integrity(self) -> bool:
239 | # can be more robust and check hash of files
240 | return os.path.exists(os.path.join(self.root, "256_ObjectCategories"))
241 |
242 | def __len__(self) -> int:
243 | return len(self.index)
244 |
245 | def download(self) -> None:
246 | if self._check_integrity():
247 | print("Files already downloaded and verified")
248 | return
249 |
250 | download_and_extract_archive(
251 | "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK",
252 | self.root,
253 | filename="256_ObjectCategories.tar",
254 | md5="67b4f42ca05d46448c6bb8ecd2220f6d",
255 | )
256 |
--------------------------------------------------------------------------------
/code/replace/datasets/country211.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Callable, Optional
3 | import pycountry
4 |
5 | from torchvision.datasets.folder import ImageFolder
6 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive
7 |
8 |
9 | class Country211(ImageFolder):
10 | """`The Country211 Data Set `_ from OpenAI.
11 |
12 | This dataset was built by filtering the images from the YFCC100m dataset
13 | that have GPS coordinate corresponding to a ISO-3166 country code. The
14 | dataset is balanced by sampling 150 train images, 50 validation images, and
15 | 100 test images images for each country.
16 |
17 | Args:
18 | root (string): Root directory of the dataset.
19 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``.
20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
21 | version. E.g, ``transforms.RandomCrop``.
22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
23 | download (bool, optional): If True, downloads the dataset from the internet and puts it into
24 | ``root/country211/``. If dataset is already downloaded, it is not downloaded again.
25 | """
26 |
27 | _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz"
28 | _MD5 = "84988d7644798601126c29e9877aab6a"
29 |
30 | def __init__(
31 | self,
32 | root: str,
33 | split: str = "train",
34 | transform: Optional[Callable] = None,
35 | target_transform: Optional[Callable] = None,
36 | download: bool = False,
37 | prompt_template = "A photo I took in {}."
38 | ) -> None:
39 | self._split = verify_str_arg(split, "split", ("train", "valid", "test"))
40 |
41 | root = Path(root).expanduser()
42 | self.root = str(root)
43 | self._base_folder = root / "country211"
44 |
45 | if download:
46 | self._download()
47 |
48 | if not self._check_exists():
49 | raise RuntimeError("Dataset not found. You can use download=True to download it")
50 |
51 | super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform)
52 |
53 | countries = {}
54 | self.code_to_country = dict()
55 | for country in pycountry.countries:
56 | countries[country.alpha_2] = country.name
57 |
58 | countries['XK'] = 'Kosovo'
59 |
60 | for code in self.class_to_idx:
61 | self.code_to_country[code] = countries[code].split(',')[0]
62 |
63 | self.prompt_template = prompt_template
64 | self.clip_prompts = [
65 | prompt_template.format(countries[label].replace('_', ' ').replace('-', ' ')) \
66 | for label in self.classes
67 | ]
68 |
69 | self.root = str(root)
70 |
71 | def _check_exists(self) -> bool:
72 | return self._base_folder.exists() and self._base_folder.is_dir()
73 |
74 | def _download(self) -> None:
75 | if self._check_exists():
76 | return
77 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
78 |
--------------------------------------------------------------------------------
/code/replace/datasets/dtd.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from typing import Optional, Callable
4 |
5 | import PIL.Image
6 |
7 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 |
11 | class DTD(VisionDataset):
12 | """`Describable Textures Dataset (DTD) `_.
13 |
14 | Args:
15 | root (string): Root directory of the dataset.
16 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
17 | partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``.
18 |
19 | .. note::
20 |
21 | The partition only changes which split each image belongs to. Thus, regardless of the selected
22 | partition, combining all splits will result in all images.
23 |
24 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
25 | version. E.g, ``transforms.RandomCrop``.
26 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27 | download (bool, optional): If True, downloads the dataset from the internet and
28 | puts it in root directory. If dataset is already downloaded, it is not
29 | downloaded again. Default is False.
30 | """
31 |
32 | _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz"
33 | _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1"
34 |
35 | def __init__(
36 | self,
37 | root: str,
38 | split: str = "train",
39 | partition: int = 1,
40 | transform: Optional[Callable] = None,
41 | target_transform: Optional[Callable] = None,
42 | download: bool = False,
43 | prompt_template = "A surface with a {} texture."
44 | ) -> None:
45 | self._split = verify_str_arg(split, "split", ("train", "val", "test"))
46 | if not isinstance(partition, int) and not (1 <= partition <= 10):
47 | raise ValueError(
48 | f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, "
49 | f"but got {partition} instead"
50 | )
51 | self._partition = partition
52 |
53 | super().__init__(root, transform=transform, target_transform=target_transform)
54 | self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower()
55 | self._data_folder = self._base_folder / "dtd"
56 | self._meta_folder = self._data_folder / "labels"
57 | self._images_folder = self._data_folder / "images"
58 |
59 | if download:
60 | self._download()
61 |
62 | if not self._check_exists():
63 | raise RuntimeError("Dataset not found. You can use download=True to download it")
64 |
65 | self._image_files = []
66 | classes = []
67 | with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file:
68 | for line in file:
69 | cls, name = line.strip().split("/")
70 | self._image_files.append(self._images_folder.joinpath(cls, name))
71 | classes.append(cls)
72 |
73 | self.classes = sorted(set(classes))
74 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
75 | self.idx_to_class = dict(zip(range(len(self.classes)), self.classes))
76 | self._labels = [self.class_to_idx[cls] for cls in classes]
77 |
78 | self.prompt_template = prompt_template
79 |
80 | self.clip_prompts = [
81 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \
82 | for label in self.classes
83 | ]
84 |
85 | def __len__(self) -> int:
86 | return len(self._image_files)
87 |
88 | def __getitem__(self, idx):
89 | image_file, label = self._image_files[idx], self._labels[idx]
90 | image = PIL.Image.open(image_file).convert("RGB")
91 |
92 | if self.transform:
93 | image = self.transform(image)
94 |
95 | if self.target_transform:
96 | label = self.target_transform(label)
97 |
98 | return image, label
99 |
100 | def extra_repr(self) -> str:
101 | return f"split={self._split}, partition={self._partition}"
102 |
103 | def _check_exists(self) -> bool:
104 | return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder)
105 |
106 | def _download(self) -> None:
107 | if self._check_exists():
108 | return
109 | download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5)
110 |
--------------------------------------------------------------------------------
/code/replace/datasets/eurosat.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Callable, Optional
3 |
4 | from torchvision.datasets.folder import ImageFolder
5 | from torchvision.datasets.utils import download_and_extract_archive
6 |
7 |
8 | class EuroSAT(ImageFolder):
9 | """RGB version of the `EuroSAT `_ Dataset.
10 |
11 | Args:
12 | root (string): Root directory of dataset where ``root/eurosat`` exists.
13 | transform (callable, optional): A function/transform that takes in an PIL image
14 | and returns a transformed version. E.g, ``transforms.RandomCrop``
15 | target_transform (callable, optional): A function/transform that takes in the
16 | target and transforms it.
17 | download (bool, optional): If True, downloads the dataset from the internet and
18 | puts it in root directory. If dataset is already downloaded, it is not
19 | downloaded again. Default is False.
20 | """
21 | idx_to_class = {
22 | 0: 'annual crop land',
23 | 1: 'a forest',
24 | 2: 'brushland or shrubland',
25 | 3: 'a highway or a road',
26 | 4: 'industrial buildings',
27 | 5: 'pasture land',
28 | 6: 'permanent crop land',
29 | 7: 'residential buildings',
30 | 8: 'a river',
31 | 9: 'a sea or a lake'
32 | }
33 |
34 | def __init__(
35 | self,
36 | root: str,
37 | transform: Optional[Callable] = None,
38 | target_transform: Optional[Callable] = None,
39 | download: bool = False,
40 | prompt_template = "A centered satellite photo of {}."
41 | ) -> None:
42 | self.root = os.path.expanduser(root)
43 | self._base_folder = os.path.join(self.root, "eurosat")
44 | self._data_folder = os.path.join(self._base_folder, "2750")
45 |
46 | if download:
47 | self.download()
48 |
49 | if not self._check_exists():
50 | raise RuntimeError("Dataset not found. You can use download=True to download it")
51 |
52 | super().__init__(self._data_folder, transform=transform, target_transform=target_transform)
53 | self.root = os.path.expanduser(root)
54 |
55 | self.prompt_template = prompt_template
56 | self.clip_prompts = [
57 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \
58 | for label in self.idx_to_class.values()
59 | ]
60 |
61 | def __len__(self) -> int:
62 | return len(self.samples)
63 |
64 | def _check_exists(self) -> bool:
65 | return os.path.exists(self._data_folder)
66 |
67 | def download(self) -> None:
68 |
69 | if self._check_exists():
70 | return
71 |
72 | os.makedirs(self._base_folder, exist_ok=True)
73 | download_and_extract_archive(
74 | "https://madm.dfki.de/files/sentinel/EuroSAT.zip",
75 | download_root=self._base_folder,
76 | md5="c8fa014336c82ac7804f0398fcb19387",
77 | )
78 |
--------------------------------------------------------------------------------
/code/replace/datasets/fgvc_aircraft.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | from typing import Any, Callable, Optional, Tuple
5 |
6 | import PIL.Image
7 |
8 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
9 | from torchvision.datasets.vision import VisionDataset
10 |
11 |
12 | class FGVCAircraft(VisionDataset):
13 | """`FGVC Aircraft `_ Dataset.
14 |
15 | The dataset contains 10,000 images of aircraft, with 100 images for each of 100
16 | different aircraft model variants, most of which are airplanes.
17 | Aircraft models are organized in a three-levels hierarchy. The three levels, from
18 | finer to coarser, are:
19 |
20 | - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually
21 | indistinguishable into one class. The dataset comprises 100 different variants.
22 | - ``family``, e.g. Boeing 737. The dataset comprises 70 different families.
23 | - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers.
24 |
25 | Args:
26 | root (string): Root directory of the FGVC Aircraft dataset.
27 | split (string, optional): The dataset split, supports ``train``, ``val``,
28 | ``trainval`` and ``test``.
29 | annotation_level (str, optional): The annotation level, supports ``variant``,
30 | ``family`` and ``manufacturer``.
31 | transform (callable, optional): A function/transform that takes in an PIL image
32 | and returns a transformed version. E.g, ``transforms.RandomCrop``
33 | target_transform (callable, optional): A function/transform that takes in the
34 | target and transforms it.
35 | download (bool, optional): If True, downloads the dataset from the internet and
36 | puts it in root directory. If dataset is already downloaded, it is not
37 | downloaded again.
38 | """
39 |
40 | _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz"
41 |
42 | def __init__(
43 | self,
44 | root: str,
45 | split: str = "trainval",
46 | annotation_level: str = "variant",
47 | transform: Optional[Callable] = None,
48 | target_transform: Optional[Callable] = None,
49 | download: bool = False,
50 | prompt_template = "A photo of a {}, a type of aircraft."
51 | ) -> None:
52 | super().__init__(root, transform=transform, target_transform=target_transform)
53 | self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test"))
54 | self._annotation_level = verify_str_arg(
55 | annotation_level, "annotation_level", ("variant", "family", "manufacturer")
56 | )
57 |
58 | self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b")
59 | if download:
60 | self._download()
61 |
62 | if not self._check_exists():
63 | raise RuntimeError("Dataset not found. You can use download=True to download it")
64 |
65 | annotation_file = os.path.join(
66 | self._data_path,
67 | "data",
68 | {
69 | "variant": "variants.txt",
70 | "family": "families.txt",
71 | "manufacturer": "manufacturers.txt",
72 | }[self._annotation_level],
73 | )
74 | with open(annotation_file, "r") as f:
75 | self.classes = [line.strip() for line in f]
76 |
77 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
78 |
79 | image_data_folder = os.path.join(self._data_path, "data", "images")
80 | labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt")
81 |
82 | self._image_files = []
83 | self._labels = []
84 |
85 | with open(labels_file, "r") as f:
86 | for line in f:
87 | image_name, label_name = line.strip().split(" ", 1)
88 | self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg"))
89 | self._labels.append(self.class_to_idx[label_name])
90 |
91 | refined_classes = []
92 | for label in self.classes:
93 | if label[0] == '7':
94 | refined_classes.append(f"Boeing {label}")
95 | else:
96 | refined_classes.append(label)
97 |
98 | self.prompt_template = prompt_template
99 | self.clip_prompts = [
100 | prompt_template.format(label.lower().replace('_', ' ')) \
101 | for label in refined_classes
102 | ]
103 |
104 | def __len__(self) -> int:
105 | return len(self._image_files)
106 |
107 | def __getitem__(self, idx) -> Tuple[Any, Any]:
108 | image_file, label = self._image_files[idx], self._labels[idx]
109 | image = PIL.Image.open(image_file).convert("RGB")
110 |
111 | if self.transform:
112 | image = self.transform(image)
113 |
114 | if self.target_transform:
115 | label = self.target_transform(label)
116 |
117 | return image, label
118 |
119 | def _download(self) -> None:
120 | """
121 | Download the FGVC Aircraft dataset archive and extract it under root.
122 | """
123 | if self._check_exists():
124 | return
125 | download_and_extract_archive(self._URL, self.root)
126 |
127 | def _check_exists(self) -> bool:
128 | return os.path.exists(self._data_path) and os.path.isdir(self._data_path)
129 |
--------------------------------------------------------------------------------
/code/replace/datasets/flowers102.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Tuple, Callable, Optional
3 |
4 | import PIL.Image
5 |
6 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 |
10 | class Flowers102(VisionDataset):
11 | """`Oxford 102 Flower `_ Dataset.
12 |
13 | .. warning::
14 |
15 | This class needs `scipy `_ to load target files from `.mat` format.
16 |
17 | Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The
18 | flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of
19 | between 40 and 258 images.
20 |
21 | The images have large scale, pose and light variations. In addition, there are categories that
22 | have large variations within the category, and several very similar categories.
23 |
24 | Args:
25 | root (string): Root directory of the dataset.
26 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``.
27 | transform (callable, optional): A function/transform that takes in an PIL image and returns a
28 | transformed version. E.g, ``transforms.RandomCrop``.
29 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30 | download (bool, optional): If true, downloads the dataset from the internet and
31 | puts it in root directory. If dataset is already downloaded, it is not
32 | downloaded again.
33 | """
34 |
35 | _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/"
36 | _file_dict = { # filename, md5
37 | "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"),
38 | "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"),
39 | "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"),
40 | }
41 | _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"}
42 |
43 | _classes = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells',
44 | 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise',
45 | 'monkshood', 'globe thistle', 'snapdragon', 'colts foot', 'king protea', 'spear thistle',
46 | 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower',
47 | 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger',
48 | 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke',
49 | 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster',
50 | 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip',
51 | 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue',
52 | 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia',
53 | 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium',
54 | 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan',
55 | 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower',
56 | 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory',
57 | 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus',
58 | 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily',
59 | 'hippeastrum ', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow',
60 | 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily']
61 |
62 | def __init__(
63 | self,
64 | root: str,
65 | split: str = "train",
66 | transform: Optional[Callable] = None,
67 | target_transform: Optional[Callable] = None,
68 | download: bool = False,
69 | prompt_template = "A photo of a {}, a type of flower."
70 | ) -> None:
71 | super().__init__(root, transform=transform, target_transform=target_transform)
72 | self._split = verify_str_arg(split, "split", ("train", "val", "test"))
73 | self._base_folder = Path(self.root) / "flowers-102"
74 | self._images_folder = self._base_folder / "jpg"
75 |
76 | if download:
77 | self.download()
78 |
79 | if not self._check_integrity():
80 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
81 |
82 | from scipy.io import loadmat
83 |
84 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True)
85 | image_ids = set_ids[self._splits_map[self._split]].tolist()
86 |
87 | labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True)
88 | image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1))
89 |
90 | self._labels = []
91 | self._image_files = []
92 | for image_id in image_ids:
93 | self._labels.append(image_id_to_label[image_id])
94 | self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg")
95 |
96 | self.idx_to_class = {label:self._classes[label] for label in range(102)}
97 |
98 | self.prompt_template = prompt_template
99 | self.clip_prompts = [
100 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \
101 | for label in self._classes
102 | ]
103 |
104 |
105 | def __len__(self) -> int:
106 | return len(self._image_files)
107 |
108 | def __getitem__(self, idx) -> Tuple[Any, Any]:
109 | image_file, label = self._image_files[idx], self._labels[idx]
110 | image = PIL.Image.open(image_file).convert("RGB")
111 |
112 | if self.transform:
113 | image = self.transform(image)
114 |
115 | if self.target_transform:
116 | label = self.target_transform(label)
117 |
118 | return image, label
119 |
120 | def extra_repr(self) -> str:
121 | return f"split={self._split}"
122 |
123 | def _check_integrity(self):
124 | if not (self._images_folder.exists() and self._images_folder.is_dir()):
125 | return False
126 |
127 | for id in ["label", "setid"]:
128 | filename, md5 = self._file_dict[id]
129 | if not check_integrity(str(self._base_folder / filename), md5):
130 | return False
131 | return True
132 |
133 | def download(self):
134 | if self._check_integrity():
135 | return
136 | download_and_extract_archive(
137 | f"{self._download_url_prefix}{self._file_dict['image'][0]}",
138 | str(self._base_folder),
139 | md5=self._file_dict["image"][1],
140 | )
141 | for id in ["label", "setid"]:
142 | filename, md5 = self._file_dict[id]
143 | download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5)
144 |
--------------------------------------------------------------------------------
/code/replace/datasets/folder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
4 |
5 | from PIL import Image
6 |
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 |
10 | def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool:
11 | """Checks if a file is an allowed extension.
12 |
13 | Args:
14 | filename (string): path to a file
15 | extensions (tuple of strings): extensions to consider (lowercase)
16 |
17 | Returns:
18 | bool: True if the filename ends with one of given extensions
19 | """
20 | return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions))
21 |
22 |
23 | def is_image_file(filename: str) -> bool:
24 | """Checks if a file is an allowed image extension.
25 |
26 | Args:
27 | filename (string): path to a file
28 |
29 | Returns:
30 | bool: True if the filename ends with a known image extension
31 | """
32 | return has_file_allowed_extension(filename, IMG_EXTENSIONS)
33 |
34 |
35 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
36 | """Finds the class folders in a dataset.
37 |
38 | See :class:`DatasetFolder` for details.
39 | """
40 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
41 | if not classes:
42 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
43 |
44 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
45 | return classes, class_to_idx
46 |
47 |
48 | def make_dataset(
49 | directory: str,
50 | class_to_idx: Optional[Dict[str, int]] = None,
51 | extensions: Optional[Union[str, Tuple[str, ...]]] = None,
52 | is_valid_file: Optional[Callable[[str], bool]] = None,
53 | ) -> List[Tuple[str, int]]:
54 | """Generates a list of samples of a form (path_to_sample, class).
55 |
56 | See :class:`DatasetFolder` for details.
57 |
58 | Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
59 | by default.
60 | """
61 | directory = os.path.expanduser(directory)
62 |
63 | if class_to_idx is None:
64 | _, class_to_idx = find_classes(directory)
65 | elif not class_to_idx:
66 | raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
67 |
68 | both_none = extensions is None and is_valid_file is None
69 | both_something = extensions is not None and is_valid_file is not None
70 | if both_none or both_something:
71 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
72 |
73 | if extensions is not None:
74 |
75 | def is_valid_file(x: str) -> bool:
76 | return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
77 |
78 | is_valid_file = cast(Callable[[str], bool], is_valid_file)
79 |
80 | instances = []
81 | available_classes = set()
82 | for target_class in sorted(class_to_idx.keys()):
83 | class_index = class_to_idx[target_class]
84 | target_dir = os.path.join(directory, target_class)
85 | if not os.path.isdir(target_dir):
86 | continue
87 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
88 | for fname in sorted(fnames):
89 | path = os.path.join(root, fname)
90 | if is_valid_file(path):
91 | item = path, class_index
92 | instances.append(item)
93 |
94 | if target_class not in available_classes:
95 | available_classes.add(target_class)
96 |
97 | empty_classes = set(class_to_idx.keys()) - available_classes
98 | if empty_classes:
99 | msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
100 | if extensions is not None:
101 | msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
102 | raise FileNotFoundError(msg)
103 |
104 | return instances
105 |
106 |
107 | class DatasetFolder(VisionDataset):
108 | """A generic data loader.
109 |
110 | This default directory structure can be customized by overriding the
111 | :meth:`find_classes` method.
112 |
113 | Args:
114 | root (string): Root directory path.
115 | loader (callable): A function to load a sample given its path.
116 | extensions (tuple[string]): A list of allowed extensions.
117 | both extensions and is_valid_file should not be passed.
118 | transform (callable, optional): A function/transform that takes in
119 | a sample and returns a transformed version.
120 | E.g, ``transforms.RandomCrop`` for images.
121 | target_transform (callable, optional): A function/transform that takes
122 | in the target and transforms it.
123 | is_valid_file (callable, optional): A function that takes path of a file
124 | and check if the file is a valid file (used to check of corrupt files)
125 | both extensions and is_valid_file should not be passed.
126 |
127 | Attributes:
128 | classes (list): List of the class names sorted alphabetically.
129 | class_to_idx (dict): Dict with items (class_name, class_index).
130 | samples (list): List of (sample path, class_index) tuples
131 | targets (list): The class_index value for each image in the dataset
132 | """
133 |
134 | def __init__(
135 | self,
136 | root: str,
137 | loader: Callable[[str], Any],
138 | extensions: Optional[Tuple[str, ...]] = None,
139 | transform: Optional[Callable] = None,
140 | target_transform: Optional[Callable] = None,
141 | is_valid_file: Optional[Callable[[str], bool]] = None,
142 | ) -> None:
143 | super().__init__(root, transform=transform, target_transform=target_transform)
144 | classes, class_to_idx = self.find_classes(self.root)
145 | samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
146 |
147 | self.loader = loader
148 | self.extensions = extensions
149 |
150 | self.classes = classes
151 | self.class_to_idx = class_to_idx
152 | self.samples = samples
153 | self.targets = [s[1] for s in samples]
154 |
155 | @staticmethod
156 | def make_dataset(
157 | directory: str,
158 | class_to_idx: Dict[str, int],
159 | extensions: Optional[Tuple[str, ...]] = None,
160 | is_valid_file: Optional[Callable[[str], bool]] = None,
161 | ) -> List[Tuple[str, int]]:
162 | """Generates a list of samples of a form (path_to_sample, class).
163 |
164 | This can be overridden to e.g. read files from a compressed zip file instead of from the disk.
165 |
166 | Args:
167 | directory (str): root dataset directory, corresponding to ``self.root``.
168 | class_to_idx (Dict[str, int]): Dictionary mapping class name to class index.
169 | extensions (optional): A list of allowed extensions.
170 | Either extensions or is_valid_file should be passed. Defaults to None.
171 | is_valid_file (optional): A function that takes path of a file
172 | and checks if the file is a valid file
173 | (used to check of corrupt files) both extensions and
174 | is_valid_file should not be passed. Defaults to None.
175 |
176 | Raises:
177 | ValueError: In case ``class_to_idx`` is empty.
178 | ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None.
179 | FileNotFoundError: In case no valid file was found for any class.
180 |
181 | Returns:
182 | List[Tuple[str, int]]: samples of a form (path_to_sample, class)
183 | """
184 | if class_to_idx is None:
185 | # prevent potential bug since make_dataset() would use the class_to_idx logic of the
186 | # find_classes() function, instead of using that of the find_classes() method, which
187 | # is potentially overridden and thus could have a different logic.
188 | raise ValueError("The class_to_idx parameter cannot be None.")
189 | return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file)
190 |
191 | def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
192 | """Find the class folders in a dataset structured as follows::
193 |
194 | directory/
195 | ├── class_x
196 | │ ├── xxx.ext
197 | │ ├── xxy.ext
198 | │ └── ...
199 | │ └── xxz.ext
200 | └── class_y
201 | ├── 123.ext
202 | ├── nsdf3.ext
203 | └── ...
204 | └── asd932_.ext
205 |
206 | This method can be overridden to only consider
207 | a subset of classes, or to adapt to a different dataset directory structure.
208 |
209 | Args:
210 | directory(str): Root directory path, corresponding to ``self.root``
211 |
212 | Raises:
213 | FileNotFoundError: If ``dir`` has no class folders.
214 |
215 | Returns:
216 | (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index.
217 | """
218 | return find_classes(directory)
219 |
220 | def __getitem__(self, index: int) -> Tuple[Any, Any]:
221 | """
222 | Args:
223 | index (int): Index
224 |
225 | Returns:
226 | tuple: (sample, target) where target is class_index of the target class.
227 | """
228 | path, target = self.samples[index]
229 | sample = self.loader(path)
230 | if self.transform is not None:
231 | sample = self.transform(sample)
232 | if self.target_transform is not None:
233 | target = self.target_transform(target)
234 |
235 | return sample, target
236 |
237 | def __len__(self) -> int:
238 | return len(self.samples)
239 |
240 |
241 | IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp")
242 |
243 |
244 | def pil_loader(path: str) -> Image.Image:
245 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
246 | with open(path, "rb") as f:
247 | img = Image.open(f)
248 | return img.convert("RGB")
249 |
250 |
251 | # TODO: specify the return type
252 | def accimage_loader(path: str) -> Any:
253 | import accimage
254 |
255 | try:
256 | return accimage.Image(path)
257 | except OSError:
258 | # Potentially a decoding problem, fall back to PIL.Image
259 | return pil_loader(path)
260 |
261 |
262 | def default_loader(path: str) -> Any:
263 | from torchvision import get_image_backend
264 |
265 | if get_image_backend() == "accimage":
266 | return accimage_loader(path)
267 | else:
268 | return pil_loader(path)
269 |
270 |
271 | class ImageFolder(DatasetFolder):
272 | """A generic data loader where the images are arranged in this way by default: ::
273 |
274 | root/dog/xxx.png
275 | root/dog/xxy.png
276 | root/dog/[...]/xxz.png
277 |
278 | root/cat/123.png
279 | root/cat/nsdf3.png
280 | root/cat/[...]/asd932_.png
281 |
282 | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
283 | the same methods can be overridden to customize the dataset.
284 |
285 | Args:
286 | root (string): Root directory path.
287 | transform (callable, optional): A function/transform that takes in an PIL image
288 | and returns a transformed version. E.g, ``transforms.RandomCrop``
289 | target_transform (callable, optional): A function/transform that takes in the
290 | target and transforms it.
291 | loader (callable, optional): A function to load an image given its path.
292 | is_valid_file (callable, optional): A function that takes path of an Image file
293 | and check if the file is a valid file (used to check of corrupt files)
294 |
295 | Attributes:
296 | classes (list): List of the class names sorted alphabetically.
297 | class_to_idx (dict): Dict with items (class_name, class_index).
298 | imgs (list): List of (image path, class_index) tuples
299 | """
300 |
301 | def __init__(
302 | self,
303 | root: str,
304 | transform: Optional[Callable] = None,
305 | target_transform: Optional[Callable] = None,
306 | loader: Callable[[str], Any] = default_loader,
307 | is_valid_file: Optional[Callable[[str], bool]] = None,
308 | ):
309 | super().__init__(
310 | root,
311 | loader,
312 | IMG_EXTENSIONS if is_valid_file is None else None,
313 | transform=transform,
314 | target_transform=target_transform,
315 | is_valid_file=is_valid_file,
316 | )
317 | self.imgs = self.samples
318 |
319 | class ImageNetFolder(DatasetFolder):
320 | """A generic data loader where the images are arranged in this way by default: ::
321 |
322 | root/dog/xxx.png
323 | root/dog/xxy.png
324 | root/dog/[...]/xxz.png
325 |
326 | root/cat/123.png
327 | root/cat/nsdf3.png
328 | root/cat/[...]/asd932_.png
329 |
330 | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
331 | the same methods can be overridden to customize the dataset.
332 |
333 | Args:
334 | root (string): Root directory path.
335 | transform (callable, optional): A function/transform that takes in an PIL image
336 | and returns a transformed version. E.g, ``transforms.RandomCrop``
337 | target_transform (callable, optional): A function/transform that takes in the
338 | target and transforms it.
339 | loader (callable, optional): A function to load an image given its path.
340 | is_valid_file (callable, optional): A function that takes path of an Image file
341 | and check if the file is a valid file (used to check of corrupt files)
342 |
343 | Attributes:
344 | classes (list): List of the class names sorted alphabetically.
345 | class_to_idx (dict): Dict with items (class_name, class_index).
346 | imgs (list): List of (image path, class_index) tuples
347 | """
348 |
349 | def find_classes(self, directory: str) -> Tuple[List[str] | Dict[str, int]]:
350 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
351 | if not classes:
352 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")
353 | # remove duplicate class 'maillot' ('n03710637' and 'n03710721'). We retain 'n03710721' only.
354 | # resulting in 999 classes
355 | if 'n03710637' in classes:
356 | classes.remove('n03710637')
357 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
358 | return classes, class_to_idx
359 |
360 | def make_dataset(
361 | self,
362 | directory: str,
363 | class_to_idx: Dict[str, int],
364 | extensions: Optional[Tuple[str, ...]] = None,
365 | is_valid_file: Optional[Callable[[str], bool]] = None,
366 | ) -> List[Tuple[str, int]]:
367 |
368 | directory = os.path.expanduser(directory)
369 |
370 | if class_to_idx is None:
371 | _, class_to_idx = self.find_classes(directory)
372 | elif not class_to_idx:
373 | raise ValueError("'class_to_index' must have at least one entry to collect any samples.")
374 |
375 | both_none = extensions is None and is_valid_file is None
376 | both_something = extensions is not None and is_valid_file is not None
377 | if both_none or both_something:
378 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
379 |
380 | if extensions is not None:
381 |
382 | def is_valid_file(x: str) -> bool:
383 | return has_file_allowed_extension(x, extensions) # type: ignore[arg-type]
384 |
385 | is_valid_file = cast(Callable[[str], bool], is_valid_file)
386 |
387 | instances = []
388 | available_classes = set()
389 | for target_class in sorted(class_to_idx.keys()):
390 | class_index = class_to_idx[target_class]
391 | target_dir = os.path.join(directory, target_class)
392 | select_files = self.select_files[target_class] if self.select_files is not None else None
393 | if not os.path.isdir(target_dir):
394 | continue
395 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
396 | for fname in sorted(fnames):
397 | if select_files is not None and fname not in select_files:
398 | continue
399 | path = os.path.join(root, fname)
400 | if is_valid_file(path):
401 | item = path, class_index
402 | instances.append(item)
403 |
404 | if target_class not in available_classes:
405 | available_classes.add(target_class)
406 |
407 | empty_classes = set(class_to_idx.keys()) - available_classes
408 | if empty_classes:
409 | msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
410 | if extensions is not None:
411 | msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}"
412 | raise FileNotFoundError(msg)
413 |
414 | return instances
415 |
416 | def __init__(
417 | self,
418 | root: str,
419 | transform: Optional[Callable] = None,
420 | target_transform: Optional[Callable] = None,
421 | loader: Callable[[str], Any] = default_loader,
422 | is_valid_file: Optional[Callable[[str], bool]] = None,
423 | select_files: Optional[Dict[str, list]] = None, # an argument that we add to hold out some samples for evaluation
424 | ):
425 | self.select_files = select_files
426 | super().__init__(
427 | root,
428 | loader,
429 | IMG_EXTENSIONS if is_valid_file is None else None,
430 | transform=transform,
431 | target_transform=target_transform,
432 | is_valid_file=is_valid_file,
433 | )
434 | self.imgs = self.samples
435 |
--------------------------------------------------------------------------------
/code/replace/datasets/food101.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pathlib import Path
3 | from typing import Any, Tuple, Callable, Optional
4 |
5 | import PIL.Image
6 |
7 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive
8 | from torchvision.datasets.vision import VisionDataset
9 |
10 |
11 | class Food101(VisionDataset):
12 | """`The Food-101 Data Set `_.
13 |
14 | The Food-101 is a challenging data set of 101 food categories, with 101'000 images.
15 | For each class, 250 manually reviewed test images are provided as well as 750 training images.
16 | On purpose, the training images were not cleaned, and thus still contain some amount of noise.
17 | This comes mostly in the form of intense colors and sometimes wrong labels. All images were
18 | rescaled to have a maximum side length of 512 pixels.
19 |
20 |
21 | Args:
22 | root (string): Root directory of the dataset.
23 | split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``.
24 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
25 | version. E.g, ``transforms.RandomCrop``.
26 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
27 | download (bool, optional): If True, downloads the dataset from the internet and
28 | puts it in root directory. If dataset is already downloaded, it is not
29 | downloaded again. Default is False.
30 | """
31 |
32 | _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz"
33 | _MD5 = "85eeb15f3717b99a5da872d97d918f87"
34 |
35 | def __init__(
36 | self,
37 | root: str,
38 | split: str = "train",
39 | transform: Optional[Callable] = None,
40 | target_transform: Optional[Callable] = None,
41 | download: bool = False,
42 | prompt_template = "A photo of {}, a type of food."
43 | ) -> None:
44 | super().__init__(root, transform=transform, target_transform=target_transform)
45 | self._split = verify_str_arg(split, "split", ("train", "test"))
46 | self._base_folder = Path(self.root) / "food-101"
47 | self._meta_folder = self._base_folder / "meta"
48 | self._images_folder = self._base_folder / "images"
49 |
50 | if download:
51 | self._download()
52 |
53 | if not self._check_exists():
54 | raise RuntimeError("Dataset not found. You can use download=True to download it")
55 |
56 | self._labels = []
57 | self._image_files = []
58 | with open(self._meta_folder / f"{split}.json") as f:
59 | metadata = json.loads(f.read())
60 |
61 | self.classes = sorted(metadata.keys())
62 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
63 |
64 | for class_label, im_rel_paths in metadata.items():
65 | self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths)
66 | self._image_files += [
67 | self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths
68 | ]
69 |
70 | self.prompt_template = prompt_template
71 | self.clip_prompts = [
72 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \
73 | for label in self.classes
74 | ]
75 |
76 | def __len__(self) -> int:
77 | return len(self._image_files)
78 |
79 | def __getitem__(self, idx) -> Tuple[Any, Any]:
80 | image_file, label = self._image_files[idx], self._labels[idx]
81 | image = PIL.Image.open(image_file).convert("RGB")
82 |
83 | if self.transform:
84 | image = self.transform(image)
85 |
86 | if self.target_transform:
87 | label = self.target_transform(label)
88 |
89 | return image, label
90 |
91 | def extra_repr(self) -> str:
92 | return f"split={self._split}"
93 |
94 | def _check_exists(self) -> bool:
95 | return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder))
96 |
97 | def _download(self) -> None:
98 | if self._check_exists():
99 | return
100 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5)
101 |
--------------------------------------------------------------------------------
/code/replace/datasets/oxford_iiit_pet.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import pathlib
4 | from typing import Any, Callable, Optional, Union, Tuple
5 | from typing import Sequence
6 |
7 | from PIL import Image
8 |
9 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
10 | from torchvision.datasets.vision import VisionDataset
11 |
12 |
13 | class OxfordIIITPet(VisionDataset):
14 | """`Oxford-IIIT Pet Dataset `_.
15 |
16 | Args:
17 | root (string): Root directory of the dataset.
18 | split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``.
19 | target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or
20 | ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent:
21 |
22 | - ``category`` (int): Label for one of the 37 pet categories.
23 | - ``segmentation`` (PIL image): Segmentation trimap of the image.
24 |
25 | If empty, ``None`` will be returned as target.
26 |
27 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
28 | version. E.g, ``transforms.RandomCrop``.
29 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
30 | download (bool, optional): If True, downloads the dataset from the internet and puts it into
31 | ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again.
32 | """
33 |
34 | _RESOURCES = (
35 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"),
36 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"),
37 | )
38 | _VALID_TARGET_TYPES = ("category", "segmentation")
39 |
40 | def __init__(
41 | self,
42 | root: str,
43 | split: str = "trainval",
44 | target_types: Union[Sequence[str], str] = "category",
45 | transforms: Optional[Callable] = None,
46 | transform: Optional[Callable] = None,
47 | target_transform: Optional[Callable] = None,
48 | download: bool = False,
49 | prompt_template = "A photo of a {}, a type of pet."
50 | ):
51 | self._split = verify_str_arg(split, "split", ("trainval", "test"))
52 | if isinstance(target_types, str):
53 | target_types = [target_types]
54 | self._target_types = [
55 | verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types
56 | ]
57 |
58 | super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform)
59 | self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet"
60 | self._images_folder = self._base_folder / "images"
61 | self._anns_folder = self._base_folder / "annotations"
62 | self._segs_folder = self._anns_folder / "trimaps"
63 |
64 | if download:
65 | self._download()
66 |
67 | if not self._check_exists():
68 | raise RuntimeError("Dataset not found. You can use download=True to download it")
69 |
70 | image_ids = []
71 | self._labels = []
72 | with open(self._anns_folder / f"{self._split}.txt") as file:
73 | for line in file:
74 | image_id, label, *_ = line.strip().split()
75 | image_ids.append(image_id)
76 | self._labels.append(int(label) - 1)
77 |
78 | self.classes = [
79 | " ".join(part.title() for part in raw_cls.split("_"))
80 | for raw_cls, _ in sorted(
81 | {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)},
82 | key=lambda image_id_and_label: image_id_and_label[1],
83 | )
84 | ]
85 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
86 |
87 | self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids]
88 | self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids]
89 |
90 | self.prompt_template = prompt_template
91 | self.clip_prompts = [
92 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \
93 | for label in self.classes
94 | ]
95 |
96 | def __len__(self) -> int:
97 | return len(self._images)
98 |
99 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
100 | image = Image.open(self._images[idx]).convert("RGB")
101 |
102 | target: Any = []
103 | for target_type in self._target_types:
104 | if target_type == "category":
105 | target.append(self._labels[idx])
106 | else: # target_type == "segmentation"
107 | target.append(Image.open(self._segs[idx]))
108 |
109 | if not target:
110 | target = None
111 | elif len(target) == 1:
112 | target = target[0]
113 | else:
114 | target = tuple(target)
115 |
116 | if self.transforms:
117 | image, target = self.transforms(image, target)
118 |
119 | return image, target
120 |
121 | def _check_exists(self) -> bool:
122 | for folder in (self._images_folder, self._anns_folder):
123 | if not (os.path.exists(folder) and os.path.isdir(folder)):
124 | return False
125 | else:
126 | return True
127 |
128 | def _download(self) -> None:
129 | if self._check_exists():
130 | return
131 |
132 | for url, md5 in self._RESOURCES:
133 | download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5)
134 |
--------------------------------------------------------------------------------
/code/replace/datasets/pcam.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from typing import Any, Callable, Optional, Tuple
3 |
4 | from PIL import Image
5 |
6 | from torchvision.datasets.utils import download_file_from_google_drive, _decompress, verify_str_arg
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 |
10 | class PCAM(VisionDataset):
11 | """`PCAM Dataset `_.
12 |
13 | The PatchCamelyon dataset is a binary classification dataset with 327,680
14 | color images (96px x 96px), extracted from histopathologic scans of lymph node
15 | sections. Each image is annotated with a binary label indicating presence of
16 | metastatic tissue.
17 |
18 | This dataset requires the ``h5py`` package which you can install with ``pip install h5py``.
19 |
20 | Args:
21 | root (string): Root directory of the dataset.
22 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``.
23 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed
24 | version. E.g, ``transforms.RandomCrop``.
25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
26 | download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
27 | dataset is already downloaded, it is not downloaded again.
28 | """
29 |
30 | _FILES = {
31 | "train": {
32 | "images": (
33 | "camelyonpatch_level_2_split_train_x.h5", # Data file name
34 | "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID
35 | "1571f514728f59376b705fc836ff4b63", # md5 hash
36 | ),
37 | "targets": (
38 | "camelyonpatch_level_2_split_train_y.h5",
39 | "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG",
40 | "35c2d7259d906cfc8143347bb8e05be7",
41 | ),
42 | },
43 | "test": {
44 | "images": (
45 | "camelyonpatch_level_2_split_test_x.h5",
46 | "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_",
47 | "d5b63470df7cfa627aeec8b9dc0c066e",
48 | ),
49 | "targets": (
50 | "camelyonpatch_level_2_split_test_y.h5",
51 | "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP",
52 | "2b85f58b927af9964a4c15b8f7e8f179",
53 | ),
54 | },
55 | "val": {
56 | "images": (
57 | "camelyonpatch_level_2_split_valid_x.h5",
58 | "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3",
59 | "d8c2d60d490dbd479f8199bdfa0cf6ec",
60 | ),
61 | "targets": (
62 | "camelyonpatch_level_2_split_valid_y.h5",
63 | "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO",
64 | "60a7035772fbdb7f34eb86d4420cf66a",
65 | ),
66 | },
67 | }
68 |
69 | clip_prompts = [
70 | "This is a photo of a healthy lymph node tissue.",
71 | "This is a photo of a lymph node tumor tissue."
72 | ]
73 |
74 | def __init__(
75 | self,
76 | root: str,
77 | split: str = "train",
78 | transform: Optional[Callable] = None,
79 | target_transform: Optional[Callable] = None,
80 | download: bool = False,
81 | ):
82 | try:
83 | import h5py
84 |
85 | self.h5py = h5py
86 | except ImportError:
87 | raise RuntimeError(
88 | "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py"
89 | )
90 |
91 | self._split = verify_str_arg(split, "split", ("train", "test", "val"))
92 |
93 | super().__init__(root, transform=transform, target_transform=target_transform)
94 | self._base_folder = pathlib.Path(self.root) / "pcam"
95 |
96 | if download:
97 | self._download()
98 |
99 | if not self._check_exists():
100 | raise RuntimeError("Dataset not found. You can use download=True to download it")
101 |
102 | def __len__(self) -> int:
103 | images_file = self._FILES[self._split]["images"][0]
104 | with self.h5py.File(self._base_folder / images_file) as images_data:
105 | return images_data["x"].shape[0]
106 |
107 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
108 | images_file = self._FILES[self._split]["images"][0]
109 | with self.h5py.File(self._base_folder / images_file) as images_data:
110 | image = Image.fromarray(images_data["x"][idx]).convert("RGB")
111 |
112 | targets_file = self._FILES[self._split]["targets"][0]
113 | with self.h5py.File(self._base_folder / targets_file) as targets_data:
114 | target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1]
115 |
116 | if self.transform:
117 | image = self.transform(image)
118 | if self.target_transform:
119 | target = self.target_transform(target)
120 |
121 | return image, target
122 |
123 | def _check_exists(self) -> bool:
124 | images_file = self._FILES[self._split]["images"][0]
125 | targets_file = self._FILES[self._split]["targets"][0]
126 | return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file))
127 |
128 | def _download(self) -> None:
129 | if self._check_exists():
130 | return
131 |
132 | for file_name, file_id, md5 in self._FILES[self._split].values():
133 | archive_name = file_name + ".gz"
134 | download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5)
135 | _decompress(str(self._base_folder / archive_name))
136 |
--------------------------------------------------------------------------------
/code/replace/datasets/stanford_cars.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from typing import Callable, Optional, Any, Tuple
3 |
4 | from PIL import Image
5 |
6 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 |
10 | class StanfordCars(VisionDataset):
11 | """`Stanford Cars `_ Dataset
12 |
13 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is
14 | split into 8,144 training images and 8,041 testing images, where each class
15 | has been split roughly in a 50-50 split
16 |
17 | .. note::
18 |
19 | This class needs `scipy `_ to load target files from `.mat` format.
20 |
21 | Args:
22 | root (string): Root directory of dataset
23 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``.
24 | transform (callable, optional): A function/transform that takes in an PIL image
25 | and returns a transformed version. E.g, ``transforms.RandomCrop``
26 | target_transform (callable, optional): A function/transform that takes in the
27 | target and transforms it.
28 | download (bool, optional): If True, downloads the dataset from the internet and
29 | puts it in root directory. If dataset is already downloaded, it is not
30 | downloaded again."""
31 |
32 | def __init__(
33 | self,
34 | root: str,
35 | split: str = "train",
36 | transform: Optional[Callable] = None,
37 | target_transform: Optional[Callable] = None,
38 | download: bool = False,
39 | prompt_template = "A photo of a {}."
40 | ) -> None:
41 |
42 | try:
43 | import scipy.io as sio
44 | except ImportError:
45 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy")
46 |
47 | super().__init__(root, transform=transform, target_transform=target_transform)
48 |
49 | self._split = verify_str_arg(split, "split", ("train", "test"))
50 | self._base_folder = pathlib.Path(root) / "stanford_cars"
51 | devkit = self._base_folder / "devkit"
52 |
53 | if self._split == "train":
54 | self._annotations_mat_path = devkit / "cars_train_annos.mat"
55 | self._images_base_path = self._base_folder / "cars_train"
56 | else:
57 | self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat"
58 | self._images_base_path = self._base_folder / "cars_test"
59 |
60 | if download:
61 | self.download()
62 |
63 | if not self._check_exists():
64 | raise RuntimeError("Dataset not found. You can use download=True to download it")
65 |
66 | self._samples = [
67 | (
68 | str(self._images_base_path / annotation["fname"]),
69 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1
70 | )
71 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"]
72 | ]
73 |
74 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist()
75 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
76 |
77 | self.prompt_template = prompt_template
78 | self.clip_prompts = [
79 | prompt_template.format(label[:-5].lower().replace('_', ' ').replace('-', ' ')) \
80 | for label in self.classes
81 | ]
82 |
83 | def __len__(self) -> int:
84 | return len(self._samples)
85 |
86 | def __getitem__(self, idx: int) -> Tuple[Any, Any]:
87 | """Returns pil_image and class_id for given index"""
88 | image_path, target = self._samples[idx]
89 | pil_image = Image.open(image_path).convert("RGB")
90 |
91 | if self.transform is not None:
92 | pil_image = self.transform(pil_image)
93 | if self.target_transform is not None:
94 | target = self.target_transform(target)
95 | return pil_image, target
96 |
97 | def download(self) -> None:
98 | if self._check_exists():
99 | return
100 |
101 | download_and_extract_archive(
102 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz",
103 | download_root=str(self._base_folder),
104 | md5="c3b158d763b6e2245038c8ad08e45376",
105 | )
106 | if self._split == "train":
107 | download_and_extract_archive(
108 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz",
109 | download_root=str(self._base_folder),
110 | md5="065e5b463ae28d29e77c1b4b166cfe61",
111 | )
112 | else:
113 | download_and_extract_archive(
114 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz",
115 | download_root=str(self._base_folder),
116 | md5="4ce7ebf6a94d07f1952d94dd34c4d501",
117 | )
118 | download_url(
119 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat",
120 | root=str(self._base_folder),
121 | md5="b0a2b23655a3edd16d84508592a98d10",
122 | )
123 |
124 | def _check_exists(self) -> bool:
125 | if not (self._base_folder / "devkit").is_dir():
126 | return False
127 |
128 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir()
129 |
--------------------------------------------------------------------------------
/code/replace/datasets/sun397.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Any, Tuple, Callable, Optional
3 |
4 | import PIL.Image
5 |
6 | from torchvision.datasets.utils import download_and_extract_archive
7 | from torchvision.datasets.vision import VisionDataset
8 |
9 |
10 | class SUN397(VisionDataset):
11 | """`The SUN397 Data Set `_.
12 |
13 | The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of
14 | 397 categories with 108'754 images.
15 |
16 | Args:
17 | root (string): Root directory of the dataset.
18 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed
19 | version. E.g, ``transforms.RandomCrop``.
20 | target_transform (callable, optional): A function/transform that takes in the target and transforms it.
21 | download (bool, optional): If true, downloads the dataset from the internet and
22 | puts it in root directory. If dataset is already downloaded, it is not
23 | downloaded again.
24 | """
25 |
26 | _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz"
27 | _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a"
28 |
29 | def __init__(
30 | self,
31 | root: str,
32 | transform: Optional[Callable] = None,
33 | target_transform: Optional[Callable] = None,
34 | download: bool = False,
35 | ) -> None:
36 | super().__init__(root, transform=transform, target_transform=target_transform)
37 | self._data_dir = Path(self.root) / "SUN397"
38 |
39 | if download:
40 | self._download()
41 |
42 | if not self._check_exists():
43 | raise RuntimeError("Dataset not found. You can use download=True to download it")
44 |
45 | with open(self._data_dir / "ClassName.txt") as f:
46 | self.classes = [c[3:].strip() for c in f]
47 |
48 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes))))
49 | self._image_files = list(self._data_dir.rglob("sun_*.jpg"))
50 |
51 | self._labels = [
52 | self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files
53 | ]
54 |
55 | def __len__(self) -> int:
56 | return len(self._image_files)
57 |
58 | def __getitem__(self, idx) -> Tuple[Any, Any]:
59 | image_file, label = self._image_files[idx], self._labels[idx]
60 | image = PIL.Image.open(image_file).convert("RGB")
61 |
62 | if self.transform:
63 | image = self.transform(image)
64 |
65 | if self.target_transform:
66 | label = self.target_transform(label)
67 |
68 | return image, label
69 |
70 | def _check_exists(self) -> bool:
71 | return self._data_dir.is_dir()
72 |
73 | def _download(self) -> None:
74 | if self._check_exists():
75 | return
76 | download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5)
77 |
--------------------------------------------------------------------------------
/code/replace/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 | self.relu1 = nn.ReLU(inplace=True)
20 |
21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
22 | self.bn2 = nn.BatchNorm2d(planes)
23 | self.relu2 = nn.ReLU(inplace=True)
24 |
25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
26 |
27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
29 | self.relu3 = nn.ReLU(inplace=True)
30 |
31 | self.downsample = None
32 | self.stride = stride
33 |
34 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
36 | self.downsample = nn.Sequential(OrderedDict([
37 | ("-1", nn.AvgPool2d(stride)),
38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
39 | ("1", nn.BatchNorm2d(planes * self.expansion))
40 | ]))
41 |
42 | def forward(self, x: torch.Tensor):
43 | identity = x
44 |
45 | out = self.relu1(self.bn1(self.conv1(x)))
46 | out = self.relu2(self.bn2(self.conv2(out)))
47 | out = self.avgpool(out)
48 | out = self.bn3(self.conv3(out))
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu3(out)
55 | return out
56 |
57 |
58 | class AttentionPool2d(nn.Module):
59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
60 | super().__init__()
61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
62 | self.k_proj = nn.Linear(embed_dim, embed_dim)
63 | self.q_proj = nn.Linear(embed_dim, embed_dim)
64 | self.v_proj = nn.Linear(embed_dim, embed_dim)
65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
66 | self.num_heads = num_heads
67 |
68 | def forward(self, x):
69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
72 | x, _ = F.multi_head_attention_forward(
73 | query=x[:1], key=x, value=x,
74 | embed_dim_to_check=x.shape[-1],
75 | num_heads=self.num_heads,
76 | q_proj_weight=self.q_proj.weight,
77 | k_proj_weight=self.k_proj.weight,
78 | v_proj_weight=self.v_proj.weight,
79 | in_proj_weight=None,
80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
81 | bias_k=None,
82 | bias_v=None,
83 | add_zero_attn=False,
84 | dropout_p=0,
85 | out_proj_weight=self.c_proj.weight,
86 | out_proj_bias=self.c_proj.bias,
87 | use_separate_proj_weight=True,
88 | training=self.training,
89 | need_weights=False
90 | )
91 | return x.squeeze(0)
92 |
93 |
94 | class ModifiedResNet(nn.Module):
95 | """
96 | A ResNet class that is similar to torchvision's but contains the following changes:
97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
99 | - The final pooling layer is a QKV attention instead of an average pool
100 | """
101 |
102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
103 | super().__init__()
104 | self.output_dim = output_dim
105 | self.input_resolution = input_resolution
106 |
107 | # the 3-layer stem
108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
109 | self.bn1 = nn.BatchNorm2d(width // 2)
110 | self.relu1 = nn.ReLU(inplace=True)
111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
112 | self.bn2 = nn.BatchNorm2d(width // 2)
113 | self.relu2 = nn.ReLU(inplace=True)
114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
115 | self.bn3 = nn.BatchNorm2d(width)
116 | self.relu3 = nn.ReLU(inplace=True)
117 | self.avgpool = nn.AvgPool2d(2)
118 |
119 | # residual layers
120 | self._inplanes = width # this is a *mutable* variable used during construction
121 | self.layer1 = self._make_layer(width, layers[0])
122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
125 |
126 | embed_dim = width * 32 # the ResNet feature dimension
127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
128 |
129 | def _make_layer(self, planes, blocks, stride=1):
130 | layers = [Bottleneck(self._inplanes, planes, stride)]
131 |
132 | self._inplanes = planes * Bottleneck.expansion
133 | for _ in range(1, blocks):
134 | layers.append(Bottleneck(self._inplanes, planes))
135 |
136 | return nn.Sequential(*layers)
137 |
138 | def forward(self, x):
139 | def stem(x):
140 | x = self.relu1(self.bn1(self.conv1(x)))
141 | x = self.relu2(self.bn2(self.conv2(x)))
142 | x = self.relu3(self.bn3(self.conv3(x)))
143 | x = self.avgpool(x)
144 | return x
145 |
146 | x = x.type(self.conv1.weight.dtype)
147 | x = stem(x)
148 | x = self.layer1(x)
149 | x = self.layer2(x)
150 | x = self.layer3(x)
151 | x = self.layer4(x)
152 | x = self.attnpool(x)
153 |
154 | return x
155 |
156 |
157 | class LayerNorm(nn.LayerNorm):
158 | """Subclass torch's LayerNorm to handle fp16."""
159 |
160 | def forward(self, x: torch.Tensor):
161 | orig_type = x.dtype
162 | ret = super().forward(x.type(torch.float32))
163 | return ret.type(orig_type)
164 |
165 |
166 | class QuickGELU(nn.Module):
167 | def forward(self, x: torch.Tensor):
168 | return x * torch.sigmoid(1.702 * x)
169 |
170 |
171 | class ResidualAttentionBlock(nn.Module):
172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
173 | super().__init__()
174 |
175 | self.attn = nn.MultiheadAttention(d_model, n_head)
176 | self.ln_1 = LayerNorm(d_model)
177 | self.mlp = nn.Sequential(OrderedDict([
178 | ("c_fc", nn.Linear(d_model, d_model * 4)),
179 | ("gelu", QuickGELU()),
180 | ("c_proj", nn.Linear(d_model * 4, d_model))
181 | ]))
182 | self.ln_2 = LayerNorm(d_model)
183 | self.attn_mask = attn_mask
184 |
185 | def attention(self, x: torch.Tensor):
186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
188 |
189 | def forward(self, x: torch.Tensor):
190 | x = x + self.attention(self.ln_1(x))
191 | x = x + self.mlp(self.ln_2(x))
192 | return x
193 |
194 |
195 | class Transformer(nn.Module):
196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
197 | super().__init__()
198 | self.width = width
199 | self.layers = layers
200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
201 |
202 | def forward(self, x: torch.Tensor):
203 | return self.resblocks(x)
204 |
205 |
206 | class VisionTransformer(nn.Module):
207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, prompt_len: int):
208 | super().__init__()
209 | self.input_resolution = input_resolution
210 | self.output_dim = output_dim
211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
212 | self.prompt_len = prompt_len
213 | scale = width ** -0.5
214 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
216 | if prompt_len > 0:
217 | self.prompt_pos_embedding = nn.Parameter(scale * torch.zeros(prompt_len, width))
218 |
219 | self.ln_pre = LayerNorm(width)
220 |
221 | self.transformer = Transformer(width, layers, heads)
222 |
223 | self.ln_post = LayerNorm(width)
224 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
225 |
226 | def forward(self, x: torch.Tensor, ind_prompt: torch.Tensor = None):
227 | x = self.conv1(x) # shape = [*, width, grid, grid]
228 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
229 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
230 | if self.prompt_len>0 and ind_prompt != None:
231 | tmp_ind_prompt = ind_prompt + torch.zeros(x.shape[0],1,1, dtype=x.dtype, device=x.device)
232 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x, tmp_ind_prompt], dim=1) # shape = [*, grid ** 2 + 1, width]
233 | x = x + torch.cat([self.positional_embedding.to(x.dtype), self.prompt_pos_embedding], dim=0)
234 | else:
235 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
236 | x = x + self.positional_embedding.to(x.dtype)
237 |
238 | x = self.ln_pre(x)
239 |
240 | x = x.permute(1, 0, 2) # NLD -> LND
241 | x = self.transformer(x)
242 | x = x.permute(1, 0, 2) # LND -> NLD
243 |
244 | x = self.ln_post(x[:, 0, :])
245 |
246 | if self.proj is not None:
247 | x = x @ self.proj
248 |
249 | return x
250 |
251 |
252 | class CLIP(nn.Module):
253 | def __init__(self,
254 | embed_dim: int,
255 | # vision
256 | image_resolution: int,
257 | vision_layers: Union[Tuple[int, int, int, int], int],
258 | vision_width: int,
259 | vision_patch_size: int,
260 | # text
261 | context_length: int,
262 | vocab_size: int,
263 | transformer_width: int,
264 | transformer_heads: int,
265 | transformer_layers: int,
266 | prompt_len: int
267 | ):
268 | super().__init__()
269 |
270 | self.context_length = context_length
271 |
272 | if isinstance(vision_layers, (tuple, list)):
273 | vision_heads = vision_width * 32 // 64
274 | self.visual = ModifiedResNet(
275 | layers=vision_layers,
276 | output_dim=embed_dim,
277 | heads=vision_heads,
278 | input_resolution=image_resolution,
279 | width=vision_width,
280 | prompt_len=prompt_len
281 | )
282 | else:
283 | vision_heads = vision_width // 64
284 | self.visual = VisionTransformer(
285 | input_resolution=image_resolution,
286 | patch_size=vision_patch_size,
287 | width=vision_width,
288 | layers=vision_layers,
289 | heads=vision_heads,
290 | output_dim=embed_dim,
291 | prompt_len=prompt_len
292 | )
293 |
294 | self.transformer = Transformer(
295 | width=transformer_width,
296 | layers=transformer_layers,
297 | heads=transformer_heads,
298 | attn_mask=self.build_attention_mask()
299 | )
300 |
301 | self.vocab_size = vocab_size
302 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
303 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
304 | self.ln_final = LayerNorm(transformer_width)
305 |
306 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
307 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
308 |
309 | self.initialize_parameters()
310 |
311 | def initialize_parameters(self):
312 | nn.init.normal_(self.token_embedding.weight, std=0.02)
313 | nn.init.normal_(self.positional_embedding, std=0.01)
314 |
315 | if isinstance(self.visual, ModifiedResNet):
316 | if self.visual.attnpool is not None:
317 | std = self.visual.attnpool.c_proj.in_features ** -0.5
318 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
319 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
320 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
321 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
322 |
323 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
324 | for name, param in resnet_block.named_parameters():
325 | if name.endswith("bn3.weight"):
326 | nn.init.zeros_(param)
327 |
328 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
329 | attn_std = self.transformer.width ** -0.5
330 | fc_std = (2 * self.transformer.width) ** -0.5
331 | for block in self.transformer.resblocks:
332 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
333 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
334 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
335 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
336 |
337 | if self.text_projection is not None:
338 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
339 |
340 | def build_attention_mask(self):
341 | # lazily create causal attention mask, with full attention between the vision tokens
342 | # pytorch uses additive attention mask; fill with -inf
343 | mask = torch.empty(self.context_length, self.context_length)
344 | mask.fill_(float("-inf"))
345 | mask.triu_(1) # zero out the lower diagonal
346 | return mask
347 |
348 | @property
349 | def dtype(self):
350 | return self.visual.conv1.weight.dtype
351 |
352 | def encode_image(self, image, ind_prompt):
353 | if ind_prompt==None:
354 | return self.visual(image.type(self.dtype))
355 | return self.visual(image.type(self.dtype), ind_prompt.type(self.dtype))
356 |
357 | def encode_text(self, text):
358 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
359 |
360 | x = x + self.positional_embedding.type(self.dtype)
361 | x = x.permute(1, 0, 2) # NLD -> LND
362 | x = self.transformer(x)
363 | x = x.permute(1, 0, 2) # LND -> NLD
364 | x = self.ln_final(x).type(self.dtype)
365 |
366 | # x.shape = [batch_size, n_ctx, transformer.width]
367 | # take features from the eot embedding (eot_token is the highest number in each sequence)
368 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
369 |
370 | return x
371 |
372 | def forward(self, image, text, ind_prompt=None):
373 | image_features = self.encode_image(image, ind_prompt)
374 | text_features = self.encode_text(text)
375 |
376 | # normalized features
377 | image_features = image_features / image_features.norm(dim=1, keepdim=True)
378 | text_features = text_features / text_features.norm(dim=1, keepdim=True)
379 |
380 | # cosine similarity as logits
381 | logit_scale = self.logit_scale.exp()
382 | return image_features, logit_scale * text_features
383 |
384 | logits_per_image = logit_scale * image_features @ text_features.t()
385 | logits_per_text = logits_per_image.t()
386 |
387 | # shape = [global_batch_size, global_batch_size]
388 | return logits_per_image, logits_per_text
389 |
390 |
391 | def convert_weights(model: nn.Module):
392 | """Convert applicable model parameters to fp16"""
393 |
394 | def _convert_weights_to_fp16(l):
395 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
396 | l.weight.data = l.weight.data.half()
397 | if l.bias is not None:
398 | l.bias.data = l.bias.data.half()
399 |
400 | if isinstance(l, nn.MultiheadAttention):
401 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
402 | tensor = getattr(l, attr)
403 | if tensor is not None:
404 | tensor.data = tensor.data.half()
405 |
406 | for name in ["text_projection", "proj"]:
407 | if hasattr(l, name):
408 | attr = getattr(l, name)
409 | if attr is not None:
410 | attr.data = attr.data.half()
411 |
412 | model.apply(_convert_weights_to_fp16)
413 |
414 |
415 | def build_model(state_dict: dict, prompt_len: int):
416 | vit = "visual.proj" in state_dict
417 |
418 | if vit:
419 | vision_width = state_dict["visual.conv1.weight"].shape[0]
420 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
421 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
422 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
423 | image_resolution = vision_patch_size * grid_size
424 | else:
425 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
426 | vision_layers = tuple(counts)
427 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
428 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
429 | vision_patch_size = None
430 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
431 | image_resolution = output_width * 32
432 |
433 | embed_dim = state_dict["text_projection"].shape[1]
434 | context_length = state_dict["positional_embedding"].shape[0]
435 | vocab_size = state_dict["token_embedding.weight"].shape[0]
436 | transformer_width = state_dict["ln_final.weight"].shape[0]
437 | transformer_heads = transformer_width // 64
438 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
439 |
440 | model = CLIP(
441 | embed_dim,
442 | image_resolution, vision_layers, vision_width, vision_patch_size,
443 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, prompt_len
444 | )
445 |
446 | for key in ["input_resolution", "context_length", "vocab_size"]:
447 | if key in state_dict:
448 | del state_dict[key]
449 |
450 | convert_weights(model)
451 | model.load_state_dict(state_dict, strict=False)
452 | return model.eval()
453 |
--------------------------------------------------------------------------------
/code/replace/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8+n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1:49152-256-2+1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v+'' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + ( token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token+''
88 |
89 | while True:
90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
106 | new_word.append(first+second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/code/test_time_counterattack.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import argparse
4 | import os
5 | import time
6 | import random
7 | import logging
8 | from tqdm import tqdm
9 | from copy import deepcopy as dcopy
10 |
11 | import torch
12 | from torch.cuda.amp import GradScaler, autocast
13 |
14 | from replace import clip
15 | from models.prompters import TokenPrompter, NullPrompter
16 | from utils import *
17 | from attacks import *
18 | from func import clip_img_preprocessing, multiGPU_CLIP
19 |
20 | def parse_options():
21 | parser = argparse.ArgumentParser()
22 |
23 | parser.add_argument('--evaluate', type=bool, default=True) # eval mode
24 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size')
25 | parser.add_argument('--num_workers', type=int, default=32, help='num of workers to use')
26 | parser.add_argument('--cache', type=str, default='./cache')
27 |
28 | # test setting
29 | parser.add_argument('--test_set', default=[], type=str, nargs='*') # defaults to 17 datasets, if not specified
30 | parser.add_argument('--test_attack_type', type=str, default="pgd", choices=['pgd', 'CW', 'autoattack',])
31 | parser.add_argument('--test_eps', type=float, default=1,help='test attack budget')
32 | parser.add_argument('--test_numsteps', type=int, default=10)
33 | parser.add_argument('--test_stepsize', type=int, default=1)
34 |
35 | # model
36 | parser.add_argument('--model', type=str, default='clip')
37 | parser.add_argument('--arch', type=str, default='vit_b32')
38 | parser.add_argument('--method', type=str, default='null_patch',
39 | choices=['null_patch'], help='choose visual prompting method')
40 | parser.add_argument('--name', type=str, default='')
41 | parser.add_argument('--prompt_size', type=int, default=30, help='size for visual prompts')
42 | parser.add_argument('--add_prompt_size', type=int, default=0, help='size for additional visual prompts')
43 |
44 | # data
45 | parser.add_argument('--root', type=str, default='./data', help='dataset path')
46 | parser.add_argument('--dataset', type=str, default='tinyImageNet', help='dataset used for AFT methods')
47 | parser.add_argument('--image_size', type=int, default=224, help='image size')
48 |
49 | # TTC config
50 | parser.add_argument('--seed', type=int, default=0, help='seed for initializing training')
51 | parser.add_argument('--victim_resume', type=str, default=None, help='model weights of victim to attack.')
52 | parser.add_argument('--outdir', type=str, default=None, help='output directory for results')
53 | parser.add_argument('--tau_thres', type=float, default=0.2)
54 | parser.add_argument('--beta', type=float, default=2.,)
55 | parser.add_argument('--ttc_eps', type=float, default=4.)
56 | parser.add_argument('--ttc_numsteps', type=int, default=2)
57 | parser.add_argument('--ttc_stepsize', type=float, default=1.)
58 |
59 | args = parser.parse_args()
60 | return args
61 |
62 | def compute_tau(clip_visual, images, n):
63 | orig_feat = clip_visual(clip_img_preprocessing(images), None) # [bs, 512]
64 | noisy_feat = clip_visual(clip_img_preprocessing(images + n), None)
65 | diff_ratio = (noisy_feat - orig_feat).norm(dim=-1) / orig_feat.norm(dim=-1) # [bs]
66 | return diff_ratio
67 |
68 | def tau_thres_weighted_counterattacks(model, X, prompter, add_prompter, alpha, attack_iters,
69 | norm="l_inf", epsilon=0, visual_model_orig=None,
70 | tau_thres:float=None, beta:float=None, clip_visual=None):
71 | delta = torch.zeros_like(X)
72 | if epsilon <= 0.:
73 | return delta
74 |
75 | if norm == "l_inf":
76 | delta.uniform_(-epsilon, epsilon)
77 | elif norm == "l_2":
78 | delta.normal_()
79 | d_flat = delta.view(delta.size(0), -1)
80 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1)
81 | r = torch.zeros_like(n).uniform_(0, 1)
82 | delta *= r / n * epsilon
83 | else:
84 | raise ValueError
85 |
86 | delta = clamp(delta, lower_limit - X, upper_limit - X)
87 | delta.requires_grad = True
88 |
89 | if attack_iters == 0: # apply random noise (RN)
90 | return delta.data
91 |
92 | diff_ratio = compute_tau(clip_visual, X, delta.data) if clip_visual is not None else None
93 |
94 | # Freeze model parameters temporarily. Not necessary but for completeness of code
95 | tunable_param_names = []
96 | for n,p in model.module.named_parameters():
97 | if p.requires_grad:
98 | tunable_param_names.append(n)
99 | p.requires_grad = False
100 |
101 | prompt_token = add_prompter()
102 | with torch.no_grad():
103 | X_ori_reps = model.module.encode_image(
104 | prompter(clip_img_preprocessing(X)), prompt_token
105 | )
106 | X_ori_norm = torch.norm(X_ori_reps, dim=-1) # [ bs]
107 |
108 | deltas_per_step = []
109 | deltas_per_step.append(delta.data.clone())
110 |
111 | for _step_id in range(attack_iters):
112 |
113 | prompted_images = prompter(clip_img_preprocessing(X + delta))
114 | X_att_reps = model.module.encode_image(prompted_images, prompt_token)
115 | if _step_id == 0 and diff_ratio is None: # compute tau at the zero-th step
116 | feature_diff = X_att_reps - X_ori_reps # [bs, 512]
117 | diff_ratio = torch.norm(feature_diff, dim=-1) / X_ori_norm # [bs]
118 |
119 | scheme_sign = (tau_thres - diff_ratio).sign()
120 |
121 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum()
122 | grad = torch.autograd.grad(l2_loss, delta)[0]
123 | d = delta[:, :, :, :]
124 | g = grad[:, :, :, :]
125 | x = X[:, :, :, :]
126 |
127 | if norm == "l_inf":
128 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon)
129 | elif norm == "l_2":
130 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1)
131 | scaled_g = g / (g_norm + 1e-10)
132 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d)
133 | d = clamp(d, lower_limit - x, upper_limit - x)
134 | delta.data[:, :, :, :] = d
135 | deltas_per_step.append(delta.data.clone())
136 |
137 | Delta = torch.stack(deltas_per_step, dim=1) # [bs, numsteps+1, C, W, H]
138 |
139 | # create weights across steps
140 | weights = torch.arange(attack_iters+1).unsqueeze(0).expand(X.size(0), -1).to(device) # [bs, numsteps+1]
141 | weights = torch.exp(
142 | scheme_sign.view(-1, 1) * weights * beta
143 | ) # [bs, numsteps+1]
144 | weights /= weights.sum(dim=1, keepdim=True)
145 |
146 | weights_hard = torch.zeros_like(weights) # [bs, numsteps+1]
147 | weights_hard[:,0] = 1.
148 |
149 | weights = torch.where(scheme_sign.unsqueeze(1)>0, weights, weights_hard)
150 | weights = weights.view(X.size(0), attack_iters+1, 1, 1, 1)
151 |
152 | Delta = (weights * Delta).sum(dim=1)
153 |
154 | # Unfreeze model parameters. Only for completeness of code
155 | for n,p in model.module.named_parameters():
156 | if n in tunable_param_names:
157 | p.requires_grad = True
158 |
159 | return Delta
160 |
161 |
162 | def validate(args, val_dataset_name, model, model_text, model_image,
163 | prompter, add_prompter, criterion, visual_model_orig=None,
164 | clip_visual=None
165 | ):
166 |
167 | logging.info(f"Evaluate with Attack method: {args.test_attack_type}")
168 |
169 | dataset_num = len(val_dataset_name)
170 | all_clean_org, all_clean_ttc, all_adv_org, all_adv_ttc = {},{},{},{}
171 |
172 | test_stepsize = args.test_stepsize
173 |
174 | ttc_eps = args.ttc_eps
175 | ttc_numsteps = args.ttc_numsteps
176 | ttc_stepsize = args.ttc_stepsize
177 | beta = args.beta
178 | tau_thres = args.tau_thres
179 |
180 | for cnt in range(dataset_num):
181 | val_dataset, val_loader = load_val_dataset(args, val_dataset_name[cnt])
182 | dataset_name = val_dataset_name[cnt]
183 | texts = get_text_prompts_val([val_dataset], [dataset_name])[0]
184 |
185 | binary = ['PCAM', 'hateful_memes']
186 | attacks_to_run=['apgd-ce', 'apgd-dlr']
187 | if dataset_name in binary:
188 | attacks_to_run=['apgd-ce']
189 |
190 | batch_time = AverageMeter('Time', ':6.3f')
191 | losses = AverageMeter('Loss', ':.4e')
192 | top1_org = AverageMeter('Original Acc@1', ':6.2f')
193 | top1_org_ttc = AverageMeter('Prompt Acc@1', ':6.2f')
194 | top1_adv = AverageMeter('Adv Original Acc@1', ':6.2f')
195 | top1_adv_ttc = AverageMeter('Adv Prompt Acc@1', ':6.2f')
196 |
197 | # switch to evaluation mode
198 | prompter.eval()
199 | add_prompter.eval()
200 | model.eval()
201 |
202 | text_tokens = clip.tokenize(texts).to(device)
203 | end = time.time()
204 |
205 | for i, (images, target) in enumerate(tqdm(val_loader)):
206 |
207 | images = images.to(device)
208 | target = target.to(device)
209 |
210 | with autocast():
211 |
212 | # original acc of clean images
213 | with torch.no_grad():
214 | clean_output,_,_,_ = multiGPU_CLIP(
215 | None, None, None, model, prompter(clip_img_preprocessing(images)),
216 | text_tokens = text_tokens,
217 | prompt_token = None, dataset_name = dataset_name
218 | )
219 | clean_acc = accuracy(clean_output, target, topk=(1,))
220 | top1_org.update(clean_acc[0].item(), images.size(0))
221 |
222 | # TTC on clean images
223 | ttc_delta_clean = tau_thres_weighted_counterattacks(
224 | model, images, prompter, add_prompter,
225 | alpha=ttc_stepsize, attack_iters=ttc_numsteps,
226 | norm='l_inf', epsilon=ttc_eps, visual_model_orig=None,
227 | tau_thres=tau_thres, beta = beta,
228 | clip_visual=clip_visual
229 | )
230 | with torch.no_grad():
231 | clean_output_ttc,_,_,_ = multiGPU_CLIP(
232 | None, None, None, model, prompter(clip_img_preprocessing(images+ttc_delta_clean)),
233 | text_tokens = text_tokens,
234 | prompt_token = None, dataset_name = dataset_name
235 | )
236 | clean_acc_ttc = accuracy(clean_output_ttc, target, topk=(1,))
237 | top1_org_ttc.update(clean_acc_ttc[0].item(), images.size(0))
238 |
239 | # generate adv samples for this batch
240 | torch.cuda.empty_cache()
241 | if args.test_attack_type == "pgd":
242 | delta_prompt = attack_pgd(args, prompter, model, model_text, model_image, add_prompter, criterion,
243 | images, target, test_stepsize, args.test_numsteps, 'l_inf',
244 | text_tokens=text_tokens, epsilon=args.test_eps, dataset_name=dataset_name)
245 | attacked_images = images + delta_prompt
246 | elif args.test_attack_type == "CW":
247 | delta_prompt = attack_CW(args, prompter, model, model_text, model_image, add_prompter, criterion,
248 | images, target, text_tokens,
249 | test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps)
250 | attacked_images = images + delta_prompt
251 | elif args.test_attack_type == "autoattack":
252 | attacked_images = attack_auto(model, images, target, text_tokens,
253 | None, None, epsilon=args.test_eps, attacks_to_run=attacks_to_run)
254 |
255 | # acc of adv images without ttc
256 | with torch.no_grad():
257 | adv_output,_,_,_ = multiGPU_CLIP(
258 | None,None,None, model, prompter(clip_img_preprocessing(attacked_images)),
259 | text_tokens, prompt_token=None, dataset_name=dataset_name
260 | )
261 | adv_acc = accuracy(adv_output, target, topk=(1,))
262 | top1_adv.update(adv_acc[0].item(), images.size(0))
263 |
264 | ttc_delta_adv = tau_thres_weighted_counterattacks(
265 | model, attacked_images.data, prompter, add_prompter,
266 | alpha=ttc_stepsize, attack_iters=ttc_numsteps,
267 | norm='l_inf', epsilon=ttc_eps, visual_model_orig=None,
268 | tau_thres=tau_thres, beta = beta,
269 | clip_visual = clip_visual
270 | )
271 | with torch.no_grad():
272 | adv_output_ttc,_,_,_ = multiGPU_CLIP(
273 | None,None,None, model, prompter(clip_img_preprocessing(attacked_images+ttc_delta_adv)),
274 | text_tokens, prompt_token=None, dataset_name=dataset_name
275 | )
276 | adv_output_acc = accuracy(adv_output_ttc, target, topk=(1,))
277 | top1_adv_ttc.update(adv_output_acc[0].item(), images.size(0))
278 |
279 | batch_time.update(time.time() - end)
280 | end = time.time()
281 |
282 | torch.cuda.empty_cache()
283 | clean_acc = top1_org.avg
284 | clean_ttc_acc = top1_org_ttc.avg
285 | adv_acc = top1_adv.avg
286 | adv_ttc_acc = top1_adv_ttc.avg
287 |
288 | all_clean_org[dataset_name] = clean_acc
289 | all_clean_ttc[dataset_name] = clean_ttc_acc
290 | all_adv_org[dataset_name] = adv_acc
291 | all_adv_ttc[dataset_name] = adv_ttc_acc
292 |
293 | show_text = f"{dataset_name}:\n\t"
294 | show_text += f"- clean acc. {clean_acc:.2f} (ttc: {clean_ttc_acc:.2f})\n\t"
295 | show_text += f"- robust acc. {adv_acc:.2f} (ttc: {adv_ttc_acc:.2f})"
296 |
297 | logging.info(show_text)
298 |
299 | all_clean_org_avg = np.mean([all_clean_org[name] for name in all_clean_org]).item()
300 | all_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in all_clean_ttc]).item()
301 | all_adv_org_avg = np.mean([all_adv_org[name] for name in all_adv_org]).item()
302 | all_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in all_adv_ttc]).item()
303 | show_text = f"===== SUMMARY ACROSS {dataset_num} DATASETS =====\n\t"
304 | show_text += f"AVG acc. {all_clean_org_avg:.2f} (ttc: {all_clean_ttc_avg:.2f})\n\t"
305 | show_text += f"AVG acc. {all_adv_org_avg:.2f} (ttc: {all_adv_ttc_avg:.2f})"
306 | logging.info(show_text)
307 |
308 | # Exclude the dataset used for implementing AFT methods (tinyImageNet in the paper)
309 | zs_clean_org_avg = np.mean([all_clean_org[name] for name in val_dataset_name if name != args.dataset]).item()
310 | zs_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in val_dataset_name if name != args.dataset]).item()
311 | zs_adv_org_avg = np.mean([all_adv_org[name] for name in val_dataset_name if name != args.dataset]).item()
312 | zs_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in val_dataset_name if name != args.dataset]).item()
313 | valid_dataset_num = dataset_num - 1 if args.dataset in val_dataset_name else dataset_num
314 | show_text = f"===== SUMMARY ACROSS {valid_dataset_num} DATASETS (EXCEPT {args.dataset}) =====\n\t"
315 | show_text += f"AVG acc. {zs_clean_org_avg:.2f} (ttc: {zs_clean_ttc_avg:.2f})\n\t"
316 | show_text += f"AVG acc. {zs_adv_org_avg:.2f} (ttc: {zs_adv_ttc_avg:.2f})"
317 | logging.info(show_text)
318 |
319 | return all_clean_org_avg, all_clean_ttc_avg, all_adv_org_avg, all_adv_ttc_avg
320 |
321 | device = "cuda" if torch.cuda.is_available() else "cpu"
322 |
323 | def main():
324 |
325 | args = parse_options()
326 |
327 | outdir = args.outdir if args.outdir is not None else "TTC_results"
328 | outdir = os.path.join(outdir, f"{args.test_attack_type}_eps_{args.test_eps}_numsteps_{args.test_numsteps}")
329 | os.makedirs(outdir, exist_ok=True)
330 |
331 | args.test_eps = args.test_eps / 255.
332 | args.test_stepsize = args.test_stepsize / 255.
333 |
334 | seed = args.seed
335 | random.seed(seed)
336 | np.random.seed(seed)
337 | torch.manual_seed(seed)
338 | torch.cuda.manual_seed(seed)
339 | torch.cuda.manual_seed_all(seed)
340 | torch.backends.cudnn.deterministic = True
341 | torch.backends.cudnn.benchmark = False
342 |
343 | log_filename = ""
344 | log_filename += f"ttc_eps_{args.ttc_eps}_thres_{args.tau_thres}_beta_{args.beta}_numsteps_{args.ttc_numsteps}_stepsize_{int(args.ttc_stepsize)}_seed_{seed}.log".replace(" ", "")
345 | log_filename = os.path.join(outdir, log_filename)
346 | logging.basicConfig(
347 | filename = log_filename,
348 | level = logging.INFO,
349 | format="%(asctime)s - %(levelname)s - %(message)s"
350 | )
351 | logging.info(args)
352 |
353 | args.ttc_stepsize = args.ttc_stepsize / 255.
354 | args.ttc_eps = args.ttc_eps / 255.
355 |
356 | imagenet_root = './data/ImageNet'
357 | tinyimagenet_root = "./data/tiny-imagenet-200"
358 | args.imagenet_root = imagenet_root
359 | args.tinyimagenet_root = tinyimagenet_root
360 |
361 | # load model
362 | model, _ = clip.load('ViT-B/32', device, jit=False, prompt_len=0)
363 | for p in model.parameters():
364 | p.requires_grad = False
365 | convert_models_to_fp32(model)
366 |
367 | if args.victim_resume: # employ TTC on AFT checkpoints
368 | clip_visual = dcopy(model.visual)
369 | model = load_checkpoints2(args, args.victim_resume, model, None)
370 | else: # employ TTC on the original CLIP
371 | clip_visual = None
372 |
373 | model = torch.nn.DataParallel(model)
374 | model.eval()
375 | prompter = NullPrompter()
376 | add_prompter = TokenPrompter(0)
377 | prompter = torch.nn.DataParallel(prompter).cuda()
378 | add_prompter = torch.nn.DataParallel(add_prompter).cuda()
379 | logging.info("done loading model.")
380 |
381 | if len(args.test_set) == 0:
382 | test_set = DATASETS
383 | else:
384 | test_set = args.test_set
385 |
386 | # criterion to compute attack loss, the reduction of 'sum' is important for effective attacks
387 | criterion_attack = torch.nn.CrossEntropyLoss(reduction='sum').to(device)
388 |
389 | validate(
390 | args, test_set, model, None, None, prompter,
391 | add_prompter, criterion_attack, None, clip_visual
392 | )
393 |
394 | if __name__ == "__main__":
395 | main()
396 |
--------------------------------------------------------------------------------
/code/utils.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | import torch
4 | import numpy as np
5 | import torchvision.transforms as transforms
6 | from torchvision.datasets import CIFAR10, CIFAR100, STL10, ImageFolder
7 | from replace.datasets import caltech, country211, dtd,eurosat, fgvc_aircraft, food101, \
8 | flowers102, oxford_iiit_pet, pcam, stanford_cars, sun397
9 | from torch.utils.data import DataLoader
10 | import json
11 |
12 | def convert_models_to_fp32(model):
13 | for p in model.parameters():
14 | p.data = p.data.float()
15 | if p.grad:
16 | p.grad.data = p.grad.data.float()
17 |
18 |
19 | def refine_classname(class_names):
20 | for i, class_name in enumerate(class_names):
21 | class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ')
22 | return class_names
23 |
24 | def save_checkpoint(state, save_folder, is_best=False, filename='checkpoint.pth.tar'):
25 | savefile = os.path.join(save_folder, filename)
26 | bestfile = os.path.join(save_folder, 'model_best.pth.tar')
27 | torch.save(state, savefile)
28 | if is_best:
29 | shutil.copyfile(savefile, bestfile)
30 | print ('saved best file')
31 |
32 | def assign_learning_rate(optimizer, new_lr, tgt_group_idx=None):
33 | for group_idx, param_group in enumerate(optimizer.param_groups):
34 | if tgt_group_idx is None or tgt_group_idx==group_idx:
35 | param_group["lr"] = new_lr
36 |
37 | def _warmup_lr(base_lr, warmup_length, step):
38 | return base_lr * (step + 1) / warmup_length
39 |
40 | def cosine_lr(optimizer, base_lr, warmup_length, steps, tgt_group_idx=None):
41 | def _lr_adjuster(step):
42 | if step < warmup_length:
43 | lr = _warmup_lr(base_lr, warmup_length, step)
44 | else:
45 | e = step - warmup_length
46 | es = steps - warmup_length
47 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr
48 | assign_learning_rate(optimizer, lr, tgt_group_idx)
49 | return lr
50 | return _lr_adjuster
51 |
52 | def null_scheduler(init_lr):
53 | return lambda step:init_lr
54 |
55 | def accuracy(output, target, topk=(1,)):
56 | """Computes the accuracy over the k top predictions for the specified values of k"""
57 | with torch.no_grad():
58 | maxk = max(topk)
59 | batch_size = target.size(0)
60 |
61 | _, pred = output.topk(maxk, 1, True, True)
62 | pred = pred.t()
63 | correct = pred.eq(target.view(1, -1).expand_as(pred))
64 |
65 | res = []
66 | for k in topk:
67 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
68 | res.append(correct_k.mul_(100.0 / batch_size))
69 | return res
70 |
71 |
72 | class AverageMeter(object):
73 | """Computes and stores the average and current value"""
74 | def __init__(self, name, fmt=':f'):
75 | self.name = name
76 | self.fmt = fmt
77 | self.reset()
78 |
79 | def reset(self):
80 | self.val = 0
81 | self.avg = 0
82 | self.sum = 0
83 | self.count = 0
84 |
85 | def update(self, val, n=1):
86 | self.val = val
87 | self.sum += val * n
88 | self.count += n
89 | self.avg = self.sum / self.count
90 |
91 | def __str__(self):
92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
93 | return fmtstr.format(**self.__dict__)
94 |
95 |
96 | class ProgressMeter(object):
97 | def __init__(self, num_batches, meters, prefix=""):
98 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
99 | self.meters = meters
100 | self.prefix = prefix
101 |
102 | def display(self, batch):
103 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
104 | entries += [str(meter) for meter in self.meters]
105 | print('\t'.join(entries))
106 |
107 | def _get_batch_fmtstr(self, num_batches):
108 | num_digits = len(str(num_batches // 1))
109 | fmt = '{:' + str(num_digits) + 'd}'
110 | return '[' + fmt + '/' + fmt.format(num_batches) + ']'
111 |
112 |
113 | def load_imagenet_folder2name(path):
114 | dict_imagenet_folder2name = {}
115 | with open(path) as f:
116 | line = f.readline()
117 | while line:
118 | split_name = line.strip().split()
119 | cat_name = split_name[2]
120 | id = split_name[0]
121 | dict_imagenet_folder2name[id] = cat_name
122 | line = f.readline()
123 | # print(dict_imagenet_folder2name)
124 | return dict_imagenet_folder2name
125 |
126 |
127 |
128 | def one_hot_embedding(labels, num_classes):
129 | """Embedding labels to one-hot form.
130 | Args:
131 | labels: (LongTensor) class labels, sized [N,].
132 | num_classes: (int) number of classes.
133 | Returns:
134 | (tensor) encoded labels, sized [N, #classes].
135 | """
136 | y = torch.eye(num_classes)
137 | return y[labels.cpu()]
138 |
139 |
140 | preprocess = transforms.Compose([
141 | transforms.ToTensor()
142 | ])
143 | preprocess224 = transforms.Compose([
144 | transforms.Resize(256),
145 | transforms.CenterCrop(224),
146 | transforms.ToTensor()
147 | ])
148 | preprocess224_caltech = transforms.Compose([
149 | transforms.Resize(256),
150 | transforms.CenterCrop(224),
151 | transforms.Lambda(lambda img: img.convert("RGB")),
152 | transforms.ToTensor()
153 | ])
154 | preprocess224_interpolate = transforms.Compose([
155 | transforms.Resize((224, 224)),
156 | transforms.ToTensor()
157 | ])
158 |
159 | def load_train_dataset(args):
160 | if args.dataset == 'cifar100':
161 | return CIFAR100(args.root, transform=preprocess, download=True, train=True)
162 | elif args.dataset == 'cifar10':
163 | return CIFAR10(args.root, transform=preprocess, download=True, train=True)
164 | elif args.dataset == 'ImageNet':
165 | assert args.imagenet_root is not None
166 | print(f"Loading ImageNet from {args.imagenet_root}")
167 | return ImageFolder(os.path.join(args.imagenet_root, 'train'), transform=preprocess224)
168 | else:
169 | print(f"Train dataset {args.dataset} not implemented")
170 | raise NotImplementedError
171 |
172 | def get_eval_files(dataset_name):
173 | # only for imaegnet and tinyimagenet
174 | refined_data_file = f"./support/{dataset_name.lower()}_refined_labels.json"
175 | refined_data = read_json(refined_data_file)
176 | eval_select = {ssid:refined_data[ssid]['eval_files'] for ssid in refined_data}
177 | return eval_select
178 |
179 | def load_val_dataset(args, val_dataset_name):
180 |
181 | if val_dataset_name == 'cifar10':
182 | val_dataset = CIFAR10(args.root, transform=preprocess224, download=True, train=False)
183 |
184 | elif val_dataset_name == 'cifar100':
185 | val_dataset = CIFAR100(args.root, transform=preprocess224, download=True, train=False)
186 |
187 | elif val_dataset_name == 'Caltech101':
188 | val_dataset = caltech.Caltech101(args.root, target_type='category', transform=preprocess224_caltech, download=True)
189 |
190 | elif val_dataset_name == 'PCAM':
191 | val_dataset = pcam.PCAM(args.root, split='test', transform=preprocess224, download=True)
192 |
193 | elif val_dataset_name == 'STL10':
194 | val_dataset = STL10(args.root, split='test', transform=preprocess224, download=True)
195 |
196 | elif val_dataset_name == 'SUN397':
197 | val_dataset = sun397.SUN397(args.root, transform=preprocess224, download=True)
198 |
199 | elif val_dataset_name == 'StanfordCars':
200 | val_dataset = stanford_cars.StanfordCars(args.root, split='test', transform=preprocess224, download=True)
201 |
202 | elif val_dataset_name == 'Food101':
203 | val_dataset = food101.Food101(args.root, split='test', transform=preprocess224, download=True)
204 |
205 | elif val_dataset_name == 'oxfordpet':
206 | val_dataset = oxford_iiit_pet.OxfordIIITPet(args.root, split='test', transform=preprocess224, download=True)
207 |
208 | elif val_dataset_name == 'EuroSAT':
209 | val_dataset = eurosat.EuroSAT(args.root, transform=preprocess224, download=True)
210 |
211 | elif val_dataset_name == 'Caltech256':
212 | val_dataset = caltech.Caltech256(args.root, transform=preprocess224_caltech, download=True)
213 |
214 | elif val_dataset_name == 'flowers102':
215 | val_dataset = flowers102.Flowers102(args.root, split='test', transform=preprocess224, download=True)
216 |
217 | elif val_dataset_name == 'Country211':
218 | val_dataset = country211.Country211(args.root, split='test', transform=preprocess224, download=True)
219 |
220 | elif val_dataset_name == 'dtd':
221 | val_dataset = dtd.DTD(args.root, split='test', transform=preprocess224, download=True)
222 |
223 | elif val_dataset_name == 'fgvc_aircraft':
224 | val_dataset = fgvc_aircraft.FGVCAircraft(args.root, split='test', transform=preprocess224, download=True)
225 |
226 | elif val_dataset_name == 'ImageNet':
227 | from replace.datasets.folder import ImageNetFolder
228 | if args.evaluate:
229 | val_dataset = ImageNetFolder(os.path.join(args.imagenet_root, 'val'), transform=preprocess224)
230 | else:
231 | eval_select = get_eval_files(val_dataset_name)
232 | val_dataset = ImageNetFolder(os.path.join(args.imagenet_root, 'train'), transform=preprocess224, select_files=eval_select)
233 |
234 | elif val_dataset_name == 'tinyImageNet':
235 | from replace.datasets.folder import ImageNetFolder
236 | if args.evaluate:
237 | val_dataset = ImageNetFolder(os.path.join(args.tinyimagenet_root, 'val_'), transform=preprocess224)
238 | else:
239 | eval_select = get_eval_files(val_dataset_name)
240 | val_dataset = ImageNetFolder(os.path.join(args.tinyimagenet_root, 'train'), transform=preprocess224, select_files=eval_select)
241 | else:
242 | print(f"Val dataset {val_dataset_name} not implemented")
243 | raise NotImplementedError
244 |
245 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
246 | num_workers=args.num_workers, shuffle=False,)
247 |
248 | return val_dataset, val_loader
249 |
250 | def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'):
251 | class_names = train_dataset.classes
252 | if args.dataset == 'ImageNet':
253 | folder2name = load_imagenet_folder2name('support/imagenet_classes_names.txt')
254 | new_class_names = []
255 | for each in class_names:
256 | new_class_names.append(folder2name[each])
257 |
258 | class_names = new_class_names
259 |
260 | class_names = refine_classname(class_names)
261 | texts_train = [template.format(label) for label in class_names]
262 | return texts_train
263 |
264 | def get_text_prompts_val(val_dataset_list, val_dataset_name, template='This is a photo of a {}.'):
265 | texts_list = []
266 | for cnt, each in enumerate(val_dataset_list):
267 | if hasattr(each, 'clip_prompts'):
268 | texts_tmp = each.clip_prompts
269 | else:
270 | class_names = each.classes if hasattr(each, 'classes') else each.clip_categories
271 | if val_dataset_name[cnt] in ['ImageNet', 'tinyImageNet']:
272 | refined_data = read_json(f"./support/{val_dataset_name[cnt].lower()}_refined_labels.json")
273 | clean_class_names = [refined_data[ssid]['clean_name'] for ssid in class_names]
274 | class_names = clean_class_names
275 |
276 | texts_tmp = [template.format(label) for label in class_names]
277 | texts_list.append(texts_tmp)
278 | assert len(texts_list) == len(val_dataset_list)
279 | return texts_list
280 |
281 |
282 | def read_json(json_file:str):
283 | with open(json_file, 'r') as f:
284 | data = json.load(f)
285 | return data
286 |
287 | # def get_prompts(class_names):
288 | # # consider using correct articles
289 | # template = "This is a photo of a {}."
290 | # template_v = "This is a photo of an {}."
291 | # prompts = []
292 | # for class_name in class_names:
293 | # if class_name[0].lower() in ['a','e','i','o','u'] or class_name == "hourglass":
294 | # prompts.append(template_v.format(class_name))
295 | # else:
296 | # prompts.append(template.format(class_name))
297 | # return prompts
298 |
299 | def freeze(model:torch.nn.Module):
300 | for param in model.parameters():
301 | param.requires_grad=False
302 | return
303 |
304 | DATASETS = [
305 | 'cifar10', 'cifar100', 'STL10','ImageNet',
306 | 'Caltech101', 'Caltech256', 'oxfordpet', 'flowers102', 'fgvc_aircraft',
307 | 'StanfordCars', 'SUN397', 'Country211', 'Food101', 'EuroSAT',
308 | 'dtd', 'PCAM', 'tinyImageNet',
309 | ]
310 |
311 |
312 | def write_file(txt:str, file:str, mode='a'):
313 | with open(file, mode) as f:
314 | f.write(txt)
315 |
316 |
317 | def load_resume_file(file:str, gpu:int):
318 | if os.path.isfile(file):
319 | print("=> loading checkpoint '{}'".format(file))
320 | if gpu is None:
321 | checkpoint = torch.load(file)
322 | else:
323 | loc = 'cuda:{}'.format(gpu)
324 | checkpoint = torch.load(file, map_location=loc)
325 | print("=> loaded checkpoint '{}' (epoch {})".format(file, checkpoint['epoch']))
326 | return checkpoint
327 | else:
328 | print("=> no checkpoint found at '{}'".format(file))
329 | return None
330 |
331 |
332 | def load_checkpoints2(args, resume_file, model, optimizer=None):
333 | checkpoint = load_resume_file(resume_file, None)
334 | try:
335 | model.module.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False)
336 | except:
337 | model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False)
338 | if optimizer is not None:
339 | optimizer.load_state_dict(checkpoint['optimizer'])
340 | return model
341 |
--------------------------------------------------------------------------------
/data/readme.md:
--------------------------------------------------------------------------------
1 | This directory is the root path for all datasets. Please download raw datasets in this folder.
2 |
--------------------------------------------------------------------------------
/download_weights.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gdown
3 |
4 | file_ids = {
5 | "FARE": "1IMtb5SG1ajYphR8cK-3w3Nr7zvAHa8bi",
6 | "TeCoA": "1m4Iw9pCjtBHj7OVqHFlRu2OkO0rd7C7j",
7 | "PMG_AFT": "1JMXdMheNaYWiwqcWI0tvRrp6MBweQagn",
8 | }
9 |
10 | os.makedirs("./AFT_model_weights", exist_ok=True)
11 |
12 | for name, fid in file_ids.items():
13 | url = f"https://drive.google.com/uc?id={fid}"
14 | output = f"./AFT_model_weights/{name}.pth.tar"
15 | gdown.download(url, output, quiet=False)
16 | print(f"Downloaded {name} weights to {output}")
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: TTC
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - blas=1.0=mkl
10 | - brotli-python=1.0.9=py311h6a678d5_9
11 | - bzip2=1.0.8=h5eee18b_6
12 | - c-ares=1.19.1=h5eee18b_0
13 | - ca-certificates=2025.2.25=h06a4308_0
14 | - certifi=2025.1.31=py311h06a4308_0
15 | - charset-normalizer=3.3.2=pyhd3eb1b0_0
16 | - cuda-cudart=11.8.89=0
17 | - cuda-cupti=11.8.87=0
18 | - cuda-libraries=11.8.0=0
19 | - cuda-nvrtc=11.8.89=0
20 | - cuda-nvtx=11.8.86=0
21 | - cuda-runtime=11.8.0=0
22 | - cuda-version=12.8=3
23 | - ffmpeg=4.3=hf484d3e_0
24 | - filelock=3.13.1=py311h06a4308_0
25 | - freetype=2.12.1=h4a9f257_0
26 | - ftfy=5.8=py_0
27 | - gmp=6.3.0=h6a678d5_0
28 | - gmpy2=2.2.1=py311h5eee18b_0
29 | - gnutls=3.6.15=he1e5248_0
30 | - h5py=3.12.1=py311h5842655_1
31 | - hdf5=1.14.5=h2b7332f_2
32 | - idna=3.7=py311h06a4308_0
33 | - intel-openmp=2023.1.0=hdb19cb5_46306
34 | - jinja2=3.1.5=py311h06a4308_0
35 | - jpeg=9e=h5eee18b_3
36 | - krb5=1.20.1=h143b758_1
37 | - lame=3.100=h7b6447c_0
38 | - lcms2=2.16=hb9589c4_0
39 | - ld_impl_linux-64=2.40=h12ee557_0
40 | - lerc=4.0.0=h6a678d5_0
41 | - libcublas=11.11.3.6=0
42 | - libcufft=10.9.0.58=0
43 | - libcufile=1.13.1.3=0
44 | - libcurand=10.3.9.90=0
45 | - libcurl=8.12.1=hc9e6f67_0
46 | - libcusolver=11.4.1.48=0
47 | - libcusparse=11.7.5.86=0
48 | - libdeflate=1.22=h5eee18b_0
49 | - libedit=3.1.20230828=h5eee18b_0
50 | - libev=4.33=h7f8727e_1
51 | - libffi=3.4.4=h6a678d5_1
52 | - libgcc-ng=11.2.0=h1234567_1
53 | - libgfortran-ng=11.2.0=h00389a5_1
54 | - libgfortran5=11.2.0=h1234567_1
55 | - libgomp=11.2.0=h1234567_1
56 | - libiconv=1.16=h5eee18b_3
57 | - libidn2=2.3.4=h5eee18b_0
58 | - libnghttp2=1.57.0=h2d74bed_0
59 | - libnpp=11.8.0.86=0
60 | - libnvjpeg=11.9.0.86=0
61 | - libpng=1.6.39=h5eee18b_0
62 | - libssh2=1.11.1=h251f7ec_0
63 | - libstdcxx-ng=11.2.0=h1234567_1
64 | - libtasn1=4.19.0=h5eee18b_0
65 | - libtiff=4.5.1=hffd6297_1
66 | - libunistring=0.9.10=h27cfd23_0
67 | - libuuid=1.41.5=h5eee18b_0
68 | - libwebp-base=1.3.2=h5eee18b_1
69 | - lz4-c=1.9.4=h6a678d5_1
70 | - markupsafe=3.0.2=py311h5eee18b_0
71 | - mkl=2023.1.0=h213fc3f_46344
72 | - mkl-service=2.4.0=py311h5eee18b_2
73 | - mkl_fft=1.3.11=py311h5eee18b_0
74 | - mkl_random=1.2.8=py311ha02d727_0
75 | - mpc=1.3.1=h5eee18b_0
76 | - mpfr=4.2.1=h5eee18b_0
77 | - mpmath=1.3.0=py311h06a4308_0
78 | - ncurses=6.4=h6a678d5_0
79 | - nettle=3.7.3=hbbd107a_1
80 | - networkx=3.4.2=py311h06a4308_0
81 | - numpy=1.26.0=py311h08b1b3b_0
82 | - numpy-base=1.26.0=py311hf175353_0
83 | - openh264=2.1.1=h4ff587b_0
84 | - openjpeg=2.5.2=he7f1fd0_0
85 | - openssl=3.0.16=h5eee18b_0
86 | - pillow=11.1.0=py311hcea889d_0
87 | - pip=25.0=py311h06a4308_0
88 | - pysocks=1.7.1=py311h06a4308_0
89 | - python=3.11.11=he870216_0
90 | - pytorch=2.0.1=py3.11_cuda11.8_cudnn8.7.0_0
91 | - pytorch-cuda=11.8=h7e8668a_6
92 | - pytorch-mutex=1.0=cuda
93 | - readline=8.2=h5eee18b_0
94 | - regex=2024.11.6=py311h5eee18b_0
95 | - requests=2.32.3=py311h06a4308_1
96 | - scipy=1.15.1=py311h08b1b3b_0
97 | - setuptools=72.1.0=py311h06a4308_0
98 | - sqlite=3.45.3=h5eee18b_0
99 | - sympy=1.13.3=py311h06a4308_1
100 | - tbb=2021.8.0=hdb19cb5_0
101 | - tk=8.6.14=h39e8969_0
102 | - torchtriton=2.0.0=py311
103 | - torchvision=0.15.2=py311_cu118
104 | - tqdm=4.67.1=py311h92b7b1e_0
105 | - typing_extensions=4.12.2=py311h06a4308_0
106 | - tzdata=2025a=h04d1e81_0
107 | - urllib3=2.3.0=py311h06a4308_0
108 | - wcwidth=0.2.5=pyhd3eb1b0_0
109 | - wheel=0.45.1=py311h06a4308_0
110 | - xz=5.6.4=h5eee18b_1
111 | - zlib=1.2.13=h5eee18b_1
112 | - zstd=1.5.6=hc292b87_0
113 | prefix: /home/sxing/miniconda3/envs/TTC
114 |
--------------------------------------------------------------------------------
/figures/fig2b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/figures/fig2b.png
--------------------------------------------------------------------------------
/figures/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/figures/teaser.png
--------------------------------------------------------------------------------
/poster_CVPR_XING.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/poster_CVPR_XING.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | git+https://github.com/openai/CLIP.git
2 | git+https://github.com/fra31/auto-attack
3 | einops==0.8.1
4 | pycountry==24.6.1
5 | packaging==24.2
6 | gdown==5.2.0
7 |
--------------------------------------------------------------------------------
/support/imagenet_classes_names.txt:
--------------------------------------------------------------------------------
1 | n02119789 1 kit_fox
2 | n02100735 2 English_setter
3 | n02110185 3 Siberian_husky
4 | n02096294 4 Australian_terrier
5 | n02102040 5 English_springer
6 | n02066245 6 grey_whale
7 | n02509815 7 lesser_panda
8 | n02124075 8 Egyptian_cat
9 | n02417914 9 ibex
10 | n02123394 10 Persian_cat
11 | n02125311 11 cougar
12 | n02423022 12 gazelle
13 | n02346627 13 porcupine
14 | n02077923 14 sea_lion
15 | n02110063 15 malamute
16 | n02447366 16 badger
17 | n02109047 17 Great_Dane
18 | n02089867 18 Walker_hound
19 | n02102177 19 Welsh_springer_spaniel
20 | n02091134 20 whippet
21 | n02092002 21 Scottish_deerhound
22 | n02071294 22 killer_whale
23 | n02442845 23 mink
24 | n02504458 24 African_elephant
25 | n02092339 25 Weimaraner
26 | n02098105 26 soft-coated_wheaten_terrier
27 | n02096437 27 Dandie_Dinmont
28 | n02114712 28 red_wolf
29 | n02105641 29 Old_English_sheepdog
30 | n02128925 30 jaguar
31 | n02091635 31 otterhound
32 | n02088466 32 bloodhound
33 | n02096051 33 Airedale
34 | n02117135 34 hyena
35 | n02138441 35 meerkat
36 | n02097130 36 giant_schnauzer
37 | n02493509 37 titi
38 | n02457408 38 three-toed_sloth
39 | n02389026 39 sorrel
40 | n02443484 40 black-footed_ferret
41 | n02110341 41 dalmatian
42 | n02089078 42 black-and-tan_coonhound
43 | n02086910 43 papillon
44 | n02445715 44 skunk
45 | n02093256 45 Staffordshire_bullterrier
46 | n02113978 46 Mexican_hairless
47 | n02106382 47 Bouvier_des_Flandres
48 | n02441942 48 weasel
49 | n02113712 49 miniature_poodle
50 | n02113186 50 Cardigan
51 | n02105162 51 malinois
52 | n02415577 52 bighorn
53 | n02356798 53 fox_squirrel
54 | n02488702 54 colobus
55 | n02123159 55 tiger_cat
56 | n02098413 56 Lhasa
57 | n02422699 57 impala
58 | n02114855 58 coyote
59 | n02094433 59 Yorkshire_terrier
60 | n02111277 60 Newfoundland
61 | n02132136 61 brown_bear
62 | n02119022 62 red_fox
63 | n02091467 63 Norwegian_elkhound
64 | n02106550 64 Rottweiler
65 | n02422106 65 hartebeest
66 | n02091831 66 Saluki
67 | n02120505 67 grey_fox
68 | n02104365 68 schipperke
69 | n02086079 69 Pekinese
70 | n02112706 70 Brabancon_griffon
71 | n02098286 71 West_Highland_white_terrier
72 | n02095889 72 Sealyham_terrier
73 | n02484975 73 guenon
74 | n02137549 74 mongoose
75 | n02500267 75 indri
76 | n02129604 76 tiger
77 | n02090721 77 Irish_wolfhound
78 | n02396427 78 wild_boar
79 | n02108000 79 EntleBucher
80 | n02391049 80 zebra
81 | n02412080 81 ram
82 | n02108915 82 French_bulldog
83 | n02480495 83 orangutan
84 | n02110806 84 basenji
85 | n02128385 85 leopard
86 | n02107683 86 Bernese_mountain_dog
87 | n02085936 87 Maltese_dog
88 | n02094114 88 Norfolk_terrier
89 | n02087046 89 toy_terrier
90 | n02100583 90 vizsla
91 | n02096177 91 cairn
92 | n02494079 92 squirrel_monkey
93 | n02105056 93 groenendael
94 | n02101556 94 clumber
95 | n02123597 95 Siamese_cat
96 | n02481823 96 chimpanzee
97 | n02105505 97 komondor
98 | n02088094 98 Afghan_hound
99 | n02085782 99 Japanese_spaniel
100 | n02489166 100 proboscis_monkey
101 | n02364673 101 guinea_pig
102 | n02114548 102 white_wolf
103 | n02134084 103 ice_bear
104 | n02480855 104 gorilla
105 | n02090622 105 borzoi
106 | n02113624 106 toy_poodle
107 | n02093859 107 Kerry_blue_terrier
108 | n02403003 108 ox
109 | n02097298 109 Scotch_terrier
110 | n02108551 110 Tibetan_mastiff
111 | n02493793 111 spider_monkey
112 | n02107142 112 Doberman
113 | n02096585 113 Boston_bull
114 | n02107574 114 Greater_Swiss_Mountain_dog
115 | n02107908 115 Appenzeller
116 | n02086240 116 Shih-Tzu
117 | n02102973 117 Irish_water_spaniel
118 | n02112018 118 Pomeranian
119 | n02093647 119 Bedlington_terrier
120 | n02397096 120 warthog
121 | n02437312 121 Arabian_camel
122 | n02483708 122 siamang
123 | n02097047 123 miniature_schnauzer
124 | n02106030 124 collie
125 | n02099601 125 golden_retriever
126 | n02093991 126 Irish_terrier
127 | n02110627 127 affenpinscher
128 | n02106166 128 Border_collie
129 | n02326432 129 hare
130 | n02108089 130 boxer
131 | n02097658 131 silky_terrier
132 | n02088364 132 beagle
133 | n02111129 133 Leonberg
134 | n02100236 134 German_short-haired_pointer
135 | n02486261 135 patas
136 | n02115913 136 dhole
137 | n02486410 137 baboon
138 | n02487347 138 macaque
139 | n02099849 139 Chesapeake_Bay_retriever
140 | n02108422 140 bull_mastiff
141 | n02104029 141 kuvasz
142 | n02492035 142 capuchin
143 | n02110958 143 pug
144 | n02099429 144 curly-coated_retriever
145 | n02094258 145 Norwich_terrier
146 | n02099267 146 flat-coated_retriever
147 | n02395406 147 hog
148 | n02112350 148 keeshond
149 | n02109961 149 Eskimo_dog
150 | n02101388 150 Brittany_spaniel
151 | n02113799 151 standard_poodle
152 | n02095570 152 Lakeland_terrier
153 | n02128757 153 snow_leopard
154 | n02101006 154 Gordon_setter
155 | n02115641 155 dingo
156 | n02097209 156 standard_schnauzer
157 | n02342885 157 hamster
158 | n02097474 158 Tibetan_terrier
159 | n02120079 159 Arctic_fox
160 | n02095314 160 wire-haired_fox_terrier
161 | n02088238 161 basset
162 | n02408429 162 water_buffalo
163 | n02133161 163 American_black_bear
164 | n02328150 164 Angora
165 | n02410509 165 bison
166 | n02492660 166 howler_monkey
167 | n02398521 167 hippopotamus
168 | n02112137 168 chow
169 | n02510455 169 giant_panda
170 | n02093428 170 American_Staffordshire_terrier
171 | n02105855 171 Shetland_sheepdog
172 | n02111500 172 Great_Pyrenees
173 | n02085620 173 Chihuahua
174 | n02123045 174 tabby
175 | n02490219 175 marmoset
176 | n02099712 176 Labrador_retriever
177 | n02109525 177 Saint_Bernard
178 | n02454379 178 armadillo
179 | n02111889 179 Samoyed
180 | n02088632 180 bluetick
181 | n02090379 181 redbone
182 | n02443114 182 polecat
183 | n02361337 183 marmot
184 | n02105412 184 kelpie
185 | n02483362 185 gibbon
186 | n02437616 186 llama
187 | n02107312 187 miniature_pinscher
188 | n02325366 188 wood_rabbit
189 | n02091032 189 Italian_greyhound
190 | n02129165 190 lion
191 | n02102318 191 cocker_spaniel
192 | n02100877 192 Irish_setter
193 | n02074367 193 dugong
194 | n02504013 194 Indian_elephant
195 | n02363005 195 beaver
196 | n02102480 196 Sussex_spaniel
197 | n02113023 197 Pembroke
198 | n02086646 198 Blenheim_spaniel
199 | n02497673 199 Madagascar_cat
200 | n02087394 200 Rhodesian_ridgeback
201 | n02127052 201 lynx
202 | n02116738 202 African_hunting_dog
203 | n02488291 203 langur
204 | n02091244 204 Ibizan_hound
205 | n02114367 205 timber_wolf
206 | n02130308 206 cheetah
207 | n02089973 207 English_foxhound
208 | n02105251 208 briard
209 | n02134418 209 sloth_bear
210 | n02093754 210 Border_terrier
211 | n02106662 211 German_shepherd
212 | n02444819 212 otter
213 | n01882714 213 koala
214 | n01871265 214 tusker
215 | n01872401 215 echidna
216 | n01877812 216 wallaby
217 | n01873310 217 platypus
218 | n01883070 218 wombat
219 | n04086273 219 revolver
220 | n04507155 220 umbrella
221 | n04147183 221 schooner
222 | n04254680 222 soccer_ball
223 | n02672831 223 accordion
224 | n02219486 224 ant
225 | n02317335 225 starfish
226 | n01968897 226 chambered_nautilus
227 | n03452741 227 grand_piano
228 | n03642806 228 laptop
229 | n07745940 229 strawberry
230 | n02690373 230 airliner
231 | n04552348 231 warplane
232 | n02692877 232 airship
233 | n02782093 233 balloon
234 | n04266014 234 space_shuttle
235 | n03344393 235 fireboat
236 | n03447447 236 gondola
237 | n04273569 237 speedboat
238 | n03662601 238 lifeboat
239 | n02951358 239 canoe
240 | n04612504 240 yawl
241 | n02981792 241 catamaran
242 | n04483307 242 trimaran
243 | n03095699 243 container_ship
244 | n03673027 244 liner
245 | n03947888 245 pirate
246 | n02687172 246 aircraft_carrier
247 | n04347754 247 submarine
248 | n04606251 248 wreck
249 | n03478589 249 half_track
250 | n04389033 250 tank
251 | n03773504 251 missile
252 | n02860847 252 bobsled
253 | n03218198 253 dogsled
254 | n02835271 254 bicycle-built-for-two
255 | n03792782 255 mountain_bike
256 | n03393912 256 freight_car
257 | n03895866 257 passenger_car
258 | n02797295 258 barrow
259 | n04204347 259 shopping_cart
260 | n03791053 260 motor_scooter
261 | n03384352 261 forklift
262 | n03272562 262 electric_locomotive
263 | n04310018 263 steam_locomotive
264 | n02704792 264 amphibian
265 | n02701002 265 ambulance
266 | n02814533 266 beach_wagon
267 | n02930766 267 cab
268 | n03100240 268 convertible
269 | n03594945 269 jeep
270 | n03670208 270 limousine
271 | n03770679 271 minivan
272 | n03777568 272 Model_T
273 | n04037443 273 racer
274 | n04285008 274 sports_car
275 | n03444034 275 go-kart
276 | n03445924 276 golfcart
277 | n03785016 277 moped
278 | n04252225 278 snowplow
279 | n03345487 279 fire_engine
280 | n03417042 280 garbage_truck
281 | n03930630 281 pickup
282 | n04461696 282 tow_truck
283 | n04467665 283 trailer_truck
284 | n03796401 284 moving_van
285 | n03977966 285 police_van
286 | n04065272 286 recreational_vehicle
287 | n04335435 287 streetcar
288 | n04252077 288 snowmobile
289 | n04465501 289 tractor
290 | n03776460 290 mobile_home
291 | n04482393 291 tricycle
292 | n04509417 292 unicycle
293 | n03538406 293 horse_cart
294 | n03599486 294 jinrikisha
295 | n03868242 295 oxcart
296 | n02804414 296 bassinet
297 | n03125729 297 cradle
298 | n03131574 298 crib
299 | n03388549 299 four-poster
300 | n02870880 300 bookcase
301 | n03018349 301 china_cabinet
302 | n03742115 302 medicine_chest
303 | n03016953 303 chiffonier
304 | n04380533 304 table_lamp
305 | n03337140 305 file
306 | n03891251 306 park_bench
307 | n02791124 307 barber_chair
308 | n04429376 308 throne
309 | n03376595 309 folding_chair
310 | n04099969 310 rocking_chair
311 | n04344873 311 studio_couch
312 | n04447861 312 toilet_seat
313 | n03179701 313 desk
314 | n03982430 314 pool_table
315 | n03201208 315 dining_table
316 | n03290653 316 entertainment_center
317 | n04550184 317 wardrobe
318 | n07742313 318 Granny_Smith
319 | n07747607 319 orange
320 | n07749582 320 lemon
321 | n07753113 321 fig
322 | n07753275 322 pineapple
323 | n07753592 323 banana
324 | n07754684 324 jackfruit
325 | n07760859 325 custard_apple
326 | n07768694 326 pomegranate
327 | n12267677 327 acorn
328 | n12620546 328 hip
329 | n13133613 329 ear
330 | n11879895 330 rapeseed
331 | n12144580 331 corn
332 | n12768682 332 buckeye
333 | n03854065 333 organ
334 | n04515003 334 upright
335 | n03017168 335 chime
336 | n03249569 336 drum
337 | n03447721 337 gong
338 | n03720891 338 maraca
339 | n03721384 339 marimba
340 | n04311174 340 steel_drum
341 | n02787622 341 banjo
342 | n02992211 342 cello
343 | n04536866 343 violin
344 | n03495258 344 harp
345 | n02676566 345 acoustic_guitar
346 | n03272010 346 electric_guitar
347 | n03110669 347 cornet
348 | n03394916 348 French_horn
349 | n04487394 349 trombone
350 | n03494278 350 harmonica
351 | n03840681 351 ocarina
352 | n03884397 352 panpipe
353 | n02804610 353 bassoon
354 | n03838899 354 oboe
355 | n04141076 355 sax
356 | n03372029 356 flute
357 | n11939491 357 daisy
358 | n12057211 358 yellow_lady's_slipper
359 | n09246464 359 cliff
360 | n09468604 360 valley
361 | n09193705 361 alp
362 | n09472597 362 volcano
363 | n09399592 363 promontory
364 | n09421951 364 sandbar
365 | n09256479 365 coral_reef
366 | n09332890 366 lakeside
367 | n09428293 367 seashore
368 | n09288635 368 geyser
369 | n03498962 369 hatchet
370 | n03041632 370 cleaver
371 | n03658185 371 letter_opener
372 | n03954731 372 plane
373 | n03995372 373 power_drill
374 | n03649909 374 lawn_mower
375 | n03481172 375 hammer
376 | n03109150 376 corkscrew
377 | n02951585 377 can_opener
378 | n03970156 378 plunger
379 | n04154565 379 screwdriver
380 | n04208210 380 shovel
381 | n03967562 381 plow
382 | n03000684 382 chain_saw
383 | n01514668 383 cock
384 | n01514859 384 hen
385 | n01518878 385 ostrich
386 | n01530575 386 brambling
387 | n01531178 387 goldfinch
388 | n01532829 388 house_finch
389 | n01534433 389 junco
390 | n01537544 390 indigo_bunting
391 | n01558993 391 robin
392 | n01560419 392 bulbul
393 | n01580077 393 jay
394 | n01582220 394 magpie
395 | n01592084 395 chickadee
396 | n01601694 396 water_ouzel
397 | n01608432 397 kite
398 | n01614925 398 bald_eagle
399 | n01616318 399 vulture
400 | n01622779 400 great_grey_owl
401 | n01795545 401 black_grouse
402 | n01796340 402 ptarmigan
403 | n01797886 403 ruffed_grouse
404 | n01798484 404 prairie_chicken
405 | n01806143 405 peacock
406 | n01806567 406 quail
407 | n01807496 407 partridge
408 | n01817953 408 African_grey
409 | n01818515 409 macaw
410 | n01819313 410 sulphur-crested_cockatoo
411 | n01820546 411 lorikeet
412 | n01824575 412 coucal
413 | n01828970 413 bee_eater
414 | n01829413 414 hornbill
415 | n01833805 415 hummingbird
416 | n01843065 416 jacamar
417 | n01843383 417 toucan
418 | n01847000 418 drake
419 | n01855032 419 red-breasted_merganser
420 | n01855672 420 goose
421 | n01860187 421 black_swan
422 | n02002556 422 white_stork
423 | n02002724 423 black_stork
424 | n02006656 424 spoonbill
425 | n02007558 425 flamingo
426 | n02009912 426 American_egret
427 | n02009229 427 little_blue_heron
428 | n02011460 428 bittern
429 | n02012849 429 crane
430 | n02013706 430 limpkin
431 | n02018207 431 American_coot
432 | n02018795 432 bustard
433 | n02025239 433 ruddy_turnstone
434 | n02027492 434 red-backed_sandpiper
435 | n02028035 435 redshank
436 | n02033041 436 dowitcher
437 | n02037110 437 oystercatcher
438 | n02017213 438 European_gallinule
439 | n02051845 439 pelican
440 | n02056570 440 king_penguin
441 | n02058221 441 albatross
442 | n01484850 442 great_white_shark
443 | n01491361 443 tiger_shark
444 | n01494475 444 hammerhead
445 | n01496331 445 electric_ray
446 | n01498041 446 stingray
447 | n02514041 447 barracouta
448 | n02536864 448 coho
449 | n01440764 449 tench
450 | n01443537 450 goldfish
451 | n02526121 451 eel
452 | n02606052 452 rock_beauty
453 | n02607072 453 anemone_fish
454 | n02643566 454 lionfish
455 | n02655020 455 puffer
456 | n02640242 456 sturgeon
457 | n02641379 457 gar
458 | n01664065 458 loggerhead
459 | n01665541 459 leatherback_turtle
460 | n01667114 460 mud_turtle
461 | n01667778 461 terrapin
462 | n01669191 462 box_turtle
463 | n01675722 463 banded_gecko
464 | n01677366 464 common_iguana
465 | n01682714 465 American_chameleon
466 | n01685808 466 whiptail
467 | n01687978 467 agama
468 | n01688243 468 frilled_lizard
469 | n01689811 469 alligator_lizard
470 | n01692333 470 Gila_monster
471 | n01693334 471 green_lizard
472 | n01694178 472 African_chameleon
473 | n01695060 473 Komodo_dragon
474 | n01704323 474 triceratops
475 | n01697457 475 African_crocodile
476 | n01698640 476 American_alligator
477 | n01728572 477 thunder_snake
478 | n01728920 478 ringneck_snake
479 | n01729322 479 hognose_snake
480 | n01729977 480 green_snake
481 | n01734418 481 king_snake
482 | n01735189 482 garter_snake
483 | n01737021 483 water_snake
484 | n01739381 484 vine_snake
485 | n01740131 485 night_snake
486 | n01742172 486 boa_constrictor
487 | n01744401 487 rock_python
488 | n01748264 488 Indian_cobra
489 | n01749939 489 green_mamba
490 | n01751748 490 sea_snake
491 | n01753488 491 horned_viper
492 | n01755581 492 diamondback
493 | n01756291 493 sidewinder
494 | n01629819 494 European_fire_salamander
495 | n01630670 495 common_newt
496 | n01631663 496 eft
497 | n01632458 497 spotted_salamander
498 | n01632777 498 axolotl
499 | n01641577 499 bullfrog
500 | n01644373 500 tree_frog
501 | n01644900 501 tailed_frog
502 | n04579432 502 whistle
503 | n04592741 503 wing
504 | n03876231 504 paintbrush
505 | n03483316 505 hand_blower
506 | n03868863 506 oxygen_mask
507 | n04251144 507 snorkel
508 | n03691459 508 loudspeaker
509 | n03759954 509 microphone
510 | n04152593 510 screen
511 | n03793489 511 mouse
512 | n03271574 512 electric_fan
513 | n03843555 513 oil_filter
514 | n04332243 514 strainer
515 | n04265275 515 space_heater
516 | n04330267 516 stove
517 | n03467068 517 guillotine
518 | n02794156 518 barometer
519 | n04118776 519 rule
520 | n03841143 520 odometer
521 | n04141975 521 scale
522 | n02708093 522 analog_clock
523 | n03196217 523 digital_clock
524 | n04548280 524 wall_clock
525 | n03544143 525 hourglass
526 | n04355338 526 sundial
527 | n03891332 527 parking_meter
528 | n04328186 528 stopwatch
529 | n03197337 529 digital_watch
530 | n04317175 530 stethoscope
531 | n04376876 531 syringe
532 | n03706229 532 magnetic_compass
533 | n02841315 533 binoculars
534 | n04009552 534 projector
535 | n04356056 535 sunglasses
536 | n03692522 536 loupe
537 | n04044716 537 radio_telescope
538 | n02879718 538 bow
539 | n02950826 539 cannon
540 | n02749479 540 assault_rifle
541 | n04090263 541 rifle
542 | n04008634 542 projectile
543 | n03085013 543 computer_keyboard
544 | n04505470 544 typewriter_keyboard
545 | n03126707 545 crane
546 | n03666591 546 lighter
547 | n02666196 547 abacus
548 | n02977058 548 cash_machine
549 | n04238763 549 slide_rule
550 | n03180011 550 desktop_computer
551 | n03485407 551 hand-held_computer
552 | n03832673 552 notebook
553 | n06359193 553 web_site
554 | n03496892 554 harvester
555 | n04428191 555 thresher
556 | n04004767 556 printer
557 | n04243546 557 slot
558 | n04525305 558 vending_machine
559 | n04179913 559 sewing_machine
560 | n03602883 560 joystick
561 | n04372370 561 switch
562 | n03532672 562 hook
563 | n02974003 563 car_wheel
564 | n03874293 564 paddlewheel
565 | n03944341 565 pinwheel
566 | n03992509 566 potter's_wheel
567 | n03425413 567 gas_pump
568 | n02966193 568 carousel
569 | n04371774 569 swing
570 | n04067472 570 reel
571 | n04040759 571 radiator
572 | n04019541 572 puck
573 | n03492542 573 hard_disc
574 | n04355933 574 sunglass
575 | n03929660 575 pick
576 | n02965783 576 car_mirror
577 | n04258138 577 solar_dish
578 | n04074963 578 remote_control
579 | n03208938 579 disk_brake
580 | n02910353 580 buckle
581 | n03476684 581 hair_slide
582 | n03627232 582 knot
583 | n03075370 583 combination_lock
584 | n03874599 584 padlock
585 | n03804744 585 nail
586 | n04127249 586 safety_pin
587 | n04153751 587 screw
588 | n03803284 588 muzzle
589 | n04162706 589 seat_belt
590 | n04228054 590 ski
591 | n02948072 591 candle
592 | n03590841 592 jack-o'-lantern
593 | n04286575 593 spotlight
594 | n04456115 594 torch
595 | n03814639 595 neck_brace
596 | n03933933 596 pier
597 | n04485082 597 tripod
598 | n03733131 598 maypole
599 | n03794056 599 mousetrap
600 | n04275548 600 spider_web
601 | n01768244 601 trilobite
602 | n01770081 602 harvestman
603 | n01770393 603 scorpion
604 | n01773157 604 black_and_gold_garden_spider
605 | n01773549 605 barn_spider
606 | n01773797 606 garden_spider
607 | n01774384 607 black_widow
608 | n01774750 608 tarantula
609 | n01775062 609 wolf_spider
610 | n01776313 610 tick
611 | n01784675 611 centipede
612 | n01990800 612 isopod
613 | n01978287 613 Dungeness_crab
614 | n01978455 614 rock_crab
615 | n01980166 615 fiddler_crab
616 | n01981276 616 king_crab
617 | n01983481 617 American_lobster
618 | n01984695 618 spiny_lobster
619 | n01985128 619 crayfish
620 | n01986214 620 hermit_crab
621 | n02165105 621 tiger_beetle
622 | n02165456 622 ladybug
623 | n02167151 623 ground_beetle
624 | n02168699 624 long-horned_beetle
625 | n02169497 625 leaf_beetle
626 | n02172182 626 dung_beetle
627 | n02174001 627 rhinoceros_beetle
628 | n02177972 628 weevil
629 | n02190166 629 fly
630 | n02206856 630 bee
631 | n02226429 631 grasshopper
632 | n02229544 632 cricket
633 | n02231487 633 walking_stick
634 | n02233338 634 cockroach
635 | n02236044 635 mantis
636 | n02256656 636 cicada
637 | n02259212 637 leafhopper
638 | n02264363 638 lacewing
639 | n02268443 639 dragonfly
640 | n02268853 640 damselfly
641 | n02276258 641 admiral
642 | n02277742 642 ringlet
643 | n02279972 643 monarch
644 | n02280649 644 cabbage_butterfly
645 | n02281406 645 sulphur_butterfly
646 | n02281787 646 lycaenid
647 | n01910747 647 jellyfish
648 | n01914609 648 sea_anemone
649 | n01917289 649 brain_coral
650 | n01924916 650 flatworm
651 | n01930112 651 nematode
652 | n01943899 652 conch
653 | n01944390 653 snail
654 | n01945685 654 slug
655 | n01950731 655 sea_slug
656 | n01955084 656 chiton
657 | n02319095 657 sea_urchin
658 | n02321529 658 sea_cucumber
659 | n03584829 659 iron
660 | n03297495 660 espresso_maker
661 | n03761084 661 microwave
662 | n03259280 662 Dutch_oven
663 | n04111531 663 rotisserie
664 | n04442312 664 toaster
665 | n04542943 665 waffle_iron
666 | n04517823 666 vacuum
667 | n03207941 667 dishwasher
668 | n04070727 668 refrigerator
669 | n04554684 669 washer
670 | n03133878 670 Crock_Pot
671 | n03400231 671 frying_pan
672 | n04596742 672 wok
673 | n02939185 673 caldron
674 | n03063689 674 coffeepot
675 | n04398044 675 teapot
676 | n04270147 676 spatula
677 | n02699494 677 altar
678 | n04486054 678 triumphal_arch
679 | n03899768 679 patio
680 | n04311004 680 steel_arch_bridge
681 | n04366367 681 suspension_bridge
682 | n04532670 682 viaduct
683 | n02793495 683 barn
684 | n03457902 684 greenhouse
685 | n03877845 685 palace
686 | n03781244 686 monastery
687 | n03661043 687 library
688 | n02727426 688 apiary
689 | n02859443 689 boathouse
690 | n03028079 690 church
691 | n03788195 691 mosque
692 | n04346328 692 stupa
693 | n03956157 693 planetarium
694 | n04081281 694 restaurant
695 | n03032252 695 cinema
696 | n03529860 696 home_theater
697 | n03697007 697 lumbermill
698 | n03065424 698 coil
699 | n03837869 699 obelisk
700 | n04458633 700 totem_pole
701 | n02980441 701 castle
702 | n04005630 702 prison
703 | n03461385 703 grocery_store
704 | n02776631 704 bakery
705 | n02791270 705 barbershop
706 | n02871525 706 bookshop
707 | n02927161 707 butcher_shop
708 | n03089624 708 confectionery
709 | n04200800 709 shoe_shop
710 | n04443257 710 tobacco_shop
711 | n04462240 711 toyshop
712 | n03388043 712 fountain
713 | n03042490 713 cliff_dwelling
714 | n04613696 714 yurt
715 | n03216828 715 dock
716 | n02892201 716 brass
717 | n03743016 717 megalith
718 | n02788148 718 bannister
719 | n02894605 719 breakwater
720 | n03160309 720 dam
721 | n03000134 721 chainlink_fence
722 | n03930313 722 picket_fence
723 | n04604644 723 worm_fence
724 | n04326547 724 stone_wall
725 | n03459775 725 grille
726 | n04239074 726 sliding_door
727 | n04501370 727 turnstile
728 | n03792972 728 mountain_tent
729 | n04149813 729 scoreboard
730 | n03530642 730 honeycomb
731 | n03961711 731 plate_rack
732 | n03903868 732 pedestal
733 | n02814860 733 beacon
734 | n07711569 734 mashed_potato
735 | n07720875 735 bell_pepper
736 | n07714571 736 head_cabbage
737 | n07714990 737 broccoli
738 | n07715103 738 cauliflower
739 | n07716358 739 zucchini
740 | n07716906 740 spaghetti_squash
741 | n07717410 741 acorn_squash
742 | n07717556 742 butternut_squash
743 | n07718472 743 cucumber
744 | n07718747 744 artichoke
745 | n07730033 745 cardoon
746 | n07734744 746 mushroom
747 | n04209239 747 shower_curtain
748 | n03594734 748 jean
749 | n02971356 749 carton
750 | n03485794 750 handkerchief
751 | n04133789 751 sandal
752 | n02747177 752 ashcan
753 | n04125021 753 safe
754 | n07579787 754 plate
755 | n03814906 755 necklace
756 | n03134739 756 croquet_ball
757 | n03404251 757 fur_coat
758 | n04423845 758 thimble
759 | n03877472 759 pajama
760 | n04120489 760 running_shoe
761 | n03062245 761 cocktail_shaker
762 | n03014705 762 chest
763 | n03717622 763 manhole_cover
764 | n03777754 764 modem
765 | n04493381 765 tub
766 | n04476259 766 tray
767 | n02777292 767 balance_beam
768 | n07693725 768 bagel
769 | n03998194 769 prayer_rug
770 | n03617480 770 kimono
771 | n07590611 771 hot_pot
772 | n04579145 772 whiskey_jug
773 | n03623198 773 knee_pad
774 | n07248320 774 book_jacket
775 | n04277352 775 spindle
776 | n04229816 776 ski_mask
777 | n02823428 777 beer_bottle
778 | n03127747 778 crash_helmet
779 | n02877765 779 bottlecap
780 | n04435653 780 tile_roof
781 | n03724870 781 mask
782 | n03710637 782 maillot
783 | n03920288 783 Petri_dish
784 | n03379051 784 football_helmet
785 | n02807133 785 bathing_cap
786 | n04399382 786 teddy
787 | n03527444 787 holster
788 | n03983396 788 pop_bottle
789 | n03924679 789 photocopier
790 | n04532106 790 vestment
791 | n06785654 791 crossword_puzzle
792 | n03445777 792 golf_ball
793 | n07613480 793 trifle
794 | n04350905 794 suit
795 | n04562935 795 water_tower
796 | n03325584 796 feather_boa
797 | n03045698 797 cloak
798 | n07892512 798 red_wine
799 | n03250847 799 drumstick
800 | n04192698 800 shield
801 | n03026506 801 Christmas_stocking
802 | n03534580 802 hoopskirt
803 | n07565083 803 menu
804 | n04296562 804 stage
805 | n02869837 805 bonnet
806 | n07871810 806 meat_loaf
807 | n02799071 807 baseball
808 | n03314780 808 face_powder
809 | n04141327 809 scabbard
810 | n04357314 810 sunscreen
811 | n02823750 811 beer_glass
812 | n13052670 812 hen-of-the-woods
813 | n07583066 813 guacamole
814 | n03637318 814 lampshade
815 | n04599235 815 wool
816 | n07802026 816 hay
817 | n02883205 817 bow_tie
818 | n03709823 818 mailbag
819 | n04560804 819 water_jug
820 | n02909870 820 bucket
821 | n03207743 821 dishrag
822 | n04263257 822 soup_bowl
823 | n07932039 823 eggnog
824 | n03786901 824 mortar
825 | n04479046 825 trench_coat
826 | n03873416 826 paddle
827 | n02999410 827 chain
828 | n04367480 828 swab
829 | n03775546 829 mixing_bowl
830 | n07875152 830 potpie
831 | n04591713 831 wine_bottle
832 | n04201297 832 shoji
833 | n02916936 833 bulletproof_vest
834 | n03240683 834 drilling_platform
835 | n02840245 835 binder
836 | n02963159 836 cardigan
837 | n04370456 837 sweatshirt
838 | n03991062 838 pot
839 | n02843684 839 birdhouse
840 | n03482405 840 hamper
841 | n03942813 841 ping-pong_ball
842 | n03908618 842 pencil_box
843 | n03902125 843 pay-phone
844 | n07584110 844 consomme
845 | n02730930 845 apron
846 | n04023962 846 punching_bag
847 | n02769748 847 backpack
848 | n10148035 848 groom
849 | n02817516 849 bearskin
850 | n03908714 850 pencil_sharpener
851 | n02906734 851 broom
852 | n03788365 852 mosquito_net
853 | n02667093 853 abaya
854 | n03787032 854 mortarboard
855 | n03980874 855 poncho
856 | n03141823 856 crutch
857 | n03976467 857 Polaroid_camera
858 | n04264628 858 space_bar
859 | n07930864 859 cup
860 | n04039381 860 racket
861 | n06874185 861 traffic_light
862 | n04033901 862 quill
863 | n04041544 863 radio
864 | n07860988 864 dough
865 | n03146219 865 cuirass
866 | n03763968 866 military_uniform
867 | n03676483 867 lipstick
868 | n04209133 868 shower_cap
869 | n03782006 869 monitor
870 | n03857828 870 oscilloscope
871 | n03775071 871 mitten
872 | n02892767 872 brassiere
873 | n07684084 873 French_loaf
874 | n04522168 874 vase
875 | n03764736 875 milk_can
876 | n04118538 876 rugby_ball
877 | n03887697 877 paper_towel
878 | n13044778 878 earthstar
879 | n03291819 879 envelope
880 | n03770439 880 miniskirt
881 | n03124170 881 cowboy_hat
882 | n04487081 882 trolleybus
883 | n03916031 883 perfume
884 | n02808440 884 bathtub
885 | n07697537 885 hotdog
886 | n12985857 886 coral_fungus
887 | n02917067 887 bullet_train
888 | n03938244 888 pillow
889 | n15075141 889 toilet_tissue
890 | n02978881 890 cassette
891 | n02966687 891 carpenter's_kit
892 | n03633091 892 ladle
893 | n13040303 893 stinkhorn
894 | n03690938 894 lotion
895 | n03476991 895 hair_spray
896 | n02669723 896 academic_gown
897 | n03220513 897 dome
898 | n03127925 898 crate
899 | n04584207 899 wig
900 | n07880968 900 burrito
901 | n03937543 901 pill_bottle
902 | n03000247 902 chain_mail
903 | n04418357 903 theater_curtain
904 | n04590129 904 window_shade
905 | n02795169 905 barrel
906 | n04553703 906 washbasin
907 | n02783161 907 ballpoint
908 | n02802426 908 basketball
909 | n02808304 909 bath_towel
910 | n03124043 910 cowboy_boot
911 | n03450230 911 gown
912 | n04589890 912 window_screen
913 | n12998815 913 agaric
914 | n02992529 914 cellular_telephone
915 | n03825788 915 nipple
916 | n02790996 916 barbell
917 | n03710193 917 mailbox
918 | n03630383 918 lab_coat
919 | n03347037 919 fire_screen
920 | n03769881 920 minibus
921 | n03871628 921 packet
922 | n03733281 922 maze
923 | n03976657 923 pole
924 | n03535780 924 horizontal_bar
925 | n04259630 925 sombrero
926 | n03929855 926 pickelhaube
927 | n04049303 927 rain_barrel
928 | n04548362 928 wallet
929 | n02979186 929 cassette_player
930 | n06596364 930 comic_book
931 | n03935335 931 piggy_bank
932 | n06794110 932 street_sign
933 | n02825657 933 bell_cote
934 | n03388183 934 fountain_pen
935 | n04591157 935 Windsor_tie
936 | n04540053 936 volleyball
937 | n03866082 937 overskirt
938 | n04136333 938 sarong
939 | n04026417 939 purse
940 | n02865351 940 bolo_tie
941 | n02834397 941 bib
942 | n03888257 942 parachute
943 | n04235860 943 sleeping_bag
944 | n04404412 944 television
945 | n04371430 945 swimming_trunks
946 | n03733805 946 measuring_cup
947 | n07920052 947 espresso
948 | n07873807 948 pizza
949 | n02895154 949 breastplate
950 | n04204238 950 shopping_basket
951 | n04597913 951 wooden_spoon
952 | n04131690 952 saltshaker
953 | n07836838 953 chocolate_sauce
954 | n09835506 954 ballplayer
955 | n03443371 955 goblet
956 | n13037406 956 gyromitra
957 | n04336792 957 stretcher
958 | n04557648 958 water_bottle
959 | n03187595 959 dial_telephone
960 | n04254120 960 soap_dispenser
961 | n03595614 961 jersey
962 | n04146614 962 school_bus
963 | n03598930 963 jigsaw_puzzle
964 | n03958227 964 plastic_bag
965 | n04069434 965 reflex_camera
966 | n03188531 966 diaper
967 | n02786058 967 Band_Aid
968 | n07615774 968 ice_lolly
969 | n04525038 969 velvet
970 | n04409515 970 tennis_ball
971 | n03424325 971 gasmask
972 | n03223299 972 doormat
973 | n03680355 973 Loafer
974 | n07614500 974 ice_cream
975 | n07695742 975 pretzel
976 | n04033995 976 quilt
977 | n03710721 977 maillot
978 | n04392985 978 tape_player
979 | n03047690 979 clog
980 | n03584254 980 iPod
981 | n13054560 981 bolete
982 | n10565667 982 scuba_diver
983 | n03950228 983 pitcher
984 | n03729826 984 matchstick
985 | n02837789 985 bikini
986 | n04254777 986 sock
987 | n02988304 987 CD_player
988 | n03657121 988 lens_cap
989 | n04417672 989 thatch
990 | n04523525 990 vault
991 | n02815834 991 beaker
992 | n09229709 992 bubble
993 | n07697313 993 cheeseburger
994 | n03888605 994 parallel_bars
995 | n03355925 995 flagpole
996 | n03063599 996 coffee_mug
997 | n04116512 997 rubber_eraser
998 | n04325704 998 stole
999 | n07831146 999 carbonara
1000 | n03255030 1000 dumbbell
--------------------------------------------------------------------------------
/support/readme.md:
--------------------------------------------------------------------------------
1 | Files `tinyimagenet_refined_labels.json` and `imagenet_refined_labels.json` are used to specify train/val splits when we re-implement adversarial finetuning (AFT) methods. We re-implemented three AFT methods (TeCoA, PMG-AFT and FARE) in the paper on the tinyImageNet dataset. For each class, we randomly sample 10% of the instances in the training set for evaluation in the phase of finetuning. In the `.json` file, each class (identified with a synset id) has the following attributes:
2 | ```
3 | {
4 | synset identifier (e.g., 'n01443537'):
5 | {
6 | "clean_name": cleansed textual name (e.g., 'goldfish'),
7 | "wordnet_def": definition given by WordNet (e.g., 'small golden or orange-red freshwater fishes of Eurasia used as pond or aquarium fishes'),
8 | "eval_files": list of instances selected as evaluation data (e.g., ["n01443537_266.JPEG", "n01443537_75.JPEG", ...])
9 | },
10 | }
11 | ```
12 |
13 | We also provide the cleansed textual name for each class and its definition given by [WordNet](https://wordnet.princeton.edu/).
14 |
--------------------------------------------------------------------------------