├── AFT_model_weights └── readme.md ├── README.md ├── TTC.sh ├── code ├── attacks.py ├── func.py ├── models │ ├── download_models.sh │ ├── model.py │ └── prompters.py ├── replace │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── datasets │ │ ├── caltech.py │ │ ├── country211.py │ │ ├── dtd.py │ │ ├── eurosat.py │ │ ├── fgvc_aircraft.py │ │ ├── flowers102.py │ │ ├── folder.py │ │ ├── food101.py │ │ ├── oxford_iiit_pet.py │ │ ├── pcam.py │ │ ├── stanford_cars.py │ │ └── sun397.py │ ├── model.py │ └── simple_tokenizer.py ├── test_time_counterattack.py └── utils.py ├── data └── readme.md ├── download_weights.py ├── environment.yml ├── figures ├── fig2b.png └── teaser.png ├── poster_CVPR_XING.png ├── requirements.txt └── support ├── imagenet_classes_names.txt ├── imagenet_refined_labels.json ├── readme.md └── tinyimagenet_refined_labels.json /AFT_model_weights/readme.md: -------------------------------------------------------------------------------- 1 | This folder keeps the model weights of AFT methods, if one wants to test TTC on adversarially finetuned models such as TeCoA. 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLIP-Test-time-Counterattacks 🚀 2 | This is the official code of our work: 3 | 4 | [CLIP is Strong Enough to Fight Back: Test-time Counterattacks towards Zero-shot Adversarial Robustness of CLIP](https://arxiv.org/abs/2503.03613). Songlong Xing, Zhengyu Zhao, Nicu Sebe. To appear in CVPR 2025. 5 | 6 | > **Abstract**: Despite its prevalent use in image-text matching tasks in a zero-shot manner, CLIP has been shown to be highly vulnerable to adversarial perturbations added onto images. Recent studies propose to finetune the vision encoder of CLIP with adversarial samples generated on the fly, and show improved robustness against adversarial attacks on a spectrum of downstream datasets, a property termed as zero-shot robustness. In this paper, we show that malicious perturbations that seek to maximise the classification loss lead to `falsely stable' images, and propose to leverage the pre-trained vision encoder of CLIP to counterattack such adversarial images during inference to achieve robustness. Our paradigm is simple and training-free, providing the first method to defend CLIP from adversarial attacks at test time, which is orthogonal to existing methods aiming to boost zero-shot adversarial robustness of CLIP. We conduct experiments across 16 classification datasets, and demonstrate stable and consistent gains compared to test-time defence methods adapted from existing adversarial robustness studies that do not rely on external networks, without noticeably impairing performance on clean images. We also show that our paradigm can be employed on CLIP models that have been adversarially finetuned to further enhance their robustness at test time. Our code is available here. 7 | 8 |

9 | 10 |

11 | 12 | ## 🛠️ Setup 13 | ### Environment 14 | Make sure you have installed conda and use the following commands to get the env ready! 15 | ```bash 16 | conda env create -f environment.yml 17 | conda activate TTC 18 | pip install -r requirements.txt 19 | ``` 20 | ### Data preparation 21 | Please download and unzip all the raw datasets into `./data`. It's okay to skip this step because `torchvision.datasets` will automatically download (most of) them as you run the code if you don't already have them. 22 | 23 | ## 🔥 Run 24 | The file `code/test_time_counterattack.py` contains the main program. To reproduce the results of TTC employed on the original CLIP (Tab.1 in the paper), run the following command: 25 | ```bash 26 | conda activate TTC 27 | python code/test_time_counterattack.py --batch_size 256 --test_attack_type 'pgd' --test_eps 1 --test_numsteps 10 --test_stepsize 1 --outdir 'TTC_results' --seed 1 --ttc_eps 4 --beta 2 --tau_thres 0.2 --ttc_numsteps 2 28 | ``` 29 | This command can also be found in `TTC.sh`. You can run `bash TTC.sh` to avoid typing the lengthy command in the terminal. 30 | The results will be saved in the folder specified by `--outdir`. 31 | 32 | To employ TTC on adversarially finetuned models, please specify the path to the finetuned model weights with `--victim_resume`. We provide at this [Google drive link](https://drive.google.com/drive/folders/1aDChTWGOrqK6IrIKVqSyMf4IdIBHEiJr?usp=drive_link) the checkpoints we have obtained via adversarial finetuning, which are used for producing the results reported in Tab.3 in the paper. Please download those checkpoints and keep them in the folder `./AFT_model_weights`. This can be done by running `python download_weights.py`. Alternatively, you can implement adversarial finetuning methods on your own and put the model weights in `./AFT_model_weights`. 33 | 34 | After the AFT checkpoints are in place, you can run the same command as in `TTC.sh` with an additional argument `--victim_resume` specifying the path to the model weights. For example, use `--victim_resume "./AFT_model_weights/TeCoA.pth.tar"` if you want to employ TTC on top of TeCoA. 35 | 36 | ## 📬 Updates 37 | 7 Mar 2025: **Please stay tuned for instructions to run the code!** 38 | 39 | 10 Mar 2025: **Setup** updated! 40 | 41 | 10 Mar 2025: **Run Instructions** updated! 42 | 43 | ## 🗂️ Reference 44 | ``` 45 | @InProceedings{Xing_2025_CVPR, 46 | author = {Xing, Songlong and Zhao, Zhengyu and Sebe, Nicu}, 47 | title = {CLIP is Strong Enough to Fight Back: Test-time Counterattacks towards Zero-shot Adversarial Robustness of CLIP}, 48 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, 49 | month = {June}, 50 | year = {2025}, 51 | pages = {15172-15182} 52 | } 53 | ``` 54 | 55 | ## Ackowledgement 56 | Our code is developed based on the open-source code of [TeCoA (ICLR-23)](https://github.com/cvlab-columbia/ZSRobust4FoundationModel). We thank the authors for their brilliant work. Please also consider citing their paper: 57 | ``` 58 | @inproceedings{maounderstanding, 59 | title={Understanding Zero-shot Adversarial Robustness for Large-Scale Models}, 60 | author={Mao, Chengzhi and Geng, Scott and Yang, Junfeng and Wang, Xin and Vondrick, Carl}, 61 | booktitle={The Eleventh International Conference on Learning Representations}, 62 | year={2023} 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /TTC.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | conda activate TTC 4 | 5 | cd ./ 6 | 7 | python code/test_time_counterattack.py \ 8 | --batch_size 256 \ 9 | --test_attack_type 'pgd' \ 10 | --test_eps 1 \ 11 | --test_numsteps 10 \ 12 | --test_stepsize 1 \ 13 | --outdir 'TTC_results' \ 14 | --seed 1 \ 15 | --ttc_eps 4 \ 16 | --beta 2 \ 17 | --tau_thres 0.2 \ 18 | --ttc_numsteps 2 19 | -------------------------------------------------------------------------------- /code/attacks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from utils import one_hot_embedding 3 | # from models.model import * 4 | import torch.nn.functional as F 5 | import functools 6 | from autoattack import AutoAttack 7 | from func import clip_img_preprocessing, multiGPU_CLIP, multiGPU_CLIP_image_logits 8 | 9 | lower_limit, upper_limit = 0, 1 10 | def clamp(X, lower_limit, upper_limit): 11 | return torch.max(torch.min(X, upper_limit), lower_limit) 12 | 13 | 14 | def attack_CW(args, prompter, model, model_text, model_image, add_prompter, criterion, X, target, text_tokens, alpha, 15 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0): 16 | delta = torch.zeros_like(X).cuda() 17 | if norm == "l_inf": 18 | delta.uniform_(-epsilon, epsilon) 19 | elif norm == "l_2": 20 | delta.normal_() 21 | d_flat = delta.view(delta.size(0), -1) 22 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 23 | r = torch.zeros_like(n).uniform_(0, 1) 24 | delta *= r / n * epsilon 25 | else: 26 | raise ValueError 27 | delta = clamp(delta, lower_limit - X, upper_limit - X) 28 | delta.requires_grad = True 29 | for _ in range(attack_iters): 30 | # output = model(normalize(X )) 31 | 32 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 33 | prompt_token = add_prompter() 34 | 35 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, prompted_images, text_tokens, prompt_token) 36 | 37 | num_class = output.size(1) 38 | label_mask = one_hot_embedding(target, num_class) 39 | label_mask = label_mask.cuda() 40 | 41 | correct_logit = torch.sum(label_mask * output, dim=1) 42 | wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, axis=1) 43 | 44 | # loss = criterion(output, target) 45 | loss = - torch.sum(F.relu(correct_logit - wrong_logit + 50)) 46 | 47 | loss.backward() 48 | grad = delta.grad.detach() 49 | d = delta[:, :, :, :] 50 | g = grad[:, :, :, :] 51 | x = X[:, :, :, :] 52 | if norm == "l_inf": 53 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 54 | elif norm == "l_2": 55 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 56 | scaled_g = g / (g_norm + 1e-10) 57 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 58 | d = clamp(d, lower_limit - x, upper_limit - x) 59 | delta.data[:, :, :, :] = d 60 | delta.grad.zero_() 61 | 62 | return delta 63 | 64 | 65 | def attack_CW_noprompt(args, prompter, model, model_text, model_image, criterion, X, target, text_tokens, alpha, 66 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0): 67 | delta = torch.zeros_like(X).cuda() 68 | if norm == "l_inf": 69 | delta.uniform_(-epsilon, epsilon) 70 | elif norm == "l_2": 71 | delta.normal_() 72 | d_flat = delta.view(delta.size(0), -1) 73 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 74 | r = torch.zeros_like(n).uniform_(0, 1) 75 | delta *= r / n * epsilon 76 | else: 77 | raise ValueError 78 | delta = clamp(delta, lower_limit - X, upper_limit - X) 79 | delta.requires_grad = True 80 | for _ in range(attack_iters): 81 | # output = model(normalize(X )) 82 | 83 | _images = clip_img_preprocessing(X + delta) 84 | # output, _ = model(_images, text_tokens) 85 | 86 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, _images, text_tokens, None) 87 | 88 | num_class = output.size(1) 89 | label_mask = one_hot_embedding(target, num_class) 90 | label_mask = label_mask.cuda() 91 | 92 | correct_logit = torch.sum(label_mask * output, dim=1) 93 | wrong_logit, _ = torch.max((1 - label_mask) * output - 1e4 * label_mask, axis=1) 94 | 95 | # loss = criterion(output, target) 96 | loss = - torch.sum(F.relu(correct_logit - wrong_logit + 50)) 97 | 98 | loss.backward() 99 | grad = delta.grad.detach() 100 | d = delta[:, :, :, :] 101 | g = grad[:, :, :, :] 102 | x = X[:, :, :, :] 103 | if norm == "l_inf": 104 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 105 | elif norm == "l_2": 106 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 107 | scaled_g = g / (g_norm + 1e-10) 108 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 109 | d = clamp(d, lower_limit - x, upper_limit - x) 110 | delta.data[:, :, :, :] = d 111 | delta.grad.zero_() 112 | 113 | return delta 114 | 115 | def attack_unlabelled(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0, 116 | visual_model_orig=None): 117 | delta = torch.zeros_like(X) 118 | if norm == "l_inf": 119 | delta.uniform_(-epsilon, epsilon) 120 | elif norm == "l_2": 121 | delta.normal_() 122 | d_flat = delta.view(delta.size(0), -1) 123 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 124 | r = torch.zeros_like(n).uniform_(0, 1) 125 | delta *= r / n * epsilon 126 | else: 127 | raise ValueError 128 | 129 | # turn off model parameters temmporarily 130 | tunable_param_names = [] 131 | for n,p in model.module.named_parameters(): 132 | if p.requires_grad: 133 | tunable_param_names.append(n) 134 | p.requires_grad = False 135 | 136 | delta = clamp(delta, lower_limit - X, upper_limit - X) 137 | delta.requires_grad = True 138 | 139 | if attack_iters <= 0: 140 | return delta 141 | 142 | prompt_token = add_prompter() 143 | with torch.no_grad(): 144 | if visual_model_orig is None: # use the model itself as anchor 145 | X_ori_reps = model.module.encode_image( 146 | prompter(clip_img_preprocessing(X)), prompt_token 147 | ) 148 | else: # use original frozen model as anchor 149 | X_ori_reps = visual_model_orig.module( 150 | prompter(clip_img_preprocessing(X)), prompt_token 151 | ) 152 | 153 | for _ in range(attack_iters): 154 | 155 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 156 | 157 | X_att_reps = model.module.encode_image(prompted_images, prompt_token) 158 | # l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))**(0.5)).sum() 159 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum() 160 | 161 | grad = torch.autograd.grad(l2_loss, delta)[0] 162 | d = delta[:, :, :, :] 163 | g = grad[:, :, :, :] 164 | x = X[:, :, :, :] 165 | if norm == "l_inf": 166 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 167 | elif norm == "l_2": 168 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 169 | scaled_g = g / (g_norm + 1e-10) 170 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 171 | d = clamp(d, lower_limit - x, upper_limit - x) 172 | delta.data[:, :, :, :] = d 173 | 174 | # # Turn on model parameters 175 | for n,p in model.module.named_parameters(): 176 | if n in tunable_param_names: 177 | p.requires_grad = True 178 | 179 | return delta 180 | 181 | #### opposite update direction of attack_unlabelled() 182 | def attack_unlabelled_opp(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0, 183 | visual_model_orig=None): 184 | delta = torch.zeros_like(X) 185 | if norm == "l_inf": 186 | delta.uniform_(-epsilon, epsilon) 187 | elif norm == "l_2": 188 | delta.normal_() 189 | d_flat = delta.view(delta.size(0), -1) 190 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 191 | r = torch.zeros_like(n).uniform_(0, 1) 192 | delta *= r / n * epsilon 193 | else: 194 | raise ValueError 195 | 196 | # turn off model parameters temmporarily 197 | tunable_param_names = [] 198 | for n,p in model.module.named_parameters(): 199 | if p.requires_grad: 200 | tunable_param_names.append(n) 201 | p.requires_grad = False 202 | 203 | delta = clamp(delta, lower_limit - X, upper_limit - X) 204 | delta.requires_grad = True 205 | 206 | prompt_token = add_prompter() 207 | with torch.no_grad(): 208 | if visual_model_orig is None: # use the model itself as anchor 209 | X_ori_reps = model.module.encode_image( 210 | prompter(clip_img_preprocessing(X)), prompt_token 211 | ) 212 | else: # use original frozen model as anchor 213 | X_ori_reps = visual_model_orig.module( 214 | prompter(clip_img_preprocessing(X)), prompt_token 215 | ) 216 | 217 | for _ in range(attack_iters): 218 | 219 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 220 | 221 | X_att_reps = model.module.encode_image(prompted_images, prompt_token) 222 | # l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))**(0.5)).sum() 223 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum() 224 | 225 | grad = torch.autograd.grad(l2_loss, delta)[0] 226 | d = delta[:, :, :, :] 227 | g = grad[:, :, :, :] 228 | x = X[:, :, :, :] 229 | if norm == "l_inf": 230 | # d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 231 | d = torch.clamp(d - alpha * torch.sign(g), min=-epsilon, max=epsilon) 232 | elif norm == "l_2": 233 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 234 | scaled_g = g / (g_norm + 1e-10) 235 | # d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 236 | d = (d - scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 237 | d = clamp(d, lower_limit - x, upper_limit - x) 238 | delta.data[:, :, :, :] = d 239 | 240 | # # Turn on model parameters 241 | for n,p in model.module.named_parameters(): 242 | if n in tunable_param_names: 243 | p.requires_grad = True 244 | 245 | return delta 246 | 247 | def attack_unlabelled_cosine(model, X, prompter, add_prompter, alpha, attack_iters, norm="l_inf", epsilon=0, 248 | visual_model_orig=None): 249 | # unlabelled attack to maximise cosine similarity between the attacked image 250 | # and the original image, computed by PGD 251 | delta = torch.zeros_like(X) 252 | if norm == "l_inf": 253 | delta.uniform_(-epsilon, epsilon) 254 | elif norm == "l_2": 255 | delta.normal_() 256 | d_flat = delta.view(delta.size(0), -1) 257 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 258 | r = torch.zeros_like(n).uniform_(0, 1) 259 | delta *= r / n * epsilon 260 | else: 261 | raise ValueError 262 | 263 | # turn off model parameters temmporarily 264 | tunable_param_names = [] 265 | for n,p in model.module.named_parameters(): 266 | if p.requires_grad: 267 | tunable_param_names.append(n) 268 | p.requires_grad = False 269 | 270 | delta = clamp(delta, lower_limit - X, upper_limit - X) 271 | delta.requires_grad = True 272 | 273 | prompt_token = add_prompter() 274 | with torch.no_grad(): 275 | if visual_model_orig is None: # use the model itself as anchor 276 | X_ori_reps = model.module.encode_image( 277 | prompter(clip_img_preprocessing(X)), prompt_token 278 | ) 279 | else: # use original frozen model as anchor 280 | X_ori_reps = visual_model_orig.module( 281 | prompter(clip_img_preprocessing(X)), prompt_token 282 | ) 283 | # X_ori_reps_norm = X_ori_reps / X_ori_reps.norm(dim=-1, keepdim=True) 284 | 285 | for _ in range(attack_iters): 286 | 287 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 288 | 289 | X_att_reps = model.module.encode_image(prompted_images, prompt_token) # [bs, d_out] 290 | # X_att_reps_norm = X_att_reps / X_att_reps.norm(dim=-1, keepdim=True) 291 | 292 | cos_loss = 1 - F.cosine_similarity(X_att_reps, X_ori_reps) # [bs] 293 | 294 | grad = torch.autograd.grad(cos_loss, delta)[0] 295 | d = delta[:, :, :, :] 296 | g = grad[:, :, :, :] 297 | x = X[:, :, :, :] 298 | if norm == "l_inf": 299 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 300 | elif norm == "l_2": 301 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 302 | scaled_g = g / (g_norm + 1e-10) 303 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 304 | d = clamp(d, lower_limit - x, upper_limit - x) 305 | delta.data[:, :, :, :] = d 306 | 307 | # # Turn on model parameters 308 | for n,p in model.module.named_parameters(): 309 | if n in tunable_param_names: 310 | p.requires_grad = True 311 | 312 | return delta 313 | 314 | 315 | def attack_pgd(args, prompter, model, model_text, model_image, add_prompter, criterion, X, target, alpha, 316 | attack_iters, norm, text_tokens=None, restarts=1, early_stop=True, epsilon=0, dataset_name=None): 317 | delta = torch.zeros_like(X).cuda() 318 | if norm == "l_inf": 319 | delta.uniform_(-epsilon, epsilon) 320 | elif norm == "l_2": 321 | delta.normal_() 322 | d_flat = delta.view(delta.size(0), -1) 323 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 324 | r = torch.zeros_like(n).uniform_(0, 1) 325 | delta *= r / n * epsilon 326 | else: 327 | raise ValueError 328 | delta = clamp(delta, lower_limit - X, upper_limit - X) 329 | delta.requires_grad = True 330 | 331 | # turn off model parameters temmporarily 332 | tunable_param_names = [] 333 | for n,p in model.module.named_parameters(): 334 | if p.requires_grad: 335 | tunable_param_names.append(n) 336 | p.requires_grad = False 337 | 338 | for iter in range(attack_iters): 339 | 340 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 341 | prompt_token = add_prompter() 342 | 343 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, prompted_images, 344 | text_tokens=text_tokens, prompt_token=prompt_token, dataset_name=dataset_name) 345 | 346 | loss = criterion(output, target) 347 | 348 | grad = torch.autograd.grad(loss, delta)[0] 349 | 350 | d = delta[:, :, :, :] 351 | g = grad[:, :, :, :] 352 | x = X[:, :, :, :] 353 | if norm == "l_inf": 354 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 355 | elif norm == "l_2": 356 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 357 | scaled_g = g / (g_norm + 1e-10) 358 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 359 | d = clamp(d, lower_limit - x, upper_limit - x) 360 | delta.data[:, :, :, :] = d 361 | # delta.grad.zero_() 362 | 363 | # # Turn on model parameters 364 | for n,p in model.module.named_parameters(): 365 | if n in tunable_param_names: 366 | p.requires_grad = True 367 | 368 | return delta 369 | 370 | 371 | def attack_pgd_noprompt(args, prompter, model, model_text, model_image, criterion, X, target, text_tokens, alpha, 372 | attack_iters, norm, restarts=1, early_stop=True, epsilon=0): 373 | delta = torch.zeros_like(X).cuda() 374 | if norm == "l_inf": 375 | delta.uniform_(-epsilon, epsilon) 376 | elif norm == "l_2": 377 | delta.normal_() 378 | d_flat = delta.view(delta.size(0), -1) 379 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 380 | r = torch.zeros_like(n).uniform_(0, 1) 381 | delta *= r / n * epsilon 382 | else: 383 | raise ValueError 384 | delta = clamp(delta, lower_limit - X, upper_limit - X) 385 | delta.requires_grad = True 386 | for _ in range(attack_iters): 387 | 388 | _images = clip_img_preprocessing(X + delta) 389 | output, _, _, _ = multiGPU_CLIP(args, model_image, model_text, model, _images, text_tokens, None) 390 | 391 | loss = criterion(output, target) 392 | 393 | loss.backward() 394 | grad = delta.grad.detach() 395 | d = delta[:, :, :, :] 396 | g = grad[:, :, :, :] 397 | x = X[:, :, :, :] 398 | if norm == "l_inf": 399 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 400 | elif norm == "l_2": 401 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 402 | scaled_g = g / (g_norm + 1e-10) 403 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 404 | d = clamp(d, lower_limit - x, upper_limit - x) 405 | delta.data[:, :, :, :] = d 406 | delta.grad.zero_() 407 | 408 | return delta 409 | 410 | def attack_auto(model, images, target, text_tokens, prompter, add_prompter, 411 | attacks_to_run=['apgd-ce', 'apgd-dlr'], epsilon=0): 412 | 413 | forward_pass = functools.partial( 414 | multiGPU_CLIP_image_logits, 415 | model=model, text_tokens=text_tokens, 416 | prompter=None, add_prompter=None 417 | ) 418 | 419 | adversary = AutoAttack(forward_pass, norm='Linf', eps=epsilon, version='standard', verbose=False) 420 | adversary.attacks_to_run = attacks_to_run 421 | x_adv = adversary.run_standard_evaluation(images, target, bs=images.shape[0]) 422 | return x_adv -------------------------------------------------------------------------------- /code/func.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | device = "cuda" if torch.cuda.is_available() else "cpu" 5 | 6 | CIFAR100_MEAN = (0.48145466, 0.4578275, 0.40821073) 7 | CIFAR100_STD = (0.26862954, 0.26130258, 0.27577711) 8 | 9 | mu = torch.tensor(CIFAR100_MEAN).view(3, 1, 1).to(device) 10 | std = torch.tensor(CIFAR100_STD).view(3, 1, 1).to(device) 11 | 12 | def normalize(X): 13 | return (X - mu) / std 14 | def clip_img_preprocessing(X): 15 | img_size = 224 16 | X = torch.nn.functional.interpolate(X, size=(img_size, img_size), mode='bicubic') 17 | X = normalize(X) 18 | return X 19 | 20 | def rev_normalize(X): 21 | return X * std + mu 22 | def reverse_clip_img_preprocessing(X): 23 | X = rev_normalize(X) 24 | return X 25 | 26 | def multiGPU_CLIP_image_logits(images, model, text_tokens, prompter=None, add_prompter=None): 27 | image_tokens = clip_img_preprocessing(images) 28 | prompt_token = None if add_prompter is None else add_prompter() 29 | if prompter is not None: 30 | image_tokens = prompter(image_tokens) 31 | return multiGPU_CLIP(None, None, None, model, image_tokens, text_tokens, prompt_token=prompt_token)[0] 32 | 33 | 34 | def multiGPU_CLIP(args, model_image, model_text, model, images, text_tokens=None, prompt_token=None, dataset_name=None): 35 | if prompt_token is not None: 36 | bs = images.size(0) 37 | prompt_token = prompt_token.repeat(bs, 1, 1) 38 | if args is not None and dataset_name is not None: 39 | cache_prompts = os.path.join(args.cache, f"refined_{dataset_name.lower()}_prompts.pt") 40 | cache_wordnet_def = os.path.join(args.cache, f"refined_{dataset_name.lower()}_wn_def.pt") 41 | else: 42 | cache_prompts, cache_wordnet_def = None, None 43 | if cache_prompts is not None and os.path.exists(cache_prompts): 44 | text_features = torch.load(cache_prompts).to('cpu') 45 | if args.advanced_text == "wordnet_def": 46 | a_text_features = torch.load(cache_wordnet_def).to('cpu') 47 | text_features = (text_features + a_text_features) * 0.5 48 | else: 49 | text_features = model.module.encode_text(text_tokens) 50 | text_features = text_features / text_features.norm(dim=1, keepdim=True) # [n_class, d_emb] 51 | text_features = text_features.to(device) 52 | image_features = model.module.encode_image(images, prompt_token) 53 | image_features = image_features / image_features.norm(dim=1, keepdim=True) # [bs, d_emb] 54 | logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp() 55 | logits_per_text = text_features @ image_features.t() * model.module.logit_scale.exp() 56 | 57 | return logits_per_image, logits_per_text, image_features, text_features 58 | 59 | def kl_div(p_logits, q_logits): 60 | # p_logits, q_logits [bs, n_class] both have been softmax normalized 61 | kl_divs = (p_logits * (p_logits.log() - q_logits.log())).sum(dim=1) # [bs,] 62 | return kl_divs.mean() 63 | 64 | def get_loss_general(tgt_logits, a_images, model_image_copy, text_features): 65 | # feed the perturbed image into the original visual encoder, regularise the predictive logits 66 | image_features = model_image_copy(a_images) # [bs, d_emb] 67 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 68 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 69 | logits_per_image_ = image_features @ text_features.t() * model_image_copy.module.logit_scale.exp() # [bs, n_class] 70 | l_general = kl_div(tgt_logits.softmax(dim=1), logits_per_image_.softmax(dim=1)) 71 | # l_general = criterion_(F.log_softmax(logits_per_image_, dim=1), F.softmax(tgt_logits)) 72 | return l_general 73 | 74 | def get_loss_clean(clean_images, tgt_logits, model, text_features, prompt_token=None): 75 | # feed the clean image into the visual encoder, regularise the predictive logits 76 | image_features = model.module.encode_image(clean_images, prompt_token) # [bs, d_emb] 77 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 78 | logits_per_image = image_features @ text_features.t() * model.module.logit_scale.exp() # [bs, n_class] 79 | l_clean = kl_div(tgt_logits.softmax(dim=1), logits_per_image.softmax(dim=1)) 80 | # l_clean = criterion_(F.log_softmax(logits_per_image, dim=1), F.softmax(tgt_logits, dim=1)) 81 | return l_clean -------------------------------------------------------------------------------- /code/models/download_models.sh: -------------------------------------------------------------------------------- 1 | # CLIP 2 | pip install git+https://github.com/openai/CLIP.git 3 | -------------------------------------------------------------------------------- /code/models/model.py: -------------------------------------------------------------------------------- 1 | import torch, clip 2 | 3 | IMAGENET_MEAN = (0.48145466, 0.4578275, 0.40821073) 4 | IMAGENET_STD = (0.26862954, 0.26130258, 0.27577711) 5 | 6 | mu = torch.tensor(IMAGENET_MEAN).view(3, 1, 1).cuda() 7 | std = torch.tensor(IMAGENET_STD).view(3, 1, 1).cuda() 8 | 9 | def normalize(X): 10 | return (X - mu) / std 11 | 12 | def clip_img_preprocessing(X): 13 | img_size = 224 14 | X = torch.nn.functional.upsample(X, size=(img_size, img_size), mode='bicubic') 15 | X = normalize(X) 16 | return X 17 | 18 | def create_logits(x1, x2, logit_scale): 19 | x1 = x1 / x1.norm(dim=-1, keepdim=True) 20 | x2 = x2 / x2.norm(dim=-1, keepdim=True) 21 | # cosine similarity as logits 22 | logits_per_x1 = logit_scale * x1 @ x2.t() 23 | logits_per_x2 = logit_scale * x2 @ x1.t() 24 | return logits_per_x1, logits_per_x2 25 | 26 | def multiGPU_CLIP(clip_model, images, text_tokens, prompt_token=None): 27 | if prompt_token is not None: 28 | bs = images.size(0) 29 | prompt_token = prompt_token.repeat(bs, 1, 1) 30 | 31 | img_embed, scale_text_embed = clip_model(images, text_tokens, prompt_token) 32 | logits_per_image = img_embed @ scale_text_embed.t() 33 | logits_per_text = scale_text_embed @ img_embed.t() 34 | return logits_per_image, logits_per_text 35 | -------------------------------------------------------------------------------- /code/models/prompters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | 6 | import torch 7 | from torch import nn 8 | 9 | from einops import rearrange, repeat 10 | from einops.layers.torch import Rearrange 11 | 12 | ############# 13 | 14 | class PreNorm(nn.Module): 15 | def __init__(self, dim, fn): 16 | super().__init__() 17 | self.norm = nn.LayerNorm(dim) 18 | self.fn = fn 19 | def forward(self, x, **kwargs): 20 | return self.fn(self.norm(x), **kwargs) 21 | 22 | class FeedForward(nn.Module): 23 | def __init__(self, dim, hidden_dim, dropout = 0.): 24 | super().__init__() 25 | self.net = nn.Sequential( 26 | nn.Linear(dim, hidden_dim), 27 | nn.GELU(), 28 | nn.Dropout(dropout), 29 | nn.Linear(hidden_dim, dim), 30 | nn.Dropout(dropout) 31 | ) 32 | def forward(self, x): 33 | return self.net(x) 34 | 35 | class Attention(nn.Module): 36 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): 37 | super().__init__() 38 | inner_dim = dim_head * heads 39 | project_out = not (heads == 1 and dim_head == dim) 40 | 41 | self.heads = heads 42 | self.scale = dim_head ** -0.5 43 | 44 | self.attend = nn.Softmax(dim = -1) 45 | self.dropout = nn.Dropout(dropout) 46 | 47 | self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 48 | 49 | self.to_out = nn.Sequential( 50 | nn.Linear(inner_dim, dim), 51 | nn.Dropout(dropout) 52 | ) if project_out else nn.Identity() 53 | 54 | def forward(self, x): 55 | qkv = self.to_qkv(x).chunk(3, dim = -1) 56 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv) 57 | 58 | dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale 59 | 60 | attn = self.attend(dots) 61 | attn = self.dropout(attn) 62 | 63 | out = torch.matmul(attn, v) 64 | out = rearrange(out, 'b h n d -> b n (h d)') 65 | return self.to_out(out) 66 | 67 | class Transformer(nn.Module): 68 | def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): 69 | super().__init__() 70 | self.layers = nn.ModuleList([]) 71 | for _ in range(depth): 72 | self.layers.append(nn.ModuleList([ 73 | PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), 74 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) 75 | ])) 76 | def forward(self, x): 77 | for attn, ff in self.layers: 78 | x = attn(x) + x 79 | x = ff(x) + x 80 | return x 81 | #############3 82 | 83 | 84 | 85 | 86 | class NullPrompter(nn.Module): 87 | def __init__(self): 88 | super(NullPrompter, self).__init__() 89 | pass 90 | 91 | def forward(self, x): 92 | return x 93 | 94 | 95 | class PadPrompter(nn.Module): 96 | def __init__(self, args): 97 | super(PadPrompter, self).__init__() 98 | pad_size = args.prompt_size 99 | image_size = args.image_size 100 | 101 | self.base_size = image_size - pad_size*2 102 | self.pad_up = nn.Parameter(torch.randn([1, 3, pad_size, image_size])) 103 | self.pad_down = nn.Parameter(torch.randn([1, 3, pad_size, image_size])) 104 | self.pad_left = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size])) 105 | self.pad_right = nn.Parameter(torch.randn([1, 3, image_size - pad_size*2, pad_size])) 106 | 107 | def forward(self, x): 108 | base = torch.zeros(1, 3, self.base_size, self.base_size).cuda() 109 | prompt = torch.cat([self.pad_left, base, self.pad_right], dim=3) 110 | prompt = torch.cat([self.pad_up, prompt, self.pad_down], dim=2) 111 | prompt = torch.cat(x.size(0) * [prompt]) 112 | 113 | return x + prompt 114 | 115 | class TokenPrompter(nn.Module): 116 | def __init__(self, prompt_len) -> None: 117 | super(TokenPrompter, self).__init__() 118 | 119 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768])) 120 | 121 | def forward(self): 122 | return self.prompt 123 | 124 | 125 | class TokenPrompter_w_pos(nn.Module): 126 | def __init__(self, prompt_len) -> None: 127 | super(TokenPrompter_w_pos, self).__init__() 128 | 129 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768])) 130 | self.pos_embedding = nn.Parameter(torch.randn(1, prompt_len, 1)) 131 | 132 | def forward(self): 133 | return self.prompt + self.pos_embedding 134 | 135 | 136 | class TokenPrompter_w_pos_TransformerGEN(nn.Module): 137 | def __init__(self, prompt_len) -> None: 138 | super(TokenPrompter_w_pos_TransformerGEN, self).__init__() 139 | 140 | self.prompt = nn.Parameter(torch.randn([1, prompt_len, 768])) 141 | 142 | self.dropout = nn.Dropout(0) 143 | self.transformer = Transformer(768, 3, 4, 768, 768) 144 | 145 | self.pos_embedding = nn.Parameter(torch.randn(1, prompt_len, 1)) 146 | 147 | def forward(self): 148 | return self.transformer(self.prompt + self.pos_embedding) 149 | 150 | class FixedPatchPrompter(nn.Module): 151 | def __init__(self, args): 152 | super(FixedPatchPrompter, self).__init__() 153 | self.isize = args.image_size 154 | self.psize = args.prompt_size 155 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize])) 156 | 157 | def forward(self, x): 158 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda() 159 | prompt[:, :, :self.psize, :self.psize] = self.patch 160 | 161 | return x + prompt 162 | 163 | 164 | class RandomPatchPrompter(nn.Module): 165 | def __init__(self, args): 166 | super(RandomPatchPrompter, self).__init__() 167 | self.isize = args.image_size 168 | self.psize = args.prompt_size 169 | self.patch = nn.Parameter(torch.randn([1, 3, self.psize, self.psize])) 170 | 171 | def forward(self, x): 172 | x_ = np.random.choice(self.isize - self.psize) 173 | y_ = np.random.choice(self.isize - self.psize) 174 | 175 | prompt = torch.zeros([1, 3, self.isize, self.isize]).cuda() 176 | prompt[:, :, x_:x_ + self.psize, y_:y_ + self.psize] = self.patch 177 | 178 | return x + prompt 179 | 180 | 181 | def padding(args): 182 | return PadPrompter(args) 183 | 184 | 185 | def fixed_patch(args): 186 | return FixedPatchPrompter(args) 187 | 188 | 189 | def random_patch(args): 190 | return RandomPatchPrompter(args) 191 | 192 | def null_patch(args): 193 | return NullPrompter() 194 | -------------------------------------------------------------------------------- /code/replace/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/code/replace/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /code/replace/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import urllib 4 | import warnings 5 | from typing import Any, Union, List 6 | from pkg_resources import packaging 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from .model import build_model 14 | from .simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"): 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", 36 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 37 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 38 | "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", 39 | "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", 40 | } 41 | 42 | 43 | def _download(url: str, root: str): 44 | os.makedirs(root, exist_ok=True) 45 | filename = os.path.basename(url) 46 | 47 | expected_sha256 = url.split("/")[-2] 48 | download_target = os.path.join(root, filename) 49 | 50 | if os.path.exists(download_target) and not os.path.isfile(download_target): 51 | raise RuntimeError(f"{download_target} exists and is not a regular file") 52 | 53 | if os.path.isfile(download_target): 54 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 55 | return download_target 56 | else: 57 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 58 | 59 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 60 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop: 61 | while True: 62 | buffer = source.read(8192) 63 | if not buffer: 64 | break 65 | 66 | output.write(buffer) 67 | loop.update(len(buffer)) 68 | 69 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 70 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 71 | 72 | return download_target 73 | 74 | 75 | def _convert_image_to_rgb(image): 76 | return image.convert("RGB") 77 | 78 | 79 | def _transform(n_px): 80 | return Compose([ 81 | Resize(n_px, interpolation=BICUBIC), 82 | CenterCrop(n_px), 83 | _convert_image_to_rgb, 84 | ToTensor(), 85 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 86 | ]) 87 | 88 | 89 | def available_models() -> List[str]: 90 | """Returns the names of available CLIP models""" 91 | return list(_MODELS.keys()) 92 | 93 | 94 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None, prompt_len: int = 0): 95 | """Load a CLIP model 96 | 97 | Parameters 98 | ---------- 99 | name : str 100 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 101 | 102 | device : Union[str, torch.device] 103 | The device to put the loaded model 104 | 105 | jit : bool 106 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 107 | 108 | download_root: str 109 | path to download the model files; by default, it uses "~/.cache/clip" 110 | 111 | Returns 112 | ------- 113 | model : torch.nn.Module 114 | The CLIP model 115 | 116 | preprocess : Callable[[PIL.Image], torch.Tensor] 117 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 118 | """ 119 | if name in _MODELS: 120 | model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip")) 121 | elif os.path.isfile(name): 122 | model_path = name 123 | else: 124 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 125 | 126 | with open(model_path, 'rb') as opened_file: 127 | try: 128 | # loading JIT archive 129 | model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval() 130 | state_dict = None 131 | except RuntimeError: 132 | # loading saved state dict 133 | if jit: 134 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 135 | jit = False 136 | state_dict = torch.load(opened_file, map_location="cpu") 137 | 138 | if not jit: 139 | model = build_model(state_dict or model.state_dict(), prompt_len).to(device) 140 | if str(device) == "cpu": 141 | model.float() 142 | return model, _transform(model.visual.input_resolution) 143 | 144 | # patch the device names 145 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 146 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 147 | 148 | def patch_device(module): 149 | try: 150 | graphs = [module.graph] if hasattr(module, "graph") else [] 151 | except RuntimeError: 152 | graphs = [] 153 | 154 | if hasattr(module, "forward1"): 155 | graphs.append(module.forward1.graph) 156 | 157 | for graph in graphs: 158 | for node in graph.findAllNodes("prim::Constant"): 159 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 160 | node.copyAttributes(device_node) 161 | 162 | model.apply(patch_device) 163 | patch_device(model.encode_image) 164 | patch_device(model.encode_text) 165 | 166 | # patch dtype to float32 on CPU 167 | if str(device) == "cpu": 168 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 169 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 170 | float_node = float_input.node() 171 | 172 | def patch_float(module): 173 | try: 174 | graphs = [module.graph] if hasattr(module, "graph") else [] 175 | except RuntimeError: 176 | graphs = [] 177 | 178 | if hasattr(module, "forward1"): 179 | graphs.append(module.forward1.graph) 180 | 181 | for graph in graphs: 182 | for node in graph.findAllNodes("aten::to"): 183 | inputs = list(node.inputs()) 184 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 185 | if inputs[i].node()["value"] == 5: 186 | inputs[i].node().copyAttributes(float_node) 187 | 188 | model.apply(patch_float) 189 | patch_float(model.encode_image) 190 | patch_float(model.encode_text) 191 | 192 | model.float() 193 | 194 | return model, _transform(model.input_resolution.item()) 195 | 196 | 197 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]: 198 | """ 199 | Returns the tokenized representation of given input string(s) 200 | 201 | Parameters 202 | ---------- 203 | texts : Union[str, List[str]] 204 | An input string or a list of input strings to tokenize 205 | 206 | context_length : int 207 | The context length to use; all CLIP models use 77 as the context length 208 | 209 | truncate: bool 210 | Whether to truncate the text in case its encoding is longer than the context length 211 | 212 | Returns 213 | ------- 214 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]. 215 | We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long. 216 | """ 217 | if isinstance(texts, str): 218 | texts = [texts] 219 | 220 | sot_token = _tokenizer.encoder["<|startoftext|>"] 221 | eot_token = _tokenizer.encoder["<|endoftext|>"] 222 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 223 | if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"): 224 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 225 | else: 226 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.int) 227 | 228 | for i, tokens in enumerate(all_tokens): 229 | if len(tokens) > context_length: 230 | if truncate: 231 | tokens = tokens[:context_length] 232 | tokens[-1] = eot_token 233 | else: 234 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 235 | result[i, :len(tokens)] = torch.tensor(tokens) 236 | 237 | return result 238 | -------------------------------------------------------------------------------- /code/replace/datasets/caltech.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from typing import Any, Callable, List, Optional, Union, Tuple 4 | 5 | from PIL import Image 6 | 7 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | 11 | class Caltech101(VisionDataset): 12 | """`Caltech 101 `_ Dataset. 13 | 14 | .. warning:: 15 | 16 | This class needs `scipy `_ to load target files from `.mat` format. 17 | 18 | Args: 19 | root (string): Root directory of dataset where directory 20 | ``caltech101`` exists or will be saved to if download is set to True. 21 | target_type (string or list, optional): Type of target to use, ``category`` or 22 | ``annotation``. Can also be a list to output a tuple with all specified 23 | target types. ``category`` represents the target class, and 24 | ``annotation`` is a list of points from a hand-generated outline. 25 | Defaults to ``category``. 26 | transform (callable, optional): A function/transform that takes in an PIL image 27 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 28 | target_transform (callable, optional): A function/transform that takes in the 29 | target and transforms it. 30 | download (bool, optional): If true, downloads the dataset from the internet and 31 | puts it in root directory. If dataset is already downloaded, it is not 32 | downloaded again. 33 | """ 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | target_type: Union[List[str], str] = "category", 39 | transform: Optional[Callable] = None, 40 | target_transform: Optional[Callable] = None, 41 | download: bool = False, 42 | ) -> None: 43 | super().__init__(os.path.join(root, "caltech101"), transform=transform, target_transform=target_transform) 44 | os.makedirs(self.root, exist_ok=True) 45 | if isinstance(target_type, str): 46 | target_type = [target_type] 47 | self.target_type = [verify_str_arg(t, "target_type", ("category", "annotation", "category_name")) for t in target_type] 48 | 49 | if download: 50 | self.download() 51 | 52 | if not self._check_integrity(): 53 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 54 | 55 | self.categories = sorted(os.listdir(os.path.join(self.root, "101_ObjectCategories"))) 56 | self.categories.remove("BACKGROUND_Google") # this is not a real class 57 | 58 | # For some reason, the category names in "101_ObjectCategories" and 59 | # "Annotations" do not always match. This is a manual map between the 60 | # two. Defaults to using same name, since most names are fine. 61 | name_map = { 62 | "Faces": "Faces_2", 63 | "Faces_easy": "Faces_3", 64 | "Motorbikes": "Motorbikes_16", 65 | "airplanes": "Airplanes_Side_2", 66 | } 67 | self.annotation_categories = list(map(lambda x: name_map[x] if x in name_map else x, self.categories)) 68 | 69 | self.index: List[int] = [] 70 | self.y = [] 71 | for (i, c) in enumerate(self.categories): 72 | n = len(os.listdir(os.path.join(self.root, "101_ObjectCategories", c))) 73 | self.index.extend(range(1, n + 1)) 74 | self.y.extend(n * [i]) 75 | 76 | self.clip_categories = self.categories.copy() 77 | self.clip_categories[0] = "person" 78 | self.clip_categories[1] = "person" 79 | 80 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 81 | """ 82 | Args: 83 | index (int): Index 84 | 85 | Returns: 86 | tuple: (image, target) where the type of target specified by target_type. 87 | """ 88 | import scipy.io 89 | 90 | img = Image.open( 91 | os.path.join( 92 | self.root, 93 | "101_ObjectCategories", 94 | self.categories[self.y[index]], 95 | f"image_{self.index[index]:04d}.jpg", 96 | ) 97 | ) 98 | 99 | target: Any = [] 100 | for t in self.target_type: 101 | if t == "category": 102 | target.append(self.y[index]) 103 | elif t == "category_name": 104 | target.append(self.clip_categories[self.y[index]]) 105 | elif t == "annotation": 106 | data = scipy.io.loadmat( 107 | os.path.join( 108 | self.root, 109 | "Annotations", 110 | self.annotation_categories[self.y[index]], 111 | f"annotation_{self.index[index]:04d}.mat", 112 | ) 113 | ) 114 | target.append(data["obj_contour"]) 115 | target = tuple(target) if len(target) > 1 else target[0] 116 | 117 | if self.transform is not None: 118 | img = self.transform(img) 119 | 120 | if self.target_transform is not None: 121 | target = self.target_transform(target) 122 | 123 | return img, target 124 | 125 | def _check_integrity(self) -> bool: 126 | # can be more robust and check hash of files 127 | return os.path.exists(os.path.join(self.root, "101_ObjectCategories")) 128 | 129 | def __len__(self) -> int: 130 | return len(self.index) 131 | 132 | def download(self) -> None: 133 | if self._check_integrity(): 134 | print("Files already downloaded and verified") 135 | return 136 | 137 | download_and_extract_archive( 138 | "https://drive.google.com/file/d/137RyRjvTBkBiIfeYBNZBtViDHQ6_Ewsp", 139 | self.root, 140 | filename="101_ObjectCategories.tar.gz", 141 | md5="b224c7392d521a49829488ab0f1120d9", 142 | ) 143 | download_and_extract_archive( 144 | "https://drive.google.com/file/d/175kQy3UsZ0wUEHZjqkUDdNVssr7bgh_m", 145 | self.root, 146 | filename="Annotations.tar", 147 | md5="6f83eeb1f24d99cab4eb377263132c91", 148 | ) 149 | 150 | def extra_repr(self) -> str: 151 | return "Target type: {target_type}".format(**self.__dict__) 152 | 153 | 154 | class Caltech256(VisionDataset): 155 | """`Caltech 256 `_ Dataset. 156 | 157 | Args: 158 | root (string): Root directory of dataset where directory 159 | ``caltech256`` exists or will be saved to if download is set to True. 160 | transform (callable, optional): A function/transform that takes in an PIL image 161 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 162 | target_transform (callable, optional): A function/transform that takes in the 163 | target and transforms it. 164 | download (bool, optional): If true, downloads the dataset from the internet and 165 | puts it in root directory. If dataset is already downloaded, it is not 166 | downloaded again. 167 | """ 168 | 169 | def __init__( 170 | self, 171 | root: str, 172 | transform: Optional[Callable] = None, 173 | target_transform: Optional[Callable] = None, 174 | download: bool = False, 175 | prompt_template = "A photo of a {}." 176 | ) -> None: 177 | super().__init__(os.path.join(root, "caltech256"), transform=transform, target_transform=target_transform) 178 | os.makedirs(self.root, exist_ok=True) 179 | 180 | if download: 181 | self.download() 182 | 183 | if not self._check_integrity(): 184 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 185 | 186 | self.categories = sorted(os.listdir(os.path.join(self.root, "256_ObjectCategories"))) 187 | self.index: List[int] = [] 188 | self.y = [] 189 | for (i, c) in enumerate(self.categories): 190 | n = len( 191 | [ 192 | item 193 | for item in os.listdir(os.path.join(self.root, "256_ObjectCategories", c)) 194 | if item.endswith(".jpg") 195 | ] 196 | ) 197 | self.index.extend(range(1, n + 1)) 198 | self.y.extend(n * [i]) 199 | 200 | refined_classes = [] 201 | for class_name in self.categories: 202 | class_name = class_name[4:] 203 | refined_classes.append(class_name.replace('-101', '')) 204 | 205 | self.prompt_template = prompt_template 206 | self.clip_prompts = [ 207 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ').strip()) \ 208 | for label in refined_classes 209 | ] 210 | 211 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 212 | """ 213 | Args: 214 | index (int): Index 215 | 216 | Returns: 217 | tuple: (image, target) where target is index of the target class. 218 | """ 219 | img = Image.open( 220 | os.path.join( 221 | self.root, 222 | "256_ObjectCategories", 223 | self.categories[self.y[index]], 224 | f"{self.y[index] + 1:03d}_{self.index[index]:04d}.jpg", 225 | ) 226 | ) 227 | 228 | target = self.y[index] 229 | 230 | if self.transform is not None: 231 | img = self.transform(img) 232 | 233 | if self.target_transform is not None: 234 | target = self.target_transform(target) 235 | 236 | return img, target 237 | 238 | def _check_integrity(self) -> bool: 239 | # can be more robust and check hash of files 240 | return os.path.exists(os.path.join(self.root, "256_ObjectCategories")) 241 | 242 | def __len__(self) -> int: 243 | return len(self.index) 244 | 245 | def download(self) -> None: 246 | if self._check_integrity(): 247 | print("Files already downloaded and verified") 248 | return 249 | 250 | download_and_extract_archive( 251 | "https://drive.google.com/file/d/1r6o0pSROcV1_VwT4oSjA2FBUSCWGuxLK", 252 | self.root, 253 | filename="256_ObjectCategories.tar", 254 | md5="67b4f42ca05d46448c6bb8ecd2220f6d", 255 | ) 256 | -------------------------------------------------------------------------------- /code/replace/datasets/country211.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, Optional 3 | import pycountry 4 | 5 | from torchvision.datasets.folder import ImageFolder 6 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive 7 | 8 | 9 | class Country211(ImageFolder): 10 | """`The Country211 Data Set `_ from OpenAI. 11 | 12 | This dataset was built by filtering the images from the YFCC100m dataset 13 | that have GPS coordinate corresponding to a ISO-3166 country code. The 14 | dataset is balanced by sampling 150 train images, 50 validation images, and 15 | 100 test images images for each country. 16 | 17 | Args: 18 | root (string): Root directory of the dataset. 19 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"valid"`` and ``"test"``. 20 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 21 | version. E.g, ``transforms.RandomCrop``. 22 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 23 | download (bool, optional): If True, downloads the dataset from the internet and puts it into 24 | ``root/country211/``. If dataset is already downloaded, it is not downloaded again. 25 | """ 26 | 27 | _URL = "https://openaipublic.azureedge.net/clip/data/country211.tgz" 28 | _MD5 = "84988d7644798601126c29e9877aab6a" 29 | 30 | def __init__( 31 | self, 32 | root: str, 33 | split: str = "train", 34 | transform: Optional[Callable] = None, 35 | target_transform: Optional[Callable] = None, 36 | download: bool = False, 37 | prompt_template = "A photo I took in {}." 38 | ) -> None: 39 | self._split = verify_str_arg(split, "split", ("train", "valid", "test")) 40 | 41 | root = Path(root).expanduser() 42 | self.root = str(root) 43 | self._base_folder = root / "country211" 44 | 45 | if download: 46 | self._download() 47 | 48 | if not self._check_exists(): 49 | raise RuntimeError("Dataset not found. You can use download=True to download it") 50 | 51 | super().__init__(str(self._base_folder / self._split), transform=transform, target_transform=target_transform) 52 | 53 | countries = {} 54 | self.code_to_country = dict() 55 | for country in pycountry.countries: 56 | countries[country.alpha_2] = country.name 57 | 58 | countries['XK'] = 'Kosovo' 59 | 60 | for code in self.class_to_idx: 61 | self.code_to_country[code] = countries[code].split(',')[0] 62 | 63 | self.prompt_template = prompt_template 64 | self.clip_prompts = [ 65 | prompt_template.format(countries[label].replace('_', ' ').replace('-', ' ')) \ 66 | for label in self.classes 67 | ] 68 | 69 | self.root = str(root) 70 | 71 | def _check_exists(self) -> bool: 72 | return self._base_folder.exists() and self._base_folder.is_dir() 73 | 74 | def _download(self) -> None: 75 | if self._check_exists(): 76 | return 77 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) 78 | -------------------------------------------------------------------------------- /code/replace/datasets/dtd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from typing import Optional, Callable 4 | 5 | import PIL.Image 6 | 7 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | 11 | class DTD(VisionDataset): 12 | """`Describable Textures Dataset (DTD) `_. 13 | 14 | Args: 15 | root (string): Root directory of the dataset. 16 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. 17 | partition (int, optional): The dataset partition. Should be ``1 <= partition <= 10``. Defaults to ``1``. 18 | 19 | .. note:: 20 | 21 | The partition only changes which split each image belongs to. Thus, regardless of the selected 22 | partition, combining all splits will result in all images. 23 | 24 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed 25 | version. E.g, ``transforms.RandomCrop``. 26 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 27 | download (bool, optional): If True, downloads the dataset from the internet and 28 | puts it in root directory. If dataset is already downloaded, it is not 29 | downloaded again. Default is False. 30 | """ 31 | 32 | _URL = "https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz" 33 | _MD5 = "fff73e5086ae6bdbea199a49dfb8a4c1" 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | split: str = "train", 39 | partition: int = 1, 40 | transform: Optional[Callable] = None, 41 | target_transform: Optional[Callable] = None, 42 | download: bool = False, 43 | prompt_template = "A surface with a {} texture." 44 | ) -> None: 45 | self._split = verify_str_arg(split, "split", ("train", "val", "test")) 46 | if not isinstance(partition, int) and not (1 <= partition <= 10): 47 | raise ValueError( 48 | f"Parameter 'partition' should be an integer with `1 <= partition <= 10`, " 49 | f"but got {partition} instead" 50 | ) 51 | self._partition = partition 52 | 53 | super().__init__(root, transform=transform, target_transform=target_transform) 54 | self._base_folder = pathlib.Path(self.root) / type(self).__name__.lower() 55 | self._data_folder = self._base_folder / "dtd" 56 | self._meta_folder = self._data_folder / "labels" 57 | self._images_folder = self._data_folder / "images" 58 | 59 | if download: 60 | self._download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | self._image_files = [] 66 | classes = [] 67 | with open(self._meta_folder / f"{self._split}{self._partition}.txt") as file: 68 | for line in file: 69 | cls, name = line.strip().split("/") 70 | self._image_files.append(self._images_folder.joinpath(cls, name)) 71 | classes.append(cls) 72 | 73 | self.classes = sorted(set(classes)) 74 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 75 | self.idx_to_class = dict(zip(range(len(self.classes)), self.classes)) 76 | self._labels = [self.class_to_idx[cls] for cls in classes] 77 | 78 | self.prompt_template = prompt_template 79 | 80 | self.clip_prompts = [ 81 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \ 82 | for label in self.classes 83 | ] 84 | 85 | def __len__(self) -> int: 86 | return len(self._image_files) 87 | 88 | def __getitem__(self, idx): 89 | image_file, label = self._image_files[idx], self._labels[idx] 90 | image = PIL.Image.open(image_file).convert("RGB") 91 | 92 | if self.transform: 93 | image = self.transform(image) 94 | 95 | if self.target_transform: 96 | label = self.target_transform(label) 97 | 98 | return image, label 99 | 100 | def extra_repr(self) -> str: 101 | return f"split={self._split}, partition={self._partition}" 102 | 103 | def _check_exists(self) -> bool: 104 | return os.path.exists(self._data_folder) and os.path.isdir(self._data_folder) 105 | 106 | def _download(self) -> None: 107 | if self._check_exists(): 108 | return 109 | download_and_extract_archive(self._URL, download_root=str(self._base_folder), md5=self._MD5) 110 | -------------------------------------------------------------------------------- /code/replace/datasets/eurosat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Optional 3 | 4 | from torchvision.datasets.folder import ImageFolder 5 | from torchvision.datasets.utils import download_and_extract_archive 6 | 7 | 8 | class EuroSAT(ImageFolder): 9 | """RGB version of the `EuroSAT `_ Dataset. 10 | 11 | Args: 12 | root (string): Root directory of dataset where ``root/eurosat`` exists. 13 | transform (callable, optional): A function/transform that takes in an PIL image 14 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 15 | target_transform (callable, optional): A function/transform that takes in the 16 | target and transforms it. 17 | download (bool, optional): If True, downloads the dataset from the internet and 18 | puts it in root directory. If dataset is already downloaded, it is not 19 | downloaded again. Default is False. 20 | """ 21 | idx_to_class = { 22 | 0: 'annual crop land', 23 | 1: 'a forest', 24 | 2: 'brushland or shrubland', 25 | 3: 'a highway or a road', 26 | 4: 'industrial buildings', 27 | 5: 'pasture land', 28 | 6: 'permanent crop land', 29 | 7: 'residential buildings', 30 | 8: 'a river', 31 | 9: 'a sea or a lake' 32 | } 33 | 34 | def __init__( 35 | self, 36 | root: str, 37 | transform: Optional[Callable] = None, 38 | target_transform: Optional[Callable] = None, 39 | download: bool = False, 40 | prompt_template = "A centered satellite photo of {}." 41 | ) -> None: 42 | self.root = os.path.expanduser(root) 43 | self._base_folder = os.path.join(self.root, "eurosat") 44 | self._data_folder = os.path.join(self._base_folder, "2750") 45 | 46 | if download: 47 | self.download() 48 | 49 | if not self._check_exists(): 50 | raise RuntimeError("Dataset not found. You can use download=True to download it") 51 | 52 | super().__init__(self._data_folder, transform=transform, target_transform=target_transform) 53 | self.root = os.path.expanduser(root) 54 | 55 | self.prompt_template = prompt_template 56 | self.clip_prompts = [ 57 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \ 58 | for label in self.idx_to_class.values() 59 | ] 60 | 61 | def __len__(self) -> int: 62 | return len(self.samples) 63 | 64 | def _check_exists(self) -> bool: 65 | return os.path.exists(self._data_folder) 66 | 67 | def download(self) -> None: 68 | 69 | if self._check_exists(): 70 | return 71 | 72 | os.makedirs(self._base_folder, exist_ok=True) 73 | download_and_extract_archive( 74 | "https://madm.dfki.de/files/sentinel/EuroSAT.zip", 75 | download_root=self._base_folder, 76 | md5="c8fa014336c82ac7804f0398fcb19387", 77 | ) 78 | -------------------------------------------------------------------------------- /code/replace/datasets/fgvc_aircraft.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | from typing import Any, Callable, Optional, Tuple 5 | 6 | import PIL.Image 7 | 8 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 9 | from torchvision.datasets.vision import VisionDataset 10 | 11 | 12 | class FGVCAircraft(VisionDataset): 13 | """`FGVC Aircraft `_ Dataset. 14 | 15 | The dataset contains 10,000 images of aircraft, with 100 images for each of 100 16 | different aircraft model variants, most of which are airplanes. 17 | Aircraft models are organized in a three-levels hierarchy. The three levels, from 18 | finer to coarser, are: 19 | 20 | - ``variant``, e.g. Boeing 737-700. A variant collapses all the models that are visually 21 | indistinguishable into one class. The dataset comprises 100 different variants. 22 | - ``family``, e.g. Boeing 737. The dataset comprises 70 different families. 23 | - ``manufacturer``, e.g. Boeing. The dataset comprises 30 different manufacturers. 24 | 25 | Args: 26 | root (string): Root directory of the FGVC Aircraft dataset. 27 | split (string, optional): The dataset split, supports ``train``, ``val``, 28 | ``trainval`` and ``test``. 29 | annotation_level (str, optional): The annotation level, supports ``variant``, 30 | ``family`` and ``manufacturer``. 31 | transform (callable, optional): A function/transform that takes in an PIL image 32 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 33 | target_transform (callable, optional): A function/transform that takes in the 34 | target and transforms it. 35 | download (bool, optional): If True, downloads the dataset from the internet and 36 | puts it in root directory. If dataset is already downloaded, it is not 37 | downloaded again. 38 | """ 39 | 40 | _URL = "https://www.robots.ox.ac.uk/~vgg/data/fgvc-aircraft/archives/fgvc-aircraft-2013b.tar.gz" 41 | 42 | def __init__( 43 | self, 44 | root: str, 45 | split: str = "trainval", 46 | annotation_level: str = "variant", 47 | transform: Optional[Callable] = None, 48 | target_transform: Optional[Callable] = None, 49 | download: bool = False, 50 | prompt_template = "A photo of a {}, a type of aircraft." 51 | ) -> None: 52 | super().__init__(root, transform=transform, target_transform=target_transform) 53 | self._split = verify_str_arg(split, "split", ("train", "val", "trainval", "test")) 54 | self._annotation_level = verify_str_arg( 55 | annotation_level, "annotation_level", ("variant", "family", "manufacturer") 56 | ) 57 | 58 | self._data_path = os.path.join(self.root, "fgvc-aircraft-2013b") 59 | if download: 60 | self._download() 61 | 62 | if not self._check_exists(): 63 | raise RuntimeError("Dataset not found. You can use download=True to download it") 64 | 65 | annotation_file = os.path.join( 66 | self._data_path, 67 | "data", 68 | { 69 | "variant": "variants.txt", 70 | "family": "families.txt", 71 | "manufacturer": "manufacturers.txt", 72 | }[self._annotation_level], 73 | ) 74 | with open(annotation_file, "r") as f: 75 | self.classes = [line.strip() for line in f] 76 | 77 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 78 | 79 | image_data_folder = os.path.join(self._data_path, "data", "images") 80 | labels_file = os.path.join(self._data_path, "data", f"images_{self._annotation_level}_{self._split}.txt") 81 | 82 | self._image_files = [] 83 | self._labels = [] 84 | 85 | with open(labels_file, "r") as f: 86 | for line in f: 87 | image_name, label_name = line.strip().split(" ", 1) 88 | self._image_files.append(os.path.join(image_data_folder, f"{image_name}.jpg")) 89 | self._labels.append(self.class_to_idx[label_name]) 90 | 91 | refined_classes = [] 92 | for label in self.classes: 93 | if label[0] == '7': 94 | refined_classes.append(f"Boeing {label}") 95 | else: 96 | refined_classes.append(label) 97 | 98 | self.prompt_template = prompt_template 99 | self.clip_prompts = [ 100 | prompt_template.format(label.lower().replace('_', ' ')) \ 101 | for label in refined_classes 102 | ] 103 | 104 | def __len__(self) -> int: 105 | return len(self._image_files) 106 | 107 | def __getitem__(self, idx) -> Tuple[Any, Any]: 108 | image_file, label = self._image_files[idx], self._labels[idx] 109 | image = PIL.Image.open(image_file).convert("RGB") 110 | 111 | if self.transform: 112 | image = self.transform(image) 113 | 114 | if self.target_transform: 115 | label = self.target_transform(label) 116 | 117 | return image, label 118 | 119 | def _download(self) -> None: 120 | """ 121 | Download the FGVC Aircraft dataset archive and extract it under root. 122 | """ 123 | if self._check_exists(): 124 | return 125 | download_and_extract_archive(self._URL, self.root) 126 | 127 | def _check_exists(self) -> bool: 128 | return os.path.exists(self._data_path) and os.path.isdir(self._data_path) 129 | -------------------------------------------------------------------------------- /code/replace/datasets/flowers102.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Tuple, Callable, Optional 3 | 4 | import PIL.Image 5 | 6 | from torchvision.datasets.utils import check_integrity, download_and_extract_archive, download_url, verify_str_arg 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | class Flowers102(VisionDataset): 11 | """`Oxford 102 Flower `_ Dataset. 12 | 13 | .. warning:: 14 | 15 | This class needs `scipy `_ to load target files from `.mat` format. 16 | 17 | Oxford 102 Flower is an image classification dataset consisting of 102 flower categories. The 18 | flowers were chosen to be flowers commonly occurring in the United Kingdom. Each class consists of 19 | between 40 and 258 images. 20 | 21 | The images have large scale, pose and light variations. In addition, there are categories that 22 | have large variations within the category, and several very similar categories. 23 | 24 | Args: 25 | root (string): Root directory of the dataset. 26 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"val"``, or ``"test"``. 27 | transform (callable, optional): A function/transform that takes in an PIL image and returns a 28 | transformed version. E.g, ``transforms.RandomCrop``. 29 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 30 | download (bool, optional): If true, downloads the dataset from the internet and 31 | puts it in root directory. If dataset is already downloaded, it is not 32 | downloaded again. 33 | """ 34 | 35 | _download_url_prefix = "https://www.robots.ox.ac.uk/~vgg/data/flowers/102/" 36 | _file_dict = { # filename, md5 37 | "image": ("102flowers.tgz", "52808999861908f626f3c1f4e79d11fa"), 38 | "label": ("imagelabels.mat", "e0620be6f572b9609742df49c70aed4d"), 39 | "setid": ("setid.mat", "a5357ecc9cb78c4bef273ce3793fc85c"), 40 | } 41 | _splits_map = {"train": "trnid", "val": "valid", "test": "tstid"} 42 | 43 | _classes = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 44 | 'sweet pea', 'english marigold', 'tiger lily', 'moon orchid', 'bird of paradise', 45 | 'monkshood', 'globe thistle', 'snapdragon', 'colts foot', 'king protea', 'spear thistle', 46 | 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 47 | 'giant white arum lily', 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 48 | 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 49 | 'sweet william', 'carnation', 'garden phlox', 'love in the mist', 'mexican aster', 50 | 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 51 | 'lenten rose', 'barbeton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 52 | 'wallflower', 'marigold', 'buttercup', 'oxeye daisy', 'common dandelion', 'petunia', 53 | 'wild pansy', 'primula', 'sunflower', 'pelargonium', 'bishop of llandaff', 'gaura', 'geranium', 54 | 'orange dahlia', 'pink-yellow dahlia?', 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 55 | 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'bearded iris', 'windflower', 56 | 'tree poppy', 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 57 | 'passion flower', 'lotus', 'toad lily', 'anthurium', 'frangipani', 'clematis', 'hibiscus', 58 | 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', 59 | 'hippeastrum ', 'bee balm', 'ball moss', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 60 | 'mexican petunia', 'bromelia', 'blanket flower', 'trumpet creeper', 'blackberry lily'] 61 | 62 | def __init__( 63 | self, 64 | root: str, 65 | split: str = "train", 66 | transform: Optional[Callable] = None, 67 | target_transform: Optional[Callable] = None, 68 | download: bool = False, 69 | prompt_template = "A photo of a {}, a type of flower." 70 | ) -> None: 71 | super().__init__(root, transform=transform, target_transform=target_transform) 72 | self._split = verify_str_arg(split, "split", ("train", "val", "test")) 73 | self._base_folder = Path(self.root) / "flowers-102" 74 | self._images_folder = self._base_folder / "jpg" 75 | 76 | if download: 77 | self.download() 78 | 79 | if not self._check_integrity(): 80 | raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it") 81 | 82 | from scipy.io import loadmat 83 | 84 | set_ids = loadmat(self._base_folder / self._file_dict["setid"][0], squeeze_me=True) 85 | image_ids = set_ids[self._splits_map[self._split]].tolist() 86 | 87 | labels = loadmat(self._base_folder / self._file_dict["label"][0], squeeze_me=True) 88 | image_id_to_label = dict(enumerate((labels["labels"] - 1).tolist(), 1)) 89 | 90 | self._labels = [] 91 | self._image_files = [] 92 | for image_id in image_ids: 93 | self._labels.append(image_id_to_label[image_id]) 94 | self._image_files.append(self._images_folder / f"image_{image_id:05d}.jpg") 95 | 96 | self.idx_to_class = {label:self._classes[label] for label in range(102)} 97 | 98 | self.prompt_template = prompt_template 99 | self.clip_prompts = [ 100 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \ 101 | for label in self._classes 102 | ] 103 | 104 | 105 | def __len__(self) -> int: 106 | return len(self._image_files) 107 | 108 | def __getitem__(self, idx) -> Tuple[Any, Any]: 109 | image_file, label = self._image_files[idx], self._labels[idx] 110 | image = PIL.Image.open(image_file).convert("RGB") 111 | 112 | if self.transform: 113 | image = self.transform(image) 114 | 115 | if self.target_transform: 116 | label = self.target_transform(label) 117 | 118 | return image, label 119 | 120 | def extra_repr(self) -> str: 121 | return f"split={self._split}" 122 | 123 | def _check_integrity(self): 124 | if not (self._images_folder.exists() and self._images_folder.is_dir()): 125 | return False 126 | 127 | for id in ["label", "setid"]: 128 | filename, md5 = self._file_dict[id] 129 | if not check_integrity(str(self._base_folder / filename), md5): 130 | return False 131 | return True 132 | 133 | def download(self): 134 | if self._check_integrity(): 135 | return 136 | download_and_extract_archive( 137 | f"{self._download_url_prefix}{self._file_dict['image'][0]}", 138 | str(self._base_folder), 139 | md5=self._file_dict["image"][1], 140 | ) 141 | for id in ["label", "setid"]: 142 | filename, md5 = self._file_dict[id] 143 | download_url(self._download_url_prefix + filename, str(self._base_folder), md5=md5) 144 | -------------------------------------------------------------------------------- /code/replace/datasets/folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union 4 | 5 | from PIL import Image 6 | 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | def has_file_allowed_extension(filename: str, extensions: Union[str, Tuple[str, ...]]) -> bool: 11 | """Checks if a file is an allowed extension. 12 | 13 | Args: 14 | filename (string): path to a file 15 | extensions (tuple of strings): extensions to consider (lowercase) 16 | 17 | Returns: 18 | bool: True if the filename ends with one of given extensions 19 | """ 20 | return filename.lower().endswith(extensions if isinstance(extensions, str) else tuple(extensions)) 21 | 22 | 23 | def is_image_file(filename: str) -> bool: 24 | """Checks if a file is an allowed image extension. 25 | 26 | Args: 27 | filename (string): path to a file 28 | 29 | Returns: 30 | bool: True if the filename ends with a known image extension 31 | """ 32 | return has_file_allowed_extension(filename, IMG_EXTENSIONS) 33 | 34 | 35 | def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]: 36 | """Finds the class folders in a dataset. 37 | 38 | See :class:`DatasetFolder` for details. 39 | """ 40 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 41 | if not classes: 42 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 43 | 44 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 45 | return classes, class_to_idx 46 | 47 | 48 | def make_dataset( 49 | directory: str, 50 | class_to_idx: Optional[Dict[str, int]] = None, 51 | extensions: Optional[Union[str, Tuple[str, ...]]] = None, 52 | is_valid_file: Optional[Callable[[str], bool]] = None, 53 | ) -> List[Tuple[str, int]]: 54 | """Generates a list of samples of a form (path_to_sample, class). 55 | 56 | See :class:`DatasetFolder` for details. 57 | 58 | Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function 59 | by default. 60 | """ 61 | directory = os.path.expanduser(directory) 62 | 63 | if class_to_idx is None: 64 | _, class_to_idx = find_classes(directory) 65 | elif not class_to_idx: 66 | raise ValueError("'class_to_index' must have at least one entry to collect any samples.") 67 | 68 | both_none = extensions is None and is_valid_file is None 69 | both_something = extensions is not None and is_valid_file is not None 70 | if both_none or both_something: 71 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 72 | 73 | if extensions is not None: 74 | 75 | def is_valid_file(x: str) -> bool: 76 | return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] 77 | 78 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 79 | 80 | instances = [] 81 | available_classes = set() 82 | for target_class in sorted(class_to_idx.keys()): 83 | class_index = class_to_idx[target_class] 84 | target_dir = os.path.join(directory, target_class) 85 | if not os.path.isdir(target_dir): 86 | continue 87 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 88 | for fname in sorted(fnames): 89 | path = os.path.join(root, fname) 90 | if is_valid_file(path): 91 | item = path, class_index 92 | instances.append(item) 93 | 94 | if target_class not in available_classes: 95 | available_classes.add(target_class) 96 | 97 | empty_classes = set(class_to_idx.keys()) - available_classes 98 | if empty_classes: 99 | msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " 100 | if extensions is not None: 101 | msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" 102 | raise FileNotFoundError(msg) 103 | 104 | return instances 105 | 106 | 107 | class DatasetFolder(VisionDataset): 108 | """A generic data loader. 109 | 110 | This default directory structure can be customized by overriding the 111 | :meth:`find_classes` method. 112 | 113 | Args: 114 | root (string): Root directory path. 115 | loader (callable): A function to load a sample given its path. 116 | extensions (tuple[string]): A list of allowed extensions. 117 | both extensions and is_valid_file should not be passed. 118 | transform (callable, optional): A function/transform that takes in 119 | a sample and returns a transformed version. 120 | E.g, ``transforms.RandomCrop`` for images. 121 | target_transform (callable, optional): A function/transform that takes 122 | in the target and transforms it. 123 | is_valid_file (callable, optional): A function that takes path of a file 124 | and check if the file is a valid file (used to check of corrupt files) 125 | both extensions and is_valid_file should not be passed. 126 | 127 | Attributes: 128 | classes (list): List of the class names sorted alphabetically. 129 | class_to_idx (dict): Dict with items (class_name, class_index). 130 | samples (list): List of (sample path, class_index) tuples 131 | targets (list): The class_index value for each image in the dataset 132 | """ 133 | 134 | def __init__( 135 | self, 136 | root: str, 137 | loader: Callable[[str], Any], 138 | extensions: Optional[Tuple[str, ...]] = None, 139 | transform: Optional[Callable] = None, 140 | target_transform: Optional[Callable] = None, 141 | is_valid_file: Optional[Callable[[str], bool]] = None, 142 | ) -> None: 143 | super().__init__(root, transform=transform, target_transform=target_transform) 144 | classes, class_to_idx = self.find_classes(self.root) 145 | samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file) 146 | 147 | self.loader = loader 148 | self.extensions = extensions 149 | 150 | self.classes = classes 151 | self.class_to_idx = class_to_idx 152 | self.samples = samples 153 | self.targets = [s[1] for s in samples] 154 | 155 | @staticmethod 156 | def make_dataset( 157 | directory: str, 158 | class_to_idx: Dict[str, int], 159 | extensions: Optional[Tuple[str, ...]] = None, 160 | is_valid_file: Optional[Callable[[str], bool]] = None, 161 | ) -> List[Tuple[str, int]]: 162 | """Generates a list of samples of a form (path_to_sample, class). 163 | 164 | This can be overridden to e.g. read files from a compressed zip file instead of from the disk. 165 | 166 | Args: 167 | directory (str): root dataset directory, corresponding to ``self.root``. 168 | class_to_idx (Dict[str, int]): Dictionary mapping class name to class index. 169 | extensions (optional): A list of allowed extensions. 170 | Either extensions or is_valid_file should be passed. Defaults to None. 171 | is_valid_file (optional): A function that takes path of a file 172 | and checks if the file is a valid file 173 | (used to check of corrupt files) both extensions and 174 | is_valid_file should not be passed. Defaults to None. 175 | 176 | Raises: 177 | ValueError: In case ``class_to_idx`` is empty. 178 | ValueError: In case ``extensions`` and ``is_valid_file`` are None or both are not None. 179 | FileNotFoundError: In case no valid file was found for any class. 180 | 181 | Returns: 182 | List[Tuple[str, int]]: samples of a form (path_to_sample, class) 183 | """ 184 | if class_to_idx is None: 185 | # prevent potential bug since make_dataset() would use the class_to_idx logic of the 186 | # find_classes() function, instead of using that of the find_classes() method, which 187 | # is potentially overridden and thus could have a different logic. 188 | raise ValueError("The class_to_idx parameter cannot be None.") 189 | return make_dataset(directory, class_to_idx, extensions=extensions, is_valid_file=is_valid_file) 190 | 191 | def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: 192 | """Find the class folders in a dataset structured as follows:: 193 | 194 | directory/ 195 | ├── class_x 196 | │ ├── xxx.ext 197 | │ ├── xxy.ext 198 | │ └── ... 199 | │ └── xxz.ext 200 | └── class_y 201 | ├── 123.ext 202 | ├── nsdf3.ext 203 | └── ... 204 | └── asd932_.ext 205 | 206 | This method can be overridden to only consider 207 | a subset of classes, or to adapt to a different dataset directory structure. 208 | 209 | Args: 210 | directory(str): Root directory path, corresponding to ``self.root`` 211 | 212 | Raises: 213 | FileNotFoundError: If ``dir`` has no class folders. 214 | 215 | Returns: 216 | (Tuple[List[str], Dict[str, int]]): List of all classes and dictionary mapping each class to an index. 217 | """ 218 | return find_classes(directory) 219 | 220 | def __getitem__(self, index: int) -> Tuple[Any, Any]: 221 | """ 222 | Args: 223 | index (int): Index 224 | 225 | Returns: 226 | tuple: (sample, target) where target is class_index of the target class. 227 | """ 228 | path, target = self.samples[index] 229 | sample = self.loader(path) 230 | if self.transform is not None: 231 | sample = self.transform(sample) 232 | if self.target_transform is not None: 233 | target = self.target_transform(target) 234 | 235 | return sample, target 236 | 237 | def __len__(self) -> int: 238 | return len(self.samples) 239 | 240 | 241 | IMG_EXTENSIONS = (".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp") 242 | 243 | 244 | def pil_loader(path: str) -> Image.Image: 245 | # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) 246 | with open(path, "rb") as f: 247 | img = Image.open(f) 248 | return img.convert("RGB") 249 | 250 | 251 | # TODO: specify the return type 252 | def accimage_loader(path: str) -> Any: 253 | import accimage 254 | 255 | try: 256 | return accimage.Image(path) 257 | except OSError: 258 | # Potentially a decoding problem, fall back to PIL.Image 259 | return pil_loader(path) 260 | 261 | 262 | def default_loader(path: str) -> Any: 263 | from torchvision import get_image_backend 264 | 265 | if get_image_backend() == "accimage": 266 | return accimage_loader(path) 267 | else: 268 | return pil_loader(path) 269 | 270 | 271 | class ImageFolder(DatasetFolder): 272 | """A generic data loader where the images are arranged in this way by default: :: 273 | 274 | root/dog/xxx.png 275 | root/dog/xxy.png 276 | root/dog/[...]/xxz.png 277 | 278 | root/cat/123.png 279 | root/cat/nsdf3.png 280 | root/cat/[...]/asd932_.png 281 | 282 | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so 283 | the same methods can be overridden to customize the dataset. 284 | 285 | Args: 286 | root (string): Root directory path. 287 | transform (callable, optional): A function/transform that takes in an PIL image 288 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 289 | target_transform (callable, optional): A function/transform that takes in the 290 | target and transforms it. 291 | loader (callable, optional): A function to load an image given its path. 292 | is_valid_file (callable, optional): A function that takes path of an Image file 293 | and check if the file is a valid file (used to check of corrupt files) 294 | 295 | Attributes: 296 | classes (list): List of the class names sorted alphabetically. 297 | class_to_idx (dict): Dict with items (class_name, class_index). 298 | imgs (list): List of (image path, class_index) tuples 299 | """ 300 | 301 | def __init__( 302 | self, 303 | root: str, 304 | transform: Optional[Callable] = None, 305 | target_transform: Optional[Callable] = None, 306 | loader: Callable[[str], Any] = default_loader, 307 | is_valid_file: Optional[Callable[[str], bool]] = None, 308 | ): 309 | super().__init__( 310 | root, 311 | loader, 312 | IMG_EXTENSIONS if is_valid_file is None else None, 313 | transform=transform, 314 | target_transform=target_transform, 315 | is_valid_file=is_valid_file, 316 | ) 317 | self.imgs = self.samples 318 | 319 | class ImageNetFolder(DatasetFolder): 320 | """A generic data loader where the images are arranged in this way by default: :: 321 | 322 | root/dog/xxx.png 323 | root/dog/xxy.png 324 | root/dog/[...]/xxz.png 325 | 326 | root/cat/123.png 327 | root/cat/nsdf3.png 328 | root/cat/[...]/asd932_.png 329 | 330 | This class inherits from :class:`~torchvision.datasets.DatasetFolder` so 331 | the same methods can be overridden to customize the dataset. 332 | 333 | Args: 334 | root (string): Root directory path. 335 | transform (callable, optional): A function/transform that takes in an PIL image 336 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 337 | target_transform (callable, optional): A function/transform that takes in the 338 | target and transforms it. 339 | loader (callable, optional): A function to load an image given its path. 340 | is_valid_file (callable, optional): A function that takes path of an Image file 341 | and check if the file is a valid file (used to check of corrupt files) 342 | 343 | Attributes: 344 | classes (list): List of the class names sorted alphabetically. 345 | class_to_idx (dict): Dict with items (class_name, class_index). 346 | imgs (list): List of (image path, class_index) tuples 347 | """ 348 | 349 | def find_classes(self, directory: str) -> Tuple[List[str] | Dict[str, int]]: 350 | classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir()) 351 | if not classes: 352 | raise FileNotFoundError(f"Couldn't find any class folder in {directory}.") 353 | # remove duplicate class 'maillot' ('n03710637' and 'n03710721'). We retain 'n03710721' only. 354 | # resulting in 999 classes 355 | if 'n03710637' in classes: 356 | classes.remove('n03710637') 357 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 358 | return classes, class_to_idx 359 | 360 | def make_dataset( 361 | self, 362 | directory: str, 363 | class_to_idx: Dict[str, int], 364 | extensions: Optional[Tuple[str, ...]] = None, 365 | is_valid_file: Optional[Callable[[str], bool]] = None, 366 | ) -> List[Tuple[str, int]]: 367 | 368 | directory = os.path.expanduser(directory) 369 | 370 | if class_to_idx is None: 371 | _, class_to_idx = self.find_classes(directory) 372 | elif not class_to_idx: 373 | raise ValueError("'class_to_index' must have at least one entry to collect any samples.") 374 | 375 | both_none = extensions is None and is_valid_file is None 376 | both_something = extensions is not None and is_valid_file is not None 377 | if both_none or both_something: 378 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 379 | 380 | if extensions is not None: 381 | 382 | def is_valid_file(x: str) -> bool: 383 | return has_file_allowed_extension(x, extensions) # type: ignore[arg-type] 384 | 385 | is_valid_file = cast(Callable[[str], bool], is_valid_file) 386 | 387 | instances = [] 388 | available_classes = set() 389 | for target_class in sorted(class_to_idx.keys()): 390 | class_index = class_to_idx[target_class] 391 | target_dir = os.path.join(directory, target_class) 392 | select_files = self.select_files[target_class] if self.select_files is not None else None 393 | if not os.path.isdir(target_dir): 394 | continue 395 | for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): 396 | for fname in sorted(fnames): 397 | if select_files is not None and fname not in select_files: 398 | continue 399 | path = os.path.join(root, fname) 400 | if is_valid_file(path): 401 | item = path, class_index 402 | instances.append(item) 403 | 404 | if target_class not in available_classes: 405 | available_classes.add(target_class) 406 | 407 | empty_classes = set(class_to_idx.keys()) - available_classes 408 | if empty_classes: 409 | msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. " 410 | if extensions is not None: 411 | msg += f"Supported extensions are: {extensions if isinstance(extensions, str) else ', '.join(extensions)}" 412 | raise FileNotFoundError(msg) 413 | 414 | return instances 415 | 416 | def __init__( 417 | self, 418 | root: str, 419 | transform: Optional[Callable] = None, 420 | target_transform: Optional[Callable] = None, 421 | loader: Callable[[str], Any] = default_loader, 422 | is_valid_file: Optional[Callable[[str], bool]] = None, 423 | select_files: Optional[Dict[str, list]] = None, # an argument that we add to hold out some samples for evaluation 424 | ): 425 | self.select_files = select_files 426 | super().__init__( 427 | root, 428 | loader, 429 | IMG_EXTENSIONS if is_valid_file is None else None, 430 | transform=transform, 431 | target_transform=target_transform, 432 | is_valid_file=is_valid_file, 433 | ) 434 | self.imgs = self.samples 435 | -------------------------------------------------------------------------------- /code/replace/datasets/food101.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | from typing import Any, Tuple, Callable, Optional 4 | 5 | import PIL.Image 6 | 7 | from torchvision.datasets.utils import verify_str_arg, download_and_extract_archive 8 | from torchvision.datasets.vision import VisionDataset 9 | 10 | 11 | class Food101(VisionDataset): 12 | """`The Food-101 Data Set `_. 13 | 14 | The Food-101 is a challenging data set of 101 food categories, with 101'000 images. 15 | For each class, 250 manually reviewed test images are provided as well as 750 training images. 16 | On purpose, the training images were not cleaned, and thus still contain some amount of noise. 17 | This comes mostly in the form of intense colors and sometimes wrong labels. All images were 18 | rescaled to have a maximum side length of 512 pixels. 19 | 20 | 21 | Args: 22 | root (string): Root directory of the dataset. 23 | split (string, optional): The dataset split, supports ``"train"`` (default) and ``"test"``. 24 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 25 | version. E.g, ``transforms.RandomCrop``. 26 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 27 | download (bool, optional): If True, downloads the dataset from the internet and 28 | puts it in root directory. If dataset is already downloaded, it is not 29 | downloaded again. Default is False. 30 | """ 31 | 32 | _URL = "http://data.vision.ee.ethz.ch/cvl/food-101.tar.gz" 33 | _MD5 = "85eeb15f3717b99a5da872d97d918f87" 34 | 35 | def __init__( 36 | self, 37 | root: str, 38 | split: str = "train", 39 | transform: Optional[Callable] = None, 40 | target_transform: Optional[Callable] = None, 41 | download: bool = False, 42 | prompt_template = "A photo of {}, a type of food." 43 | ) -> None: 44 | super().__init__(root, transform=transform, target_transform=target_transform) 45 | self._split = verify_str_arg(split, "split", ("train", "test")) 46 | self._base_folder = Path(self.root) / "food-101" 47 | self._meta_folder = self._base_folder / "meta" 48 | self._images_folder = self._base_folder / "images" 49 | 50 | if download: 51 | self._download() 52 | 53 | if not self._check_exists(): 54 | raise RuntimeError("Dataset not found. You can use download=True to download it") 55 | 56 | self._labels = [] 57 | self._image_files = [] 58 | with open(self._meta_folder / f"{split}.json") as f: 59 | metadata = json.loads(f.read()) 60 | 61 | self.classes = sorted(metadata.keys()) 62 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 63 | 64 | for class_label, im_rel_paths in metadata.items(): 65 | self._labels += [self.class_to_idx[class_label]] * len(im_rel_paths) 66 | self._image_files += [ 67 | self._images_folder.joinpath(*f"{im_rel_path}.jpg".split("/")) for im_rel_path in im_rel_paths 68 | ] 69 | 70 | self.prompt_template = prompt_template 71 | self.clip_prompts = [ 72 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \ 73 | for label in self.classes 74 | ] 75 | 76 | def __len__(self) -> int: 77 | return len(self._image_files) 78 | 79 | def __getitem__(self, idx) -> Tuple[Any, Any]: 80 | image_file, label = self._image_files[idx], self._labels[idx] 81 | image = PIL.Image.open(image_file).convert("RGB") 82 | 83 | if self.transform: 84 | image = self.transform(image) 85 | 86 | if self.target_transform: 87 | label = self.target_transform(label) 88 | 89 | return image, label 90 | 91 | def extra_repr(self) -> str: 92 | return f"split={self._split}" 93 | 94 | def _check_exists(self) -> bool: 95 | return all(folder.exists() and folder.is_dir() for folder in (self._meta_folder, self._images_folder)) 96 | 97 | def _download(self) -> None: 98 | if self._check_exists(): 99 | return 100 | download_and_extract_archive(self._URL, download_root=self.root, md5=self._MD5) 101 | -------------------------------------------------------------------------------- /code/replace/datasets/oxford_iiit_pet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import pathlib 4 | from typing import Any, Callable, Optional, Union, Tuple 5 | from typing import Sequence 6 | 7 | from PIL import Image 8 | 9 | from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg 10 | from torchvision.datasets.vision import VisionDataset 11 | 12 | 13 | class OxfordIIITPet(VisionDataset): 14 | """`Oxford-IIIT Pet Dataset `_. 15 | 16 | Args: 17 | root (string): Root directory of the dataset. 18 | split (string, optional): The dataset split, supports ``"trainval"`` (default) or ``"test"``. 19 | target_types (string, sequence of strings, optional): Types of target to use. Can be ``category`` (default) or 20 | ``segmentation``. Can also be a list to output a tuple with all specified target types. The types represent: 21 | 22 | - ``category`` (int): Label for one of the 37 pet categories. 23 | - ``segmentation`` (PIL image): Segmentation trimap of the image. 24 | 25 | If empty, ``None`` will be returned as target. 26 | 27 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed 28 | version. E.g, ``transforms.RandomCrop``. 29 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 30 | download (bool, optional): If True, downloads the dataset from the internet and puts it into 31 | ``root/oxford-iiit-pet``. If dataset is already downloaded, it is not downloaded again. 32 | """ 33 | 34 | _RESOURCES = ( 35 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz", "5c4f3ee8e5d25df40f4fd59a7f44e54c"), 36 | ("https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz", "95a8c909bbe2e81eed6a22bccdf3f68f"), 37 | ) 38 | _VALID_TARGET_TYPES = ("category", "segmentation") 39 | 40 | def __init__( 41 | self, 42 | root: str, 43 | split: str = "trainval", 44 | target_types: Union[Sequence[str], str] = "category", 45 | transforms: Optional[Callable] = None, 46 | transform: Optional[Callable] = None, 47 | target_transform: Optional[Callable] = None, 48 | download: bool = False, 49 | prompt_template = "A photo of a {}, a type of pet." 50 | ): 51 | self._split = verify_str_arg(split, "split", ("trainval", "test")) 52 | if isinstance(target_types, str): 53 | target_types = [target_types] 54 | self._target_types = [ 55 | verify_str_arg(target_type, "target_types", self._VALID_TARGET_TYPES) for target_type in target_types 56 | ] 57 | 58 | super().__init__(root, transforms=transforms, transform=transform, target_transform=target_transform) 59 | self._base_folder = pathlib.Path(self.root) / "oxford-iiit-pet" 60 | self._images_folder = self._base_folder / "images" 61 | self._anns_folder = self._base_folder / "annotations" 62 | self._segs_folder = self._anns_folder / "trimaps" 63 | 64 | if download: 65 | self._download() 66 | 67 | if not self._check_exists(): 68 | raise RuntimeError("Dataset not found. You can use download=True to download it") 69 | 70 | image_ids = [] 71 | self._labels = [] 72 | with open(self._anns_folder / f"{self._split}.txt") as file: 73 | for line in file: 74 | image_id, label, *_ = line.strip().split() 75 | image_ids.append(image_id) 76 | self._labels.append(int(label) - 1) 77 | 78 | self.classes = [ 79 | " ".join(part.title() for part in raw_cls.split("_")) 80 | for raw_cls, _ in sorted( 81 | {(image_id.rsplit("_", 1)[0], label) for image_id, label in zip(image_ids, self._labels)}, 82 | key=lambda image_id_and_label: image_id_and_label[1], 83 | ) 84 | ] 85 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 86 | 87 | self._images = [self._images_folder / f"{image_id}.jpg" for image_id in image_ids] 88 | self._segs = [self._segs_folder / f"{image_id}.png" for image_id in image_ids] 89 | 90 | self.prompt_template = prompt_template 91 | self.clip_prompts = [ 92 | prompt_template.format(label.lower().replace('_', ' ').replace('-', ' ')) \ 93 | for label in self.classes 94 | ] 95 | 96 | def __len__(self) -> int: 97 | return len(self._images) 98 | 99 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 100 | image = Image.open(self._images[idx]).convert("RGB") 101 | 102 | target: Any = [] 103 | for target_type in self._target_types: 104 | if target_type == "category": 105 | target.append(self._labels[idx]) 106 | else: # target_type == "segmentation" 107 | target.append(Image.open(self._segs[idx])) 108 | 109 | if not target: 110 | target = None 111 | elif len(target) == 1: 112 | target = target[0] 113 | else: 114 | target = tuple(target) 115 | 116 | if self.transforms: 117 | image, target = self.transforms(image, target) 118 | 119 | return image, target 120 | 121 | def _check_exists(self) -> bool: 122 | for folder in (self._images_folder, self._anns_folder): 123 | if not (os.path.exists(folder) and os.path.isdir(folder)): 124 | return False 125 | else: 126 | return True 127 | 128 | def _download(self) -> None: 129 | if self._check_exists(): 130 | return 131 | 132 | for url, md5 in self._RESOURCES: 133 | download_and_extract_archive(url, download_root=str(self._base_folder), md5=md5) 134 | -------------------------------------------------------------------------------- /code/replace/datasets/pcam.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Any, Callable, Optional, Tuple 3 | 4 | from PIL import Image 5 | 6 | from torchvision.datasets.utils import download_file_from_google_drive, _decompress, verify_str_arg 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | class PCAM(VisionDataset): 11 | """`PCAM Dataset `_. 12 | 13 | The PatchCamelyon dataset is a binary classification dataset with 327,680 14 | color images (96px x 96px), extracted from histopathologic scans of lymph node 15 | sections. Each image is annotated with a binary label indicating presence of 16 | metastatic tissue. 17 | 18 | This dataset requires the ``h5py`` package which you can install with ``pip install h5py``. 19 | 20 | Args: 21 | root (string): Root directory of the dataset. 22 | split (string, optional): The dataset split, supports ``"train"`` (default), ``"test"`` or ``"val"``. 23 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed 24 | version. E.g, ``transforms.RandomCrop``. 25 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 26 | download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If 27 | dataset is already downloaded, it is not downloaded again. 28 | """ 29 | 30 | _FILES = { 31 | "train": { 32 | "images": ( 33 | "camelyonpatch_level_2_split_train_x.h5", # Data file name 34 | "1Ka0XfEMiwgCYPdTI-vv6eUElOBnKFKQ2", # Google Drive ID 35 | "1571f514728f59376b705fc836ff4b63", # md5 hash 36 | ), 37 | "targets": ( 38 | "camelyonpatch_level_2_split_train_y.h5", 39 | "1269yhu3pZDP8UYFQs-NYs3FPwuK-nGSG", 40 | "35c2d7259d906cfc8143347bb8e05be7", 41 | ), 42 | }, 43 | "test": { 44 | "images": ( 45 | "camelyonpatch_level_2_split_test_x.h5", 46 | "1qV65ZqZvWzuIVthK8eVDhIwrbnsJdbg_", 47 | "d5b63470df7cfa627aeec8b9dc0c066e", 48 | ), 49 | "targets": ( 50 | "camelyonpatch_level_2_split_test_y.h5", 51 | "17BHrSrwWKjYsOgTMmoqrIjDy6Fa2o_gP", 52 | "2b85f58b927af9964a4c15b8f7e8f179", 53 | ), 54 | }, 55 | "val": { 56 | "images": ( 57 | "camelyonpatch_level_2_split_valid_x.h5", 58 | "1hgshYGWK8V-eGRy8LToWJJgDU_rXWVJ3", 59 | "d8c2d60d490dbd479f8199bdfa0cf6ec", 60 | ), 61 | "targets": ( 62 | "camelyonpatch_level_2_split_valid_y.h5", 63 | "1bH8ZRbhSVAhScTS0p9-ZzGnX91cHT3uO", 64 | "60a7035772fbdb7f34eb86d4420cf66a", 65 | ), 66 | }, 67 | } 68 | 69 | clip_prompts = [ 70 | "This is a photo of a healthy lymph node tissue.", 71 | "This is a photo of a lymph node tumor tissue." 72 | ] 73 | 74 | def __init__( 75 | self, 76 | root: str, 77 | split: str = "train", 78 | transform: Optional[Callable] = None, 79 | target_transform: Optional[Callable] = None, 80 | download: bool = False, 81 | ): 82 | try: 83 | import h5py 84 | 85 | self.h5py = h5py 86 | except ImportError: 87 | raise RuntimeError( 88 | "h5py is not found. This dataset needs to have h5py installed: please run pip install h5py" 89 | ) 90 | 91 | self._split = verify_str_arg(split, "split", ("train", "test", "val")) 92 | 93 | super().__init__(root, transform=transform, target_transform=target_transform) 94 | self._base_folder = pathlib.Path(self.root) / "pcam" 95 | 96 | if download: 97 | self._download() 98 | 99 | if not self._check_exists(): 100 | raise RuntimeError("Dataset not found. You can use download=True to download it") 101 | 102 | def __len__(self) -> int: 103 | images_file = self._FILES[self._split]["images"][0] 104 | with self.h5py.File(self._base_folder / images_file) as images_data: 105 | return images_data["x"].shape[0] 106 | 107 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 108 | images_file = self._FILES[self._split]["images"][0] 109 | with self.h5py.File(self._base_folder / images_file) as images_data: 110 | image = Image.fromarray(images_data["x"][idx]).convert("RGB") 111 | 112 | targets_file = self._FILES[self._split]["targets"][0] 113 | with self.h5py.File(self._base_folder / targets_file) as targets_data: 114 | target = int(targets_data["y"][idx, 0, 0, 0]) # shape is [num_images, 1, 1, 1] 115 | 116 | if self.transform: 117 | image = self.transform(image) 118 | if self.target_transform: 119 | target = self.target_transform(target) 120 | 121 | return image, target 122 | 123 | def _check_exists(self) -> bool: 124 | images_file = self._FILES[self._split]["images"][0] 125 | targets_file = self._FILES[self._split]["targets"][0] 126 | return all(self._base_folder.joinpath(h5_file).exists() for h5_file in (images_file, targets_file)) 127 | 128 | def _download(self) -> None: 129 | if self._check_exists(): 130 | return 131 | 132 | for file_name, file_id, md5 in self._FILES[self._split].values(): 133 | archive_name = file_name + ".gz" 134 | download_file_from_google_drive(file_id, str(self._base_folder), filename=archive_name, md5=md5) 135 | _decompress(str(self._base_folder / archive_name)) 136 | -------------------------------------------------------------------------------- /code/replace/datasets/stanford_cars.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from typing import Callable, Optional, Any, Tuple 3 | 4 | from PIL import Image 5 | 6 | from torchvision.datasets.utils import download_and_extract_archive, download_url, verify_str_arg 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | class StanfordCars(VisionDataset): 11 | """`Stanford Cars `_ Dataset 12 | 13 | The Cars dataset contains 16,185 images of 196 classes of cars. The data is 14 | split into 8,144 training images and 8,041 testing images, where each class 15 | has been split roughly in a 50-50 split 16 | 17 | .. note:: 18 | 19 | This class needs `scipy `_ to load target files from `.mat` format. 20 | 21 | Args: 22 | root (string): Root directory of dataset 23 | split (string, optional): The dataset split, supports ``"train"`` (default) or ``"test"``. 24 | transform (callable, optional): A function/transform that takes in an PIL image 25 | and returns a transformed version. E.g, ``transforms.RandomCrop`` 26 | target_transform (callable, optional): A function/transform that takes in the 27 | target and transforms it. 28 | download (bool, optional): If True, downloads the dataset from the internet and 29 | puts it in root directory. If dataset is already downloaded, it is not 30 | downloaded again.""" 31 | 32 | def __init__( 33 | self, 34 | root: str, 35 | split: str = "train", 36 | transform: Optional[Callable] = None, 37 | target_transform: Optional[Callable] = None, 38 | download: bool = False, 39 | prompt_template = "A photo of a {}." 40 | ) -> None: 41 | 42 | try: 43 | import scipy.io as sio 44 | except ImportError: 45 | raise RuntimeError("Scipy is not found. This dataset needs to have scipy installed: pip install scipy") 46 | 47 | super().__init__(root, transform=transform, target_transform=target_transform) 48 | 49 | self._split = verify_str_arg(split, "split", ("train", "test")) 50 | self._base_folder = pathlib.Path(root) / "stanford_cars" 51 | devkit = self._base_folder / "devkit" 52 | 53 | if self._split == "train": 54 | self._annotations_mat_path = devkit / "cars_train_annos.mat" 55 | self._images_base_path = self._base_folder / "cars_train" 56 | else: 57 | self._annotations_mat_path = self._base_folder / "cars_test_annos_withlabels.mat" 58 | self._images_base_path = self._base_folder / "cars_test" 59 | 60 | if download: 61 | self.download() 62 | 63 | if not self._check_exists(): 64 | raise RuntimeError("Dataset not found. You can use download=True to download it") 65 | 66 | self._samples = [ 67 | ( 68 | str(self._images_base_path / annotation["fname"]), 69 | annotation["class"] - 1, # Original target mapping starts from 1, hence -1 70 | ) 71 | for annotation in sio.loadmat(self._annotations_mat_path, squeeze_me=True)["annotations"] 72 | ] 73 | 74 | self.classes = sio.loadmat(str(devkit / "cars_meta.mat"), squeeze_me=True)["class_names"].tolist() 75 | self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)} 76 | 77 | self.prompt_template = prompt_template 78 | self.clip_prompts = [ 79 | prompt_template.format(label[:-5].lower().replace('_', ' ').replace('-', ' ')) \ 80 | for label in self.classes 81 | ] 82 | 83 | def __len__(self) -> int: 84 | return len(self._samples) 85 | 86 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 87 | """Returns pil_image and class_id for given index""" 88 | image_path, target = self._samples[idx] 89 | pil_image = Image.open(image_path).convert("RGB") 90 | 91 | if self.transform is not None: 92 | pil_image = self.transform(pil_image) 93 | if self.target_transform is not None: 94 | target = self.target_transform(target) 95 | return pil_image, target 96 | 97 | def download(self) -> None: 98 | if self._check_exists(): 99 | return 100 | 101 | download_and_extract_archive( 102 | url="https://ai.stanford.edu/~jkrause/cars/car_devkit.tgz", 103 | download_root=str(self._base_folder), 104 | md5="c3b158d763b6e2245038c8ad08e45376", 105 | ) 106 | if self._split == "train": 107 | download_and_extract_archive( 108 | url="https://ai.stanford.edu/~jkrause/car196/cars_train.tgz", 109 | download_root=str(self._base_folder), 110 | md5="065e5b463ae28d29e77c1b4b166cfe61", 111 | ) 112 | else: 113 | download_and_extract_archive( 114 | url="https://ai.stanford.edu/~jkrause/car196/cars_test.tgz", 115 | download_root=str(self._base_folder), 116 | md5="4ce7ebf6a94d07f1952d94dd34c4d501", 117 | ) 118 | download_url( 119 | url="https://ai.stanford.edu/~jkrause/car196/cars_test_annos_withlabels.mat", 120 | root=str(self._base_folder), 121 | md5="b0a2b23655a3edd16d84508592a98d10", 122 | ) 123 | 124 | def _check_exists(self) -> bool: 125 | if not (self._base_folder / "devkit").is_dir(): 126 | return False 127 | 128 | return self._annotations_mat_path.exists() and self._images_base_path.is_dir() 129 | -------------------------------------------------------------------------------- /code/replace/datasets/sun397.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Tuple, Callable, Optional 3 | 4 | import PIL.Image 5 | 6 | from torchvision.datasets.utils import download_and_extract_archive 7 | from torchvision.datasets.vision import VisionDataset 8 | 9 | 10 | class SUN397(VisionDataset): 11 | """`The SUN397 Data Set `_. 12 | 13 | The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 14 | 397 categories with 108'754 images. 15 | 16 | Args: 17 | root (string): Root directory of the dataset. 18 | transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed 19 | version. E.g, ``transforms.RandomCrop``. 20 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 21 | download (bool, optional): If true, downloads the dataset from the internet and 22 | puts it in root directory. If dataset is already downloaded, it is not 23 | downloaded again. 24 | """ 25 | 26 | _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" 27 | _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" 28 | 29 | def __init__( 30 | self, 31 | root: str, 32 | transform: Optional[Callable] = None, 33 | target_transform: Optional[Callable] = None, 34 | download: bool = False, 35 | ) -> None: 36 | super().__init__(root, transform=transform, target_transform=target_transform) 37 | self._data_dir = Path(self.root) / "SUN397" 38 | 39 | if download: 40 | self._download() 41 | 42 | if not self._check_exists(): 43 | raise RuntimeError("Dataset not found. You can use download=True to download it") 44 | 45 | with open(self._data_dir / "ClassName.txt") as f: 46 | self.classes = [c[3:].strip() for c in f] 47 | 48 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 49 | self._image_files = list(self._data_dir.rglob("sun_*.jpg")) 50 | 51 | self._labels = [ 52 | self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files 53 | ] 54 | 55 | def __len__(self) -> int: 56 | return len(self._image_files) 57 | 58 | def __getitem__(self, idx) -> Tuple[Any, Any]: 59 | image_file, label = self._image_files[idx], self._labels[idx] 60 | image = PIL.Image.open(image_file).convert("RGB") 61 | 62 | if self.transform: 63 | image = self.transform(image) 64 | 65 | if self.target_transform: 66 | label = self.target_transform(label) 67 | 68 | return image, label 69 | 70 | def _check_exists(self) -> bool: 71 | return self._data_dir.is_dir() 72 | 73 | def _download(self) -> None: 74 | if self._check_exists(): 75 | return 76 | download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) 77 | -------------------------------------------------------------------------------- /code/replace/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | self.relu1 = nn.ReLU(inplace=True) 20 | 21 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 22 | self.bn2 = nn.BatchNorm2d(planes) 23 | self.relu2 = nn.ReLU(inplace=True) 24 | 25 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 26 | 27 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 28 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 29 | self.relu3 = nn.ReLU(inplace=True) 30 | 31 | self.downsample = None 32 | self.stride = stride 33 | 34 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 35 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 36 | self.downsample = nn.Sequential(OrderedDict([ 37 | ("-1", nn.AvgPool2d(stride)), 38 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 39 | ("1", nn.BatchNorm2d(planes * self.expansion)) 40 | ])) 41 | 42 | def forward(self, x: torch.Tensor): 43 | identity = x 44 | 45 | out = self.relu1(self.bn1(self.conv1(x))) 46 | out = self.relu2(self.bn2(self.conv2(out))) 47 | out = self.avgpool(out) 48 | out = self.bn3(self.conv3(out)) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu3(out) 55 | return out 56 | 57 | 58 | class AttentionPool2d(nn.Module): 59 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 60 | super().__init__() 61 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 62 | self.k_proj = nn.Linear(embed_dim, embed_dim) 63 | self.q_proj = nn.Linear(embed_dim, embed_dim) 64 | self.v_proj = nn.Linear(embed_dim, embed_dim) 65 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 66 | self.num_heads = num_heads 67 | 68 | def forward(self, x): 69 | x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC 70 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 71 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 72 | x, _ = F.multi_head_attention_forward( 73 | query=x[:1], key=x, value=x, 74 | embed_dim_to_check=x.shape[-1], 75 | num_heads=self.num_heads, 76 | q_proj_weight=self.q_proj.weight, 77 | k_proj_weight=self.k_proj.weight, 78 | v_proj_weight=self.v_proj.weight, 79 | in_proj_weight=None, 80 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 81 | bias_k=None, 82 | bias_v=None, 83 | add_zero_attn=False, 84 | dropout_p=0, 85 | out_proj_weight=self.c_proj.weight, 86 | out_proj_bias=self.c_proj.bias, 87 | use_separate_proj_weight=True, 88 | training=self.training, 89 | need_weights=False 90 | ) 91 | return x.squeeze(0) 92 | 93 | 94 | class ModifiedResNet(nn.Module): 95 | """ 96 | A ResNet class that is similar to torchvision's but contains the following changes: 97 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 98 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 99 | - The final pooling layer is a QKV attention instead of an average pool 100 | """ 101 | 102 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 103 | super().__init__() 104 | self.output_dim = output_dim 105 | self.input_resolution = input_resolution 106 | 107 | # the 3-layer stem 108 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 109 | self.bn1 = nn.BatchNorm2d(width // 2) 110 | self.relu1 = nn.ReLU(inplace=True) 111 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 112 | self.bn2 = nn.BatchNorm2d(width // 2) 113 | self.relu2 = nn.ReLU(inplace=True) 114 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 115 | self.bn3 = nn.BatchNorm2d(width) 116 | self.relu3 = nn.ReLU(inplace=True) 117 | self.avgpool = nn.AvgPool2d(2) 118 | 119 | # residual layers 120 | self._inplanes = width # this is a *mutable* variable used during construction 121 | self.layer1 = self._make_layer(width, layers[0]) 122 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 123 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 124 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 125 | 126 | embed_dim = width * 32 # the ResNet feature dimension 127 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 128 | 129 | def _make_layer(self, planes, blocks, stride=1): 130 | layers = [Bottleneck(self._inplanes, planes, stride)] 131 | 132 | self._inplanes = planes * Bottleneck.expansion 133 | for _ in range(1, blocks): 134 | layers.append(Bottleneck(self._inplanes, planes)) 135 | 136 | return nn.Sequential(*layers) 137 | 138 | def forward(self, x): 139 | def stem(x): 140 | x = self.relu1(self.bn1(self.conv1(x))) 141 | x = self.relu2(self.bn2(self.conv2(x))) 142 | x = self.relu3(self.bn3(self.conv3(x))) 143 | x = self.avgpool(x) 144 | return x 145 | 146 | x = x.type(self.conv1.weight.dtype) 147 | x = stem(x) 148 | x = self.layer1(x) 149 | x = self.layer2(x) 150 | x = self.layer3(x) 151 | x = self.layer4(x) 152 | x = self.attnpool(x) 153 | 154 | return x 155 | 156 | 157 | class LayerNorm(nn.LayerNorm): 158 | """Subclass torch's LayerNorm to handle fp16.""" 159 | 160 | def forward(self, x: torch.Tensor): 161 | orig_type = x.dtype 162 | ret = super().forward(x.type(torch.float32)) 163 | return ret.type(orig_type) 164 | 165 | 166 | class QuickGELU(nn.Module): 167 | def forward(self, x: torch.Tensor): 168 | return x * torch.sigmoid(1.702 * x) 169 | 170 | 171 | class ResidualAttentionBlock(nn.Module): 172 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 173 | super().__init__() 174 | 175 | self.attn = nn.MultiheadAttention(d_model, n_head) 176 | self.ln_1 = LayerNorm(d_model) 177 | self.mlp = nn.Sequential(OrderedDict([ 178 | ("c_fc", nn.Linear(d_model, d_model * 4)), 179 | ("gelu", QuickGELU()), 180 | ("c_proj", nn.Linear(d_model * 4, d_model)) 181 | ])) 182 | self.ln_2 = LayerNorm(d_model) 183 | self.attn_mask = attn_mask 184 | 185 | def attention(self, x: torch.Tensor): 186 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 187 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 188 | 189 | def forward(self, x: torch.Tensor): 190 | x = x + self.attention(self.ln_1(x)) 191 | x = x + self.mlp(self.ln_2(x)) 192 | return x 193 | 194 | 195 | class Transformer(nn.Module): 196 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 197 | super().__init__() 198 | self.width = width 199 | self.layers = layers 200 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 201 | 202 | def forward(self, x: torch.Tensor): 203 | return self.resblocks(x) 204 | 205 | 206 | class VisionTransformer(nn.Module): 207 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int, prompt_len: int): 208 | super().__init__() 209 | self.input_resolution = input_resolution 210 | self.output_dim = output_dim 211 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 212 | self.prompt_len = prompt_len 213 | scale = width ** -0.5 214 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 215 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 216 | if prompt_len > 0: 217 | self.prompt_pos_embedding = nn.Parameter(scale * torch.zeros(prompt_len, width)) 218 | 219 | self.ln_pre = LayerNorm(width) 220 | 221 | self.transformer = Transformer(width, layers, heads) 222 | 223 | self.ln_post = LayerNorm(width) 224 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 225 | 226 | def forward(self, x: torch.Tensor, ind_prompt: torch.Tensor = None): 227 | x = self.conv1(x) # shape = [*, width, grid, grid] 228 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 229 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 230 | if self.prompt_len>0 and ind_prompt != None: 231 | tmp_ind_prompt = ind_prompt + torch.zeros(x.shape[0],1,1, dtype=x.dtype, device=x.device) 232 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x, tmp_ind_prompt], dim=1) # shape = [*, grid ** 2 + 1, width] 233 | x = x + torch.cat([self.positional_embedding.to(x.dtype), self.prompt_pos_embedding], dim=0) 234 | else: 235 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 236 | x = x + self.positional_embedding.to(x.dtype) 237 | 238 | x = self.ln_pre(x) 239 | 240 | x = x.permute(1, 0, 2) # NLD -> LND 241 | x = self.transformer(x) 242 | x = x.permute(1, 0, 2) # LND -> NLD 243 | 244 | x = self.ln_post(x[:, 0, :]) 245 | 246 | if self.proj is not None: 247 | x = x @ self.proj 248 | 249 | return x 250 | 251 | 252 | class CLIP(nn.Module): 253 | def __init__(self, 254 | embed_dim: int, 255 | # vision 256 | image_resolution: int, 257 | vision_layers: Union[Tuple[int, int, int, int], int], 258 | vision_width: int, 259 | vision_patch_size: int, 260 | # text 261 | context_length: int, 262 | vocab_size: int, 263 | transformer_width: int, 264 | transformer_heads: int, 265 | transformer_layers: int, 266 | prompt_len: int 267 | ): 268 | super().__init__() 269 | 270 | self.context_length = context_length 271 | 272 | if isinstance(vision_layers, (tuple, list)): 273 | vision_heads = vision_width * 32 // 64 274 | self.visual = ModifiedResNet( 275 | layers=vision_layers, 276 | output_dim=embed_dim, 277 | heads=vision_heads, 278 | input_resolution=image_resolution, 279 | width=vision_width, 280 | prompt_len=prompt_len 281 | ) 282 | else: 283 | vision_heads = vision_width // 64 284 | self.visual = VisionTransformer( 285 | input_resolution=image_resolution, 286 | patch_size=vision_patch_size, 287 | width=vision_width, 288 | layers=vision_layers, 289 | heads=vision_heads, 290 | output_dim=embed_dim, 291 | prompt_len=prompt_len 292 | ) 293 | 294 | self.transformer = Transformer( 295 | width=transformer_width, 296 | layers=transformer_layers, 297 | heads=transformer_heads, 298 | attn_mask=self.build_attention_mask() 299 | ) 300 | 301 | self.vocab_size = vocab_size 302 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 303 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 304 | self.ln_final = LayerNorm(transformer_width) 305 | 306 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 307 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 308 | 309 | self.initialize_parameters() 310 | 311 | def initialize_parameters(self): 312 | nn.init.normal_(self.token_embedding.weight, std=0.02) 313 | nn.init.normal_(self.positional_embedding, std=0.01) 314 | 315 | if isinstance(self.visual, ModifiedResNet): 316 | if self.visual.attnpool is not None: 317 | std = self.visual.attnpool.c_proj.in_features ** -0.5 318 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 319 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 320 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 321 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 322 | 323 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 324 | for name, param in resnet_block.named_parameters(): 325 | if name.endswith("bn3.weight"): 326 | nn.init.zeros_(param) 327 | 328 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 329 | attn_std = self.transformer.width ** -0.5 330 | fc_std = (2 * self.transformer.width) ** -0.5 331 | for block in self.transformer.resblocks: 332 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 333 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 334 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 335 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 336 | 337 | if self.text_projection is not None: 338 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 339 | 340 | def build_attention_mask(self): 341 | # lazily create causal attention mask, with full attention between the vision tokens 342 | # pytorch uses additive attention mask; fill with -inf 343 | mask = torch.empty(self.context_length, self.context_length) 344 | mask.fill_(float("-inf")) 345 | mask.triu_(1) # zero out the lower diagonal 346 | return mask 347 | 348 | @property 349 | def dtype(self): 350 | return self.visual.conv1.weight.dtype 351 | 352 | def encode_image(self, image, ind_prompt): 353 | if ind_prompt==None: 354 | return self.visual(image.type(self.dtype)) 355 | return self.visual(image.type(self.dtype), ind_prompt.type(self.dtype)) 356 | 357 | def encode_text(self, text): 358 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 359 | 360 | x = x + self.positional_embedding.type(self.dtype) 361 | x = x.permute(1, 0, 2) # NLD -> LND 362 | x = self.transformer(x) 363 | x = x.permute(1, 0, 2) # LND -> NLD 364 | x = self.ln_final(x).type(self.dtype) 365 | 366 | # x.shape = [batch_size, n_ctx, transformer.width] 367 | # take features from the eot embedding (eot_token is the highest number in each sequence) 368 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 369 | 370 | return x 371 | 372 | def forward(self, image, text, ind_prompt=None): 373 | image_features = self.encode_image(image, ind_prompt) 374 | text_features = self.encode_text(text) 375 | 376 | # normalized features 377 | image_features = image_features / image_features.norm(dim=1, keepdim=True) 378 | text_features = text_features / text_features.norm(dim=1, keepdim=True) 379 | 380 | # cosine similarity as logits 381 | logit_scale = self.logit_scale.exp() 382 | return image_features, logit_scale * text_features 383 | 384 | logits_per_image = logit_scale * image_features @ text_features.t() 385 | logits_per_text = logits_per_image.t() 386 | 387 | # shape = [global_batch_size, global_batch_size] 388 | return logits_per_image, logits_per_text 389 | 390 | 391 | def convert_weights(model: nn.Module): 392 | """Convert applicable model parameters to fp16""" 393 | 394 | def _convert_weights_to_fp16(l): 395 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 396 | l.weight.data = l.weight.data.half() 397 | if l.bias is not None: 398 | l.bias.data = l.bias.data.half() 399 | 400 | if isinstance(l, nn.MultiheadAttention): 401 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 402 | tensor = getattr(l, attr) 403 | if tensor is not None: 404 | tensor.data = tensor.data.half() 405 | 406 | for name in ["text_projection", "proj"]: 407 | if hasattr(l, name): 408 | attr = getattr(l, name) 409 | if attr is not None: 410 | attr.data = attr.data.half() 411 | 412 | model.apply(_convert_weights_to_fp16) 413 | 414 | 415 | def build_model(state_dict: dict, prompt_len: int): 416 | vit = "visual.proj" in state_dict 417 | 418 | if vit: 419 | vision_width = state_dict["visual.conv1.weight"].shape[0] 420 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 421 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 422 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 423 | image_resolution = vision_patch_size * grid_size 424 | else: 425 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 426 | vision_layers = tuple(counts) 427 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 428 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 429 | vision_patch_size = None 430 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 431 | image_resolution = output_width * 32 432 | 433 | embed_dim = state_dict["text_projection"].shape[1] 434 | context_length = state_dict["positional_embedding"].shape[0] 435 | vocab_size = state_dict["token_embedding.weight"].shape[0] 436 | transformer_width = state_dict["ln_final.weight"].shape[0] 437 | transformer_heads = transformer_width // 64 438 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 439 | 440 | model = CLIP( 441 | embed_dim, 442 | image_resolution, vision_layers, vision_width, vision_patch_size, 443 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers, prompt_len 444 | ) 445 | 446 | for key in ["input_resolution", "context_length", "vocab_size"]: 447 | if key in state_dict: 448 | del state_dict[key] 449 | 450 | convert_weights(model) 451 | model.load_state_dict(state_dict, strict=False) 452 | return model.eval() 453 | -------------------------------------------------------------------------------- /code/replace/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /code/test_time_counterattack.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import argparse 4 | import os 5 | import time 6 | import random 7 | import logging 8 | from tqdm import tqdm 9 | from copy import deepcopy as dcopy 10 | 11 | import torch 12 | from torch.cuda.amp import GradScaler, autocast 13 | 14 | from replace import clip 15 | from models.prompters import TokenPrompter, NullPrompter 16 | from utils import * 17 | from attacks import * 18 | from func import clip_img_preprocessing, multiGPU_CLIP 19 | 20 | def parse_options(): 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument('--evaluate', type=bool, default=True) # eval mode 24 | parser.add_argument('--batch_size', type=int, default=256, help='batch_size') 25 | parser.add_argument('--num_workers', type=int, default=32, help='num of workers to use') 26 | parser.add_argument('--cache', type=str, default='./cache') 27 | 28 | # test setting 29 | parser.add_argument('--test_set', default=[], type=str, nargs='*') # defaults to 17 datasets, if not specified 30 | parser.add_argument('--test_attack_type', type=str, default="pgd", choices=['pgd', 'CW', 'autoattack',]) 31 | parser.add_argument('--test_eps', type=float, default=1,help='test attack budget') 32 | parser.add_argument('--test_numsteps', type=int, default=10) 33 | parser.add_argument('--test_stepsize', type=int, default=1) 34 | 35 | # model 36 | parser.add_argument('--model', type=str, default='clip') 37 | parser.add_argument('--arch', type=str, default='vit_b32') 38 | parser.add_argument('--method', type=str, default='null_patch', 39 | choices=['null_patch'], help='choose visual prompting method') 40 | parser.add_argument('--name', type=str, default='') 41 | parser.add_argument('--prompt_size', type=int, default=30, help='size for visual prompts') 42 | parser.add_argument('--add_prompt_size', type=int, default=0, help='size for additional visual prompts') 43 | 44 | # data 45 | parser.add_argument('--root', type=str, default='./data', help='dataset path') 46 | parser.add_argument('--dataset', type=str, default='tinyImageNet', help='dataset used for AFT methods') 47 | parser.add_argument('--image_size', type=int, default=224, help='image size') 48 | 49 | # TTC config 50 | parser.add_argument('--seed', type=int, default=0, help='seed for initializing training') 51 | parser.add_argument('--victim_resume', type=str, default=None, help='model weights of victim to attack.') 52 | parser.add_argument('--outdir', type=str, default=None, help='output directory for results') 53 | parser.add_argument('--tau_thres', type=float, default=0.2) 54 | parser.add_argument('--beta', type=float, default=2.,) 55 | parser.add_argument('--ttc_eps', type=float, default=4.) 56 | parser.add_argument('--ttc_numsteps', type=int, default=2) 57 | parser.add_argument('--ttc_stepsize', type=float, default=1.) 58 | 59 | args = parser.parse_args() 60 | return args 61 | 62 | def compute_tau(clip_visual, images, n): 63 | orig_feat = clip_visual(clip_img_preprocessing(images), None) # [bs, 512] 64 | noisy_feat = clip_visual(clip_img_preprocessing(images + n), None) 65 | diff_ratio = (noisy_feat - orig_feat).norm(dim=-1) / orig_feat.norm(dim=-1) # [bs] 66 | return diff_ratio 67 | 68 | def tau_thres_weighted_counterattacks(model, X, prompter, add_prompter, alpha, attack_iters, 69 | norm="l_inf", epsilon=0, visual_model_orig=None, 70 | tau_thres:float=None, beta:float=None, clip_visual=None): 71 | delta = torch.zeros_like(X) 72 | if epsilon <= 0.: 73 | return delta 74 | 75 | if norm == "l_inf": 76 | delta.uniform_(-epsilon, epsilon) 77 | elif norm == "l_2": 78 | delta.normal_() 79 | d_flat = delta.view(delta.size(0), -1) 80 | n = d_flat.norm(p=2, dim=1).view(delta.size(0), 1, 1, 1) 81 | r = torch.zeros_like(n).uniform_(0, 1) 82 | delta *= r / n * epsilon 83 | else: 84 | raise ValueError 85 | 86 | delta = clamp(delta, lower_limit - X, upper_limit - X) 87 | delta.requires_grad = True 88 | 89 | if attack_iters == 0: # apply random noise (RN) 90 | return delta.data 91 | 92 | diff_ratio = compute_tau(clip_visual, X, delta.data) if clip_visual is not None else None 93 | 94 | # Freeze model parameters temporarily. Not necessary but for completeness of code 95 | tunable_param_names = [] 96 | for n,p in model.module.named_parameters(): 97 | if p.requires_grad: 98 | tunable_param_names.append(n) 99 | p.requires_grad = False 100 | 101 | prompt_token = add_prompter() 102 | with torch.no_grad(): 103 | X_ori_reps = model.module.encode_image( 104 | prompter(clip_img_preprocessing(X)), prompt_token 105 | ) 106 | X_ori_norm = torch.norm(X_ori_reps, dim=-1) # [ bs] 107 | 108 | deltas_per_step = [] 109 | deltas_per_step.append(delta.data.clone()) 110 | 111 | for _step_id in range(attack_iters): 112 | 113 | prompted_images = prompter(clip_img_preprocessing(X + delta)) 114 | X_att_reps = model.module.encode_image(prompted_images, prompt_token) 115 | if _step_id == 0 and diff_ratio is None: # compute tau at the zero-th step 116 | feature_diff = X_att_reps - X_ori_reps # [bs, 512] 117 | diff_ratio = torch.norm(feature_diff, dim=-1) / X_ori_norm # [bs] 118 | 119 | scheme_sign = (tau_thres - diff_ratio).sign() 120 | 121 | l2_loss = ((((X_att_reps - X_ori_reps)**2).sum(1))).sum() 122 | grad = torch.autograd.grad(l2_loss, delta)[0] 123 | d = delta[:, :, :, :] 124 | g = grad[:, :, :, :] 125 | x = X[:, :, :, :] 126 | 127 | if norm == "l_inf": 128 | d = torch.clamp(d + alpha * torch.sign(g), min=-epsilon, max=epsilon) 129 | elif norm == "l_2": 130 | g_norm = torch.norm(g.view(g.shape[0], -1), dim=1).view(-1, 1, 1, 1) 131 | scaled_g = g / (g_norm + 1e-10) 132 | d = (d + scaled_g * alpha).view(d.size(0), -1).renorm(p=2, dim=0, maxnorm=epsilon).view_as(d) 133 | d = clamp(d, lower_limit - x, upper_limit - x) 134 | delta.data[:, :, :, :] = d 135 | deltas_per_step.append(delta.data.clone()) 136 | 137 | Delta = torch.stack(deltas_per_step, dim=1) # [bs, numsteps+1, C, W, H] 138 | 139 | # create weights across steps 140 | weights = torch.arange(attack_iters+1).unsqueeze(0).expand(X.size(0), -1).to(device) # [bs, numsteps+1] 141 | weights = torch.exp( 142 | scheme_sign.view(-1, 1) * weights * beta 143 | ) # [bs, numsteps+1] 144 | weights /= weights.sum(dim=1, keepdim=True) 145 | 146 | weights_hard = torch.zeros_like(weights) # [bs, numsteps+1] 147 | weights_hard[:,0] = 1. 148 | 149 | weights = torch.where(scheme_sign.unsqueeze(1)>0, weights, weights_hard) 150 | weights = weights.view(X.size(0), attack_iters+1, 1, 1, 1) 151 | 152 | Delta = (weights * Delta).sum(dim=1) 153 | 154 | # Unfreeze model parameters. Only for completeness of code 155 | for n,p in model.module.named_parameters(): 156 | if n in tunable_param_names: 157 | p.requires_grad = True 158 | 159 | return Delta 160 | 161 | 162 | def validate(args, val_dataset_name, model, model_text, model_image, 163 | prompter, add_prompter, criterion, visual_model_orig=None, 164 | clip_visual=None 165 | ): 166 | 167 | logging.info(f"Evaluate with Attack method: {args.test_attack_type}") 168 | 169 | dataset_num = len(val_dataset_name) 170 | all_clean_org, all_clean_ttc, all_adv_org, all_adv_ttc = {},{},{},{} 171 | 172 | test_stepsize = args.test_stepsize 173 | 174 | ttc_eps = args.ttc_eps 175 | ttc_numsteps = args.ttc_numsteps 176 | ttc_stepsize = args.ttc_stepsize 177 | beta = args.beta 178 | tau_thres = args.tau_thres 179 | 180 | for cnt in range(dataset_num): 181 | val_dataset, val_loader = load_val_dataset(args, val_dataset_name[cnt]) 182 | dataset_name = val_dataset_name[cnt] 183 | texts = get_text_prompts_val([val_dataset], [dataset_name])[0] 184 | 185 | binary = ['PCAM', 'hateful_memes'] 186 | attacks_to_run=['apgd-ce', 'apgd-dlr'] 187 | if dataset_name in binary: 188 | attacks_to_run=['apgd-ce'] 189 | 190 | batch_time = AverageMeter('Time', ':6.3f') 191 | losses = AverageMeter('Loss', ':.4e') 192 | top1_org = AverageMeter('Original Acc@1', ':6.2f') 193 | top1_org_ttc = AverageMeter('Prompt Acc@1', ':6.2f') 194 | top1_adv = AverageMeter('Adv Original Acc@1', ':6.2f') 195 | top1_adv_ttc = AverageMeter('Adv Prompt Acc@1', ':6.2f') 196 | 197 | # switch to evaluation mode 198 | prompter.eval() 199 | add_prompter.eval() 200 | model.eval() 201 | 202 | text_tokens = clip.tokenize(texts).to(device) 203 | end = time.time() 204 | 205 | for i, (images, target) in enumerate(tqdm(val_loader)): 206 | 207 | images = images.to(device) 208 | target = target.to(device) 209 | 210 | with autocast(): 211 | 212 | # original acc of clean images 213 | with torch.no_grad(): 214 | clean_output,_,_,_ = multiGPU_CLIP( 215 | None, None, None, model, prompter(clip_img_preprocessing(images)), 216 | text_tokens = text_tokens, 217 | prompt_token = None, dataset_name = dataset_name 218 | ) 219 | clean_acc = accuracy(clean_output, target, topk=(1,)) 220 | top1_org.update(clean_acc[0].item(), images.size(0)) 221 | 222 | # TTC on clean images 223 | ttc_delta_clean = tau_thres_weighted_counterattacks( 224 | model, images, prompter, add_prompter, 225 | alpha=ttc_stepsize, attack_iters=ttc_numsteps, 226 | norm='l_inf', epsilon=ttc_eps, visual_model_orig=None, 227 | tau_thres=tau_thres, beta = beta, 228 | clip_visual=clip_visual 229 | ) 230 | with torch.no_grad(): 231 | clean_output_ttc,_,_,_ = multiGPU_CLIP( 232 | None, None, None, model, prompter(clip_img_preprocessing(images+ttc_delta_clean)), 233 | text_tokens = text_tokens, 234 | prompt_token = None, dataset_name = dataset_name 235 | ) 236 | clean_acc_ttc = accuracy(clean_output_ttc, target, topk=(1,)) 237 | top1_org_ttc.update(clean_acc_ttc[0].item(), images.size(0)) 238 | 239 | # generate adv samples for this batch 240 | torch.cuda.empty_cache() 241 | if args.test_attack_type == "pgd": 242 | delta_prompt = attack_pgd(args, prompter, model, model_text, model_image, add_prompter, criterion, 243 | images, target, test_stepsize, args.test_numsteps, 'l_inf', 244 | text_tokens=text_tokens, epsilon=args.test_eps, dataset_name=dataset_name) 245 | attacked_images = images + delta_prompt 246 | elif args.test_attack_type == "CW": 247 | delta_prompt = attack_CW(args, prompter, model, model_text, model_image, add_prompter, criterion, 248 | images, target, text_tokens, 249 | test_stepsize, args.test_numsteps, 'l_inf', epsilon=args.test_eps) 250 | attacked_images = images + delta_prompt 251 | elif args.test_attack_type == "autoattack": 252 | attacked_images = attack_auto(model, images, target, text_tokens, 253 | None, None, epsilon=args.test_eps, attacks_to_run=attacks_to_run) 254 | 255 | # acc of adv images without ttc 256 | with torch.no_grad(): 257 | adv_output,_,_,_ = multiGPU_CLIP( 258 | None,None,None, model, prompter(clip_img_preprocessing(attacked_images)), 259 | text_tokens, prompt_token=None, dataset_name=dataset_name 260 | ) 261 | adv_acc = accuracy(adv_output, target, topk=(1,)) 262 | top1_adv.update(adv_acc[0].item(), images.size(0)) 263 | 264 | ttc_delta_adv = tau_thres_weighted_counterattacks( 265 | model, attacked_images.data, prompter, add_prompter, 266 | alpha=ttc_stepsize, attack_iters=ttc_numsteps, 267 | norm='l_inf', epsilon=ttc_eps, visual_model_orig=None, 268 | tau_thres=tau_thres, beta = beta, 269 | clip_visual = clip_visual 270 | ) 271 | with torch.no_grad(): 272 | adv_output_ttc,_,_,_ = multiGPU_CLIP( 273 | None,None,None, model, prompter(clip_img_preprocessing(attacked_images+ttc_delta_adv)), 274 | text_tokens, prompt_token=None, dataset_name=dataset_name 275 | ) 276 | adv_output_acc = accuracy(adv_output_ttc, target, topk=(1,)) 277 | top1_adv_ttc.update(adv_output_acc[0].item(), images.size(0)) 278 | 279 | batch_time.update(time.time() - end) 280 | end = time.time() 281 | 282 | torch.cuda.empty_cache() 283 | clean_acc = top1_org.avg 284 | clean_ttc_acc = top1_org_ttc.avg 285 | adv_acc = top1_adv.avg 286 | adv_ttc_acc = top1_adv_ttc.avg 287 | 288 | all_clean_org[dataset_name] = clean_acc 289 | all_clean_ttc[dataset_name] = clean_ttc_acc 290 | all_adv_org[dataset_name] = adv_acc 291 | all_adv_ttc[dataset_name] = adv_ttc_acc 292 | 293 | show_text = f"{dataset_name}:\n\t" 294 | show_text += f"- clean acc. {clean_acc:.2f} (ttc: {clean_ttc_acc:.2f})\n\t" 295 | show_text += f"- robust acc. {adv_acc:.2f} (ttc: {adv_ttc_acc:.2f})" 296 | 297 | logging.info(show_text) 298 | 299 | all_clean_org_avg = np.mean([all_clean_org[name] for name in all_clean_org]).item() 300 | all_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in all_clean_ttc]).item() 301 | all_adv_org_avg = np.mean([all_adv_org[name] for name in all_adv_org]).item() 302 | all_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in all_adv_ttc]).item() 303 | show_text = f"===== SUMMARY ACROSS {dataset_num} DATASETS =====\n\t" 304 | show_text += f"AVG acc. {all_clean_org_avg:.2f} (ttc: {all_clean_ttc_avg:.2f})\n\t" 305 | show_text += f"AVG acc. {all_adv_org_avg:.2f} (ttc: {all_adv_ttc_avg:.2f})" 306 | logging.info(show_text) 307 | 308 | # Exclude the dataset used for implementing AFT methods (tinyImageNet in the paper) 309 | zs_clean_org_avg = np.mean([all_clean_org[name] for name in val_dataset_name if name != args.dataset]).item() 310 | zs_clean_ttc_avg = np.mean([all_clean_ttc[name] for name in val_dataset_name if name != args.dataset]).item() 311 | zs_adv_org_avg = np.mean([all_adv_org[name] for name in val_dataset_name if name != args.dataset]).item() 312 | zs_adv_ttc_avg = np.mean([all_adv_ttc[name] for name in val_dataset_name if name != args.dataset]).item() 313 | valid_dataset_num = dataset_num - 1 if args.dataset in val_dataset_name else dataset_num 314 | show_text = f"===== SUMMARY ACROSS {valid_dataset_num} DATASETS (EXCEPT {args.dataset}) =====\n\t" 315 | show_text += f"AVG acc. {zs_clean_org_avg:.2f} (ttc: {zs_clean_ttc_avg:.2f})\n\t" 316 | show_text += f"AVG acc. {zs_adv_org_avg:.2f} (ttc: {zs_adv_ttc_avg:.2f})" 317 | logging.info(show_text) 318 | 319 | return all_clean_org_avg, all_clean_ttc_avg, all_adv_org_avg, all_adv_ttc_avg 320 | 321 | device = "cuda" if torch.cuda.is_available() else "cpu" 322 | 323 | def main(): 324 | 325 | args = parse_options() 326 | 327 | outdir = args.outdir if args.outdir is not None else "TTC_results" 328 | outdir = os.path.join(outdir, f"{args.test_attack_type}_eps_{args.test_eps}_numsteps_{args.test_numsteps}") 329 | os.makedirs(outdir, exist_ok=True) 330 | 331 | args.test_eps = args.test_eps / 255. 332 | args.test_stepsize = args.test_stepsize / 255. 333 | 334 | seed = args.seed 335 | random.seed(seed) 336 | np.random.seed(seed) 337 | torch.manual_seed(seed) 338 | torch.cuda.manual_seed(seed) 339 | torch.cuda.manual_seed_all(seed) 340 | torch.backends.cudnn.deterministic = True 341 | torch.backends.cudnn.benchmark = False 342 | 343 | log_filename = "" 344 | log_filename += f"ttc_eps_{args.ttc_eps}_thres_{args.tau_thres}_beta_{args.beta}_numsteps_{args.ttc_numsteps}_stepsize_{int(args.ttc_stepsize)}_seed_{seed}.log".replace(" ", "") 345 | log_filename = os.path.join(outdir, log_filename) 346 | logging.basicConfig( 347 | filename = log_filename, 348 | level = logging.INFO, 349 | format="%(asctime)s - %(levelname)s - %(message)s" 350 | ) 351 | logging.info(args) 352 | 353 | args.ttc_stepsize = args.ttc_stepsize / 255. 354 | args.ttc_eps = args.ttc_eps / 255. 355 | 356 | imagenet_root = './data/ImageNet' 357 | tinyimagenet_root = "./data/tiny-imagenet-200" 358 | args.imagenet_root = imagenet_root 359 | args.tinyimagenet_root = tinyimagenet_root 360 | 361 | # load model 362 | model, _ = clip.load('ViT-B/32', device, jit=False, prompt_len=0) 363 | for p in model.parameters(): 364 | p.requires_grad = False 365 | convert_models_to_fp32(model) 366 | 367 | if args.victim_resume: # employ TTC on AFT checkpoints 368 | clip_visual = dcopy(model.visual) 369 | model = load_checkpoints2(args, args.victim_resume, model, None) 370 | else: # employ TTC on the original CLIP 371 | clip_visual = None 372 | 373 | model = torch.nn.DataParallel(model) 374 | model.eval() 375 | prompter = NullPrompter() 376 | add_prompter = TokenPrompter(0) 377 | prompter = torch.nn.DataParallel(prompter).cuda() 378 | add_prompter = torch.nn.DataParallel(add_prompter).cuda() 379 | logging.info("done loading model.") 380 | 381 | if len(args.test_set) == 0: 382 | test_set = DATASETS 383 | else: 384 | test_set = args.test_set 385 | 386 | # criterion to compute attack loss, the reduction of 'sum' is important for effective attacks 387 | criterion_attack = torch.nn.CrossEntropyLoss(reduction='sum').to(device) 388 | 389 | validate( 390 | args, test_set, model, None, None, prompter, 391 | add_prompter, criterion_attack, None, clip_visual 392 | ) 393 | 394 | if __name__ == "__main__": 395 | main() 396 | -------------------------------------------------------------------------------- /code/utils.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | import torch 4 | import numpy as np 5 | import torchvision.transforms as transforms 6 | from torchvision.datasets import CIFAR10, CIFAR100, STL10, ImageFolder 7 | from replace.datasets import caltech, country211, dtd,eurosat, fgvc_aircraft, food101, \ 8 | flowers102, oxford_iiit_pet, pcam, stanford_cars, sun397 9 | from torch.utils.data import DataLoader 10 | import json 11 | 12 | def convert_models_to_fp32(model): 13 | for p in model.parameters(): 14 | p.data = p.data.float() 15 | if p.grad: 16 | p.grad.data = p.grad.data.float() 17 | 18 | 19 | def refine_classname(class_names): 20 | for i, class_name in enumerate(class_names): 21 | class_names[i] = class_name.lower().replace('_', ' ').replace('-', ' ').replace('/', ' ') 22 | return class_names 23 | 24 | def save_checkpoint(state, save_folder, is_best=False, filename='checkpoint.pth.tar'): 25 | savefile = os.path.join(save_folder, filename) 26 | bestfile = os.path.join(save_folder, 'model_best.pth.tar') 27 | torch.save(state, savefile) 28 | if is_best: 29 | shutil.copyfile(savefile, bestfile) 30 | print ('saved best file') 31 | 32 | def assign_learning_rate(optimizer, new_lr, tgt_group_idx=None): 33 | for group_idx, param_group in enumerate(optimizer.param_groups): 34 | if tgt_group_idx is None or tgt_group_idx==group_idx: 35 | param_group["lr"] = new_lr 36 | 37 | def _warmup_lr(base_lr, warmup_length, step): 38 | return base_lr * (step + 1) / warmup_length 39 | 40 | def cosine_lr(optimizer, base_lr, warmup_length, steps, tgt_group_idx=None): 41 | def _lr_adjuster(step): 42 | if step < warmup_length: 43 | lr = _warmup_lr(base_lr, warmup_length, step) 44 | else: 45 | e = step - warmup_length 46 | es = steps - warmup_length 47 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 48 | assign_learning_rate(optimizer, lr, tgt_group_idx) 49 | return lr 50 | return _lr_adjuster 51 | 52 | def null_scheduler(init_lr): 53 | return lambda step:init_lr 54 | 55 | def accuracy(output, target, topk=(1,)): 56 | """Computes the accuracy over the k top predictions for the specified values of k""" 57 | with torch.no_grad(): 58 | maxk = max(topk) 59 | batch_size = target.size(0) 60 | 61 | _, pred = output.topk(maxk, 1, True, True) 62 | pred = pred.t() 63 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 64 | 65 | res = [] 66 | for k in topk: 67 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 68 | res.append(correct_k.mul_(100.0 / batch_size)) 69 | return res 70 | 71 | 72 | class AverageMeter(object): 73 | """Computes and stores the average and current value""" 74 | def __init__(self, name, fmt=':f'): 75 | self.name = name 76 | self.fmt = fmt 77 | self.reset() 78 | 79 | def reset(self): 80 | self.val = 0 81 | self.avg = 0 82 | self.sum = 0 83 | self.count = 0 84 | 85 | def update(self, val, n=1): 86 | self.val = val 87 | self.sum += val * n 88 | self.count += n 89 | self.avg = self.sum / self.count 90 | 91 | def __str__(self): 92 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' 93 | return fmtstr.format(**self.__dict__) 94 | 95 | 96 | class ProgressMeter(object): 97 | def __init__(self, num_batches, meters, prefix=""): 98 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 99 | self.meters = meters 100 | self.prefix = prefix 101 | 102 | def display(self, batch): 103 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 104 | entries += [str(meter) for meter in self.meters] 105 | print('\t'.join(entries)) 106 | 107 | def _get_batch_fmtstr(self, num_batches): 108 | num_digits = len(str(num_batches // 1)) 109 | fmt = '{:' + str(num_digits) + 'd}' 110 | return '[' + fmt + '/' + fmt.format(num_batches) + ']' 111 | 112 | 113 | def load_imagenet_folder2name(path): 114 | dict_imagenet_folder2name = {} 115 | with open(path) as f: 116 | line = f.readline() 117 | while line: 118 | split_name = line.strip().split() 119 | cat_name = split_name[2] 120 | id = split_name[0] 121 | dict_imagenet_folder2name[id] = cat_name 122 | line = f.readline() 123 | # print(dict_imagenet_folder2name) 124 | return dict_imagenet_folder2name 125 | 126 | 127 | 128 | def one_hot_embedding(labels, num_classes): 129 | """Embedding labels to one-hot form. 130 | Args: 131 | labels: (LongTensor) class labels, sized [N,]. 132 | num_classes: (int) number of classes. 133 | Returns: 134 | (tensor) encoded labels, sized [N, #classes]. 135 | """ 136 | y = torch.eye(num_classes) 137 | return y[labels.cpu()] 138 | 139 | 140 | preprocess = transforms.Compose([ 141 | transforms.ToTensor() 142 | ]) 143 | preprocess224 = transforms.Compose([ 144 | transforms.Resize(256), 145 | transforms.CenterCrop(224), 146 | transforms.ToTensor() 147 | ]) 148 | preprocess224_caltech = transforms.Compose([ 149 | transforms.Resize(256), 150 | transforms.CenterCrop(224), 151 | transforms.Lambda(lambda img: img.convert("RGB")), 152 | transforms.ToTensor() 153 | ]) 154 | preprocess224_interpolate = transforms.Compose([ 155 | transforms.Resize((224, 224)), 156 | transforms.ToTensor() 157 | ]) 158 | 159 | def load_train_dataset(args): 160 | if args.dataset == 'cifar100': 161 | return CIFAR100(args.root, transform=preprocess, download=True, train=True) 162 | elif args.dataset == 'cifar10': 163 | return CIFAR10(args.root, transform=preprocess, download=True, train=True) 164 | elif args.dataset == 'ImageNet': 165 | assert args.imagenet_root is not None 166 | print(f"Loading ImageNet from {args.imagenet_root}") 167 | return ImageFolder(os.path.join(args.imagenet_root, 'train'), transform=preprocess224) 168 | else: 169 | print(f"Train dataset {args.dataset} not implemented") 170 | raise NotImplementedError 171 | 172 | def get_eval_files(dataset_name): 173 | # only for imaegnet and tinyimagenet 174 | refined_data_file = f"./support/{dataset_name.lower()}_refined_labels.json" 175 | refined_data = read_json(refined_data_file) 176 | eval_select = {ssid:refined_data[ssid]['eval_files'] for ssid in refined_data} 177 | return eval_select 178 | 179 | def load_val_dataset(args, val_dataset_name): 180 | 181 | if val_dataset_name == 'cifar10': 182 | val_dataset = CIFAR10(args.root, transform=preprocess224, download=True, train=False) 183 | 184 | elif val_dataset_name == 'cifar100': 185 | val_dataset = CIFAR100(args.root, transform=preprocess224, download=True, train=False) 186 | 187 | elif val_dataset_name == 'Caltech101': 188 | val_dataset = caltech.Caltech101(args.root, target_type='category', transform=preprocess224_caltech, download=True) 189 | 190 | elif val_dataset_name == 'PCAM': 191 | val_dataset = pcam.PCAM(args.root, split='test', transform=preprocess224, download=True) 192 | 193 | elif val_dataset_name == 'STL10': 194 | val_dataset = STL10(args.root, split='test', transform=preprocess224, download=True) 195 | 196 | elif val_dataset_name == 'SUN397': 197 | val_dataset = sun397.SUN397(args.root, transform=preprocess224, download=True) 198 | 199 | elif val_dataset_name == 'StanfordCars': 200 | val_dataset = stanford_cars.StanfordCars(args.root, split='test', transform=preprocess224, download=True) 201 | 202 | elif val_dataset_name == 'Food101': 203 | val_dataset = food101.Food101(args.root, split='test', transform=preprocess224, download=True) 204 | 205 | elif val_dataset_name == 'oxfordpet': 206 | val_dataset = oxford_iiit_pet.OxfordIIITPet(args.root, split='test', transform=preprocess224, download=True) 207 | 208 | elif val_dataset_name == 'EuroSAT': 209 | val_dataset = eurosat.EuroSAT(args.root, transform=preprocess224, download=True) 210 | 211 | elif val_dataset_name == 'Caltech256': 212 | val_dataset = caltech.Caltech256(args.root, transform=preprocess224_caltech, download=True) 213 | 214 | elif val_dataset_name == 'flowers102': 215 | val_dataset = flowers102.Flowers102(args.root, split='test', transform=preprocess224, download=True) 216 | 217 | elif val_dataset_name == 'Country211': 218 | val_dataset = country211.Country211(args.root, split='test', transform=preprocess224, download=True) 219 | 220 | elif val_dataset_name == 'dtd': 221 | val_dataset = dtd.DTD(args.root, split='test', transform=preprocess224, download=True) 222 | 223 | elif val_dataset_name == 'fgvc_aircraft': 224 | val_dataset = fgvc_aircraft.FGVCAircraft(args.root, split='test', transform=preprocess224, download=True) 225 | 226 | elif val_dataset_name == 'ImageNet': 227 | from replace.datasets.folder import ImageNetFolder 228 | if args.evaluate: 229 | val_dataset = ImageNetFolder(os.path.join(args.imagenet_root, 'val'), transform=preprocess224) 230 | else: 231 | eval_select = get_eval_files(val_dataset_name) 232 | val_dataset = ImageNetFolder(os.path.join(args.imagenet_root, 'train'), transform=preprocess224, select_files=eval_select) 233 | 234 | elif val_dataset_name == 'tinyImageNet': 235 | from replace.datasets.folder import ImageNetFolder 236 | if args.evaluate: 237 | val_dataset = ImageNetFolder(os.path.join(args.tinyimagenet_root, 'val_'), transform=preprocess224) 238 | else: 239 | eval_select = get_eval_files(val_dataset_name) 240 | val_dataset = ImageNetFolder(os.path.join(args.tinyimagenet_root, 'train'), transform=preprocess224, select_files=eval_select) 241 | else: 242 | print(f"Val dataset {val_dataset_name} not implemented") 243 | raise NotImplementedError 244 | 245 | val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True, 246 | num_workers=args.num_workers, shuffle=False,) 247 | 248 | return val_dataset, val_loader 249 | 250 | def get_text_prompts_train(args, train_dataset, template='This is a photo of a {}'): 251 | class_names = train_dataset.classes 252 | if args.dataset == 'ImageNet': 253 | folder2name = load_imagenet_folder2name('support/imagenet_classes_names.txt') 254 | new_class_names = [] 255 | for each in class_names: 256 | new_class_names.append(folder2name[each]) 257 | 258 | class_names = new_class_names 259 | 260 | class_names = refine_classname(class_names) 261 | texts_train = [template.format(label) for label in class_names] 262 | return texts_train 263 | 264 | def get_text_prompts_val(val_dataset_list, val_dataset_name, template='This is a photo of a {}.'): 265 | texts_list = [] 266 | for cnt, each in enumerate(val_dataset_list): 267 | if hasattr(each, 'clip_prompts'): 268 | texts_tmp = each.clip_prompts 269 | else: 270 | class_names = each.classes if hasattr(each, 'classes') else each.clip_categories 271 | if val_dataset_name[cnt] in ['ImageNet', 'tinyImageNet']: 272 | refined_data = read_json(f"./support/{val_dataset_name[cnt].lower()}_refined_labels.json") 273 | clean_class_names = [refined_data[ssid]['clean_name'] for ssid in class_names] 274 | class_names = clean_class_names 275 | 276 | texts_tmp = [template.format(label) for label in class_names] 277 | texts_list.append(texts_tmp) 278 | assert len(texts_list) == len(val_dataset_list) 279 | return texts_list 280 | 281 | 282 | def read_json(json_file:str): 283 | with open(json_file, 'r') as f: 284 | data = json.load(f) 285 | return data 286 | 287 | # def get_prompts(class_names): 288 | # # consider using correct articles 289 | # template = "This is a photo of a {}." 290 | # template_v = "This is a photo of an {}." 291 | # prompts = [] 292 | # for class_name in class_names: 293 | # if class_name[0].lower() in ['a','e','i','o','u'] or class_name == "hourglass": 294 | # prompts.append(template_v.format(class_name)) 295 | # else: 296 | # prompts.append(template.format(class_name)) 297 | # return prompts 298 | 299 | def freeze(model:torch.nn.Module): 300 | for param in model.parameters(): 301 | param.requires_grad=False 302 | return 303 | 304 | DATASETS = [ 305 | 'cifar10', 'cifar100', 'STL10','ImageNet', 306 | 'Caltech101', 'Caltech256', 'oxfordpet', 'flowers102', 'fgvc_aircraft', 307 | 'StanfordCars', 'SUN397', 'Country211', 'Food101', 'EuroSAT', 308 | 'dtd', 'PCAM', 'tinyImageNet', 309 | ] 310 | 311 | 312 | def write_file(txt:str, file:str, mode='a'): 313 | with open(file, mode) as f: 314 | f.write(txt) 315 | 316 | 317 | def load_resume_file(file:str, gpu:int): 318 | if os.path.isfile(file): 319 | print("=> loading checkpoint '{}'".format(file)) 320 | if gpu is None: 321 | checkpoint = torch.load(file) 322 | else: 323 | loc = 'cuda:{}'.format(gpu) 324 | checkpoint = torch.load(file, map_location=loc) 325 | print("=> loaded checkpoint '{}' (epoch {})".format(file, checkpoint['epoch'])) 326 | return checkpoint 327 | else: 328 | print("=> no checkpoint found at '{}'".format(file)) 329 | return None 330 | 331 | 332 | def load_checkpoints2(args, resume_file, model, optimizer=None): 333 | checkpoint = load_resume_file(resume_file, None) 334 | try: 335 | model.module.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False) 336 | except: 337 | model.visual.load_state_dict(checkpoint['vision_encoder_state_dict'], strict=False) 338 | if optimizer is not None: 339 | optimizer.load_state_dict(checkpoint['optimizer']) 340 | return model 341 | -------------------------------------------------------------------------------- /data/readme.md: -------------------------------------------------------------------------------- 1 | This directory is the root path for all datasets. Please download raw datasets in this folder. 2 | -------------------------------------------------------------------------------- /download_weights.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gdown 3 | 4 | file_ids = { 5 | "FARE": "1IMtb5SG1ajYphR8cK-3w3Nr7zvAHa8bi", 6 | "TeCoA": "1m4Iw9pCjtBHj7OVqHFlRu2OkO0rd7C7j", 7 | "PMG_AFT": "1JMXdMheNaYWiwqcWI0tvRrp6MBweQagn", 8 | } 9 | 10 | os.makedirs("./AFT_model_weights", exist_ok=True) 11 | 12 | for name, fid in file_ids.items(): 13 | url = f"https://drive.google.com/uc?id={fid}" 14 | output = f"./AFT_model_weights/{name}.pth.tar" 15 | gdown.download(url, output, quiet=False) 16 | print(f"Downloaded {name} weights to {output}") -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: TTC 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - blas=1.0=mkl 10 | - brotli-python=1.0.9=py311h6a678d5_9 11 | - bzip2=1.0.8=h5eee18b_6 12 | - c-ares=1.19.1=h5eee18b_0 13 | - ca-certificates=2025.2.25=h06a4308_0 14 | - certifi=2025.1.31=py311h06a4308_0 15 | - charset-normalizer=3.3.2=pyhd3eb1b0_0 16 | - cuda-cudart=11.8.89=0 17 | - cuda-cupti=11.8.87=0 18 | - cuda-libraries=11.8.0=0 19 | - cuda-nvrtc=11.8.89=0 20 | - cuda-nvtx=11.8.86=0 21 | - cuda-runtime=11.8.0=0 22 | - cuda-version=12.8=3 23 | - ffmpeg=4.3=hf484d3e_0 24 | - filelock=3.13.1=py311h06a4308_0 25 | - freetype=2.12.1=h4a9f257_0 26 | - ftfy=5.8=py_0 27 | - gmp=6.3.0=h6a678d5_0 28 | - gmpy2=2.2.1=py311h5eee18b_0 29 | - gnutls=3.6.15=he1e5248_0 30 | - h5py=3.12.1=py311h5842655_1 31 | - hdf5=1.14.5=h2b7332f_2 32 | - idna=3.7=py311h06a4308_0 33 | - intel-openmp=2023.1.0=hdb19cb5_46306 34 | - jinja2=3.1.5=py311h06a4308_0 35 | - jpeg=9e=h5eee18b_3 36 | - krb5=1.20.1=h143b758_1 37 | - lame=3.100=h7b6447c_0 38 | - lcms2=2.16=hb9589c4_0 39 | - ld_impl_linux-64=2.40=h12ee557_0 40 | - lerc=4.0.0=h6a678d5_0 41 | - libcublas=11.11.3.6=0 42 | - libcufft=10.9.0.58=0 43 | - libcufile=1.13.1.3=0 44 | - libcurand=10.3.9.90=0 45 | - libcurl=8.12.1=hc9e6f67_0 46 | - libcusolver=11.4.1.48=0 47 | - libcusparse=11.7.5.86=0 48 | - libdeflate=1.22=h5eee18b_0 49 | - libedit=3.1.20230828=h5eee18b_0 50 | - libev=4.33=h7f8727e_1 51 | - libffi=3.4.4=h6a678d5_1 52 | - libgcc-ng=11.2.0=h1234567_1 53 | - libgfortran-ng=11.2.0=h00389a5_1 54 | - libgfortran5=11.2.0=h1234567_1 55 | - libgomp=11.2.0=h1234567_1 56 | - libiconv=1.16=h5eee18b_3 57 | - libidn2=2.3.4=h5eee18b_0 58 | - libnghttp2=1.57.0=h2d74bed_0 59 | - libnpp=11.8.0.86=0 60 | - libnvjpeg=11.9.0.86=0 61 | - libpng=1.6.39=h5eee18b_0 62 | - libssh2=1.11.1=h251f7ec_0 63 | - libstdcxx-ng=11.2.0=h1234567_1 64 | - libtasn1=4.19.0=h5eee18b_0 65 | - libtiff=4.5.1=hffd6297_1 66 | - libunistring=0.9.10=h27cfd23_0 67 | - libuuid=1.41.5=h5eee18b_0 68 | - libwebp-base=1.3.2=h5eee18b_1 69 | - lz4-c=1.9.4=h6a678d5_1 70 | - markupsafe=3.0.2=py311h5eee18b_0 71 | - mkl=2023.1.0=h213fc3f_46344 72 | - mkl-service=2.4.0=py311h5eee18b_2 73 | - mkl_fft=1.3.11=py311h5eee18b_0 74 | - mkl_random=1.2.8=py311ha02d727_0 75 | - mpc=1.3.1=h5eee18b_0 76 | - mpfr=4.2.1=h5eee18b_0 77 | - mpmath=1.3.0=py311h06a4308_0 78 | - ncurses=6.4=h6a678d5_0 79 | - nettle=3.7.3=hbbd107a_1 80 | - networkx=3.4.2=py311h06a4308_0 81 | - numpy=1.26.0=py311h08b1b3b_0 82 | - numpy-base=1.26.0=py311hf175353_0 83 | - openh264=2.1.1=h4ff587b_0 84 | - openjpeg=2.5.2=he7f1fd0_0 85 | - openssl=3.0.16=h5eee18b_0 86 | - pillow=11.1.0=py311hcea889d_0 87 | - pip=25.0=py311h06a4308_0 88 | - pysocks=1.7.1=py311h06a4308_0 89 | - python=3.11.11=he870216_0 90 | - pytorch=2.0.1=py3.11_cuda11.8_cudnn8.7.0_0 91 | - pytorch-cuda=11.8=h7e8668a_6 92 | - pytorch-mutex=1.0=cuda 93 | - readline=8.2=h5eee18b_0 94 | - regex=2024.11.6=py311h5eee18b_0 95 | - requests=2.32.3=py311h06a4308_1 96 | - scipy=1.15.1=py311h08b1b3b_0 97 | - setuptools=72.1.0=py311h06a4308_0 98 | - sqlite=3.45.3=h5eee18b_0 99 | - sympy=1.13.3=py311h06a4308_1 100 | - tbb=2021.8.0=hdb19cb5_0 101 | - tk=8.6.14=h39e8969_0 102 | - torchtriton=2.0.0=py311 103 | - torchvision=0.15.2=py311_cu118 104 | - tqdm=4.67.1=py311h92b7b1e_0 105 | - typing_extensions=4.12.2=py311h06a4308_0 106 | - tzdata=2025a=h04d1e81_0 107 | - urllib3=2.3.0=py311h06a4308_0 108 | - wcwidth=0.2.5=pyhd3eb1b0_0 109 | - wheel=0.45.1=py311h06a4308_0 110 | - xz=5.6.4=h5eee18b_1 111 | - zlib=1.2.13=h5eee18b_1 112 | - zstd=1.5.6=hc292b87_0 113 | prefix: /home/sxing/miniconda3/envs/TTC 114 | -------------------------------------------------------------------------------- /figures/fig2b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/figures/fig2b.png -------------------------------------------------------------------------------- /figures/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/figures/teaser.png -------------------------------------------------------------------------------- /poster_CVPR_XING.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sxing2/CLIP-Test-time-Counterattacks/793fbff69cd7b63881e761e9a73a38e76b3ce5bf/poster_CVPR_XING.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/openai/CLIP.git 2 | git+https://github.com/fra31/auto-attack 3 | einops==0.8.1 4 | pycountry==24.6.1 5 | packaging==24.2 6 | gdown==5.2.0 7 | -------------------------------------------------------------------------------- /support/imagenet_classes_names.txt: -------------------------------------------------------------------------------- 1 | n02119789 1 kit_fox 2 | n02100735 2 English_setter 3 | n02110185 3 Siberian_husky 4 | n02096294 4 Australian_terrier 5 | n02102040 5 English_springer 6 | n02066245 6 grey_whale 7 | n02509815 7 lesser_panda 8 | n02124075 8 Egyptian_cat 9 | n02417914 9 ibex 10 | n02123394 10 Persian_cat 11 | n02125311 11 cougar 12 | n02423022 12 gazelle 13 | n02346627 13 porcupine 14 | n02077923 14 sea_lion 15 | n02110063 15 malamute 16 | n02447366 16 badger 17 | n02109047 17 Great_Dane 18 | n02089867 18 Walker_hound 19 | n02102177 19 Welsh_springer_spaniel 20 | n02091134 20 whippet 21 | n02092002 21 Scottish_deerhound 22 | n02071294 22 killer_whale 23 | n02442845 23 mink 24 | n02504458 24 African_elephant 25 | n02092339 25 Weimaraner 26 | n02098105 26 soft-coated_wheaten_terrier 27 | n02096437 27 Dandie_Dinmont 28 | n02114712 28 red_wolf 29 | n02105641 29 Old_English_sheepdog 30 | n02128925 30 jaguar 31 | n02091635 31 otterhound 32 | n02088466 32 bloodhound 33 | n02096051 33 Airedale 34 | n02117135 34 hyena 35 | n02138441 35 meerkat 36 | n02097130 36 giant_schnauzer 37 | n02493509 37 titi 38 | n02457408 38 three-toed_sloth 39 | n02389026 39 sorrel 40 | n02443484 40 black-footed_ferret 41 | n02110341 41 dalmatian 42 | n02089078 42 black-and-tan_coonhound 43 | n02086910 43 papillon 44 | n02445715 44 skunk 45 | n02093256 45 Staffordshire_bullterrier 46 | n02113978 46 Mexican_hairless 47 | n02106382 47 Bouvier_des_Flandres 48 | n02441942 48 weasel 49 | n02113712 49 miniature_poodle 50 | n02113186 50 Cardigan 51 | n02105162 51 malinois 52 | n02415577 52 bighorn 53 | n02356798 53 fox_squirrel 54 | n02488702 54 colobus 55 | n02123159 55 tiger_cat 56 | n02098413 56 Lhasa 57 | n02422699 57 impala 58 | n02114855 58 coyote 59 | n02094433 59 Yorkshire_terrier 60 | n02111277 60 Newfoundland 61 | n02132136 61 brown_bear 62 | n02119022 62 red_fox 63 | n02091467 63 Norwegian_elkhound 64 | n02106550 64 Rottweiler 65 | n02422106 65 hartebeest 66 | n02091831 66 Saluki 67 | n02120505 67 grey_fox 68 | n02104365 68 schipperke 69 | n02086079 69 Pekinese 70 | n02112706 70 Brabancon_griffon 71 | n02098286 71 West_Highland_white_terrier 72 | n02095889 72 Sealyham_terrier 73 | n02484975 73 guenon 74 | n02137549 74 mongoose 75 | n02500267 75 indri 76 | n02129604 76 tiger 77 | n02090721 77 Irish_wolfhound 78 | n02396427 78 wild_boar 79 | n02108000 79 EntleBucher 80 | n02391049 80 zebra 81 | n02412080 81 ram 82 | n02108915 82 French_bulldog 83 | n02480495 83 orangutan 84 | n02110806 84 basenji 85 | n02128385 85 leopard 86 | n02107683 86 Bernese_mountain_dog 87 | n02085936 87 Maltese_dog 88 | n02094114 88 Norfolk_terrier 89 | n02087046 89 toy_terrier 90 | n02100583 90 vizsla 91 | n02096177 91 cairn 92 | n02494079 92 squirrel_monkey 93 | n02105056 93 groenendael 94 | n02101556 94 clumber 95 | n02123597 95 Siamese_cat 96 | n02481823 96 chimpanzee 97 | n02105505 97 komondor 98 | n02088094 98 Afghan_hound 99 | n02085782 99 Japanese_spaniel 100 | n02489166 100 proboscis_monkey 101 | n02364673 101 guinea_pig 102 | n02114548 102 white_wolf 103 | n02134084 103 ice_bear 104 | n02480855 104 gorilla 105 | n02090622 105 borzoi 106 | n02113624 106 toy_poodle 107 | n02093859 107 Kerry_blue_terrier 108 | n02403003 108 ox 109 | n02097298 109 Scotch_terrier 110 | n02108551 110 Tibetan_mastiff 111 | n02493793 111 spider_monkey 112 | n02107142 112 Doberman 113 | n02096585 113 Boston_bull 114 | n02107574 114 Greater_Swiss_Mountain_dog 115 | n02107908 115 Appenzeller 116 | n02086240 116 Shih-Tzu 117 | n02102973 117 Irish_water_spaniel 118 | n02112018 118 Pomeranian 119 | n02093647 119 Bedlington_terrier 120 | n02397096 120 warthog 121 | n02437312 121 Arabian_camel 122 | n02483708 122 siamang 123 | n02097047 123 miniature_schnauzer 124 | n02106030 124 collie 125 | n02099601 125 golden_retriever 126 | n02093991 126 Irish_terrier 127 | n02110627 127 affenpinscher 128 | n02106166 128 Border_collie 129 | n02326432 129 hare 130 | n02108089 130 boxer 131 | n02097658 131 silky_terrier 132 | n02088364 132 beagle 133 | n02111129 133 Leonberg 134 | n02100236 134 German_short-haired_pointer 135 | n02486261 135 patas 136 | n02115913 136 dhole 137 | n02486410 137 baboon 138 | n02487347 138 macaque 139 | n02099849 139 Chesapeake_Bay_retriever 140 | n02108422 140 bull_mastiff 141 | n02104029 141 kuvasz 142 | n02492035 142 capuchin 143 | n02110958 143 pug 144 | n02099429 144 curly-coated_retriever 145 | n02094258 145 Norwich_terrier 146 | n02099267 146 flat-coated_retriever 147 | n02395406 147 hog 148 | n02112350 148 keeshond 149 | n02109961 149 Eskimo_dog 150 | n02101388 150 Brittany_spaniel 151 | n02113799 151 standard_poodle 152 | n02095570 152 Lakeland_terrier 153 | n02128757 153 snow_leopard 154 | n02101006 154 Gordon_setter 155 | n02115641 155 dingo 156 | n02097209 156 standard_schnauzer 157 | n02342885 157 hamster 158 | n02097474 158 Tibetan_terrier 159 | n02120079 159 Arctic_fox 160 | n02095314 160 wire-haired_fox_terrier 161 | n02088238 161 basset 162 | n02408429 162 water_buffalo 163 | n02133161 163 American_black_bear 164 | n02328150 164 Angora 165 | n02410509 165 bison 166 | n02492660 166 howler_monkey 167 | n02398521 167 hippopotamus 168 | n02112137 168 chow 169 | n02510455 169 giant_panda 170 | n02093428 170 American_Staffordshire_terrier 171 | n02105855 171 Shetland_sheepdog 172 | n02111500 172 Great_Pyrenees 173 | n02085620 173 Chihuahua 174 | n02123045 174 tabby 175 | n02490219 175 marmoset 176 | n02099712 176 Labrador_retriever 177 | n02109525 177 Saint_Bernard 178 | n02454379 178 armadillo 179 | n02111889 179 Samoyed 180 | n02088632 180 bluetick 181 | n02090379 181 redbone 182 | n02443114 182 polecat 183 | n02361337 183 marmot 184 | n02105412 184 kelpie 185 | n02483362 185 gibbon 186 | n02437616 186 llama 187 | n02107312 187 miniature_pinscher 188 | n02325366 188 wood_rabbit 189 | n02091032 189 Italian_greyhound 190 | n02129165 190 lion 191 | n02102318 191 cocker_spaniel 192 | n02100877 192 Irish_setter 193 | n02074367 193 dugong 194 | n02504013 194 Indian_elephant 195 | n02363005 195 beaver 196 | n02102480 196 Sussex_spaniel 197 | n02113023 197 Pembroke 198 | n02086646 198 Blenheim_spaniel 199 | n02497673 199 Madagascar_cat 200 | n02087394 200 Rhodesian_ridgeback 201 | n02127052 201 lynx 202 | n02116738 202 African_hunting_dog 203 | n02488291 203 langur 204 | n02091244 204 Ibizan_hound 205 | n02114367 205 timber_wolf 206 | n02130308 206 cheetah 207 | n02089973 207 English_foxhound 208 | n02105251 208 briard 209 | n02134418 209 sloth_bear 210 | n02093754 210 Border_terrier 211 | n02106662 211 German_shepherd 212 | n02444819 212 otter 213 | n01882714 213 koala 214 | n01871265 214 tusker 215 | n01872401 215 echidna 216 | n01877812 216 wallaby 217 | n01873310 217 platypus 218 | n01883070 218 wombat 219 | n04086273 219 revolver 220 | n04507155 220 umbrella 221 | n04147183 221 schooner 222 | n04254680 222 soccer_ball 223 | n02672831 223 accordion 224 | n02219486 224 ant 225 | n02317335 225 starfish 226 | n01968897 226 chambered_nautilus 227 | n03452741 227 grand_piano 228 | n03642806 228 laptop 229 | n07745940 229 strawberry 230 | n02690373 230 airliner 231 | n04552348 231 warplane 232 | n02692877 232 airship 233 | n02782093 233 balloon 234 | n04266014 234 space_shuttle 235 | n03344393 235 fireboat 236 | n03447447 236 gondola 237 | n04273569 237 speedboat 238 | n03662601 238 lifeboat 239 | n02951358 239 canoe 240 | n04612504 240 yawl 241 | n02981792 241 catamaran 242 | n04483307 242 trimaran 243 | n03095699 243 container_ship 244 | n03673027 244 liner 245 | n03947888 245 pirate 246 | n02687172 246 aircraft_carrier 247 | n04347754 247 submarine 248 | n04606251 248 wreck 249 | n03478589 249 half_track 250 | n04389033 250 tank 251 | n03773504 251 missile 252 | n02860847 252 bobsled 253 | n03218198 253 dogsled 254 | n02835271 254 bicycle-built-for-two 255 | n03792782 255 mountain_bike 256 | n03393912 256 freight_car 257 | n03895866 257 passenger_car 258 | n02797295 258 barrow 259 | n04204347 259 shopping_cart 260 | n03791053 260 motor_scooter 261 | n03384352 261 forklift 262 | n03272562 262 electric_locomotive 263 | n04310018 263 steam_locomotive 264 | n02704792 264 amphibian 265 | n02701002 265 ambulance 266 | n02814533 266 beach_wagon 267 | n02930766 267 cab 268 | n03100240 268 convertible 269 | n03594945 269 jeep 270 | n03670208 270 limousine 271 | n03770679 271 minivan 272 | n03777568 272 Model_T 273 | n04037443 273 racer 274 | n04285008 274 sports_car 275 | n03444034 275 go-kart 276 | n03445924 276 golfcart 277 | n03785016 277 moped 278 | n04252225 278 snowplow 279 | n03345487 279 fire_engine 280 | n03417042 280 garbage_truck 281 | n03930630 281 pickup 282 | n04461696 282 tow_truck 283 | n04467665 283 trailer_truck 284 | n03796401 284 moving_van 285 | n03977966 285 police_van 286 | n04065272 286 recreational_vehicle 287 | n04335435 287 streetcar 288 | n04252077 288 snowmobile 289 | n04465501 289 tractor 290 | n03776460 290 mobile_home 291 | n04482393 291 tricycle 292 | n04509417 292 unicycle 293 | n03538406 293 horse_cart 294 | n03599486 294 jinrikisha 295 | n03868242 295 oxcart 296 | n02804414 296 bassinet 297 | n03125729 297 cradle 298 | n03131574 298 crib 299 | n03388549 299 four-poster 300 | n02870880 300 bookcase 301 | n03018349 301 china_cabinet 302 | n03742115 302 medicine_chest 303 | n03016953 303 chiffonier 304 | n04380533 304 table_lamp 305 | n03337140 305 file 306 | n03891251 306 park_bench 307 | n02791124 307 barber_chair 308 | n04429376 308 throne 309 | n03376595 309 folding_chair 310 | n04099969 310 rocking_chair 311 | n04344873 311 studio_couch 312 | n04447861 312 toilet_seat 313 | n03179701 313 desk 314 | n03982430 314 pool_table 315 | n03201208 315 dining_table 316 | n03290653 316 entertainment_center 317 | n04550184 317 wardrobe 318 | n07742313 318 Granny_Smith 319 | n07747607 319 orange 320 | n07749582 320 lemon 321 | n07753113 321 fig 322 | n07753275 322 pineapple 323 | n07753592 323 banana 324 | n07754684 324 jackfruit 325 | n07760859 325 custard_apple 326 | n07768694 326 pomegranate 327 | n12267677 327 acorn 328 | n12620546 328 hip 329 | n13133613 329 ear 330 | n11879895 330 rapeseed 331 | n12144580 331 corn 332 | n12768682 332 buckeye 333 | n03854065 333 organ 334 | n04515003 334 upright 335 | n03017168 335 chime 336 | n03249569 336 drum 337 | n03447721 337 gong 338 | n03720891 338 maraca 339 | n03721384 339 marimba 340 | n04311174 340 steel_drum 341 | n02787622 341 banjo 342 | n02992211 342 cello 343 | n04536866 343 violin 344 | n03495258 344 harp 345 | n02676566 345 acoustic_guitar 346 | n03272010 346 electric_guitar 347 | n03110669 347 cornet 348 | n03394916 348 French_horn 349 | n04487394 349 trombone 350 | n03494278 350 harmonica 351 | n03840681 351 ocarina 352 | n03884397 352 panpipe 353 | n02804610 353 bassoon 354 | n03838899 354 oboe 355 | n04141076 355 sax 356 | n03372029 356 flute 357 | n11939491 357 daisy 358 | n12057211 358 yellow_lady's_slipper 359 | n09246464 359 cliff 360 | n09468604 360 valley 361 | n09193705 361 alp 362 | n09472597 362 volcano 363 | n09399592 363 promontory 364 | n09421951 364 sandbar 365 | n09256479 365 coral_reef 366 | n09332890 366 lakeside 367 | n09428293 367 seashore 368 | n09288635 368 geyser 369 | n03498962 369 hatchet 370 | n03041632 370 cleaver 371 | n03658185 371 letter_opener 372 | n03954731 372 plane 373 | n03995372 373 power_drill 374 | n03649909 374 lawn_mower 375 | n03481172 375 hammer 376 | n03109150 376 corkscrew 377 | n02951585 377 can_opener 378 | n03970156 378 plunger 379 | n04154565 379 screwdriver 380 | n04208210 380 shovel 381 | n03967562 381 plow 382 | n03000684 382 chain_saw 383 | n01514668 383 cock 384 | n01514859 384 hen 385 | n01518878 385 ostrich 386 | n01530575 386 brambling 387 | n01531178 387 goldfinch 388 | n01532829 388 house_finch 389 | n01534433 389 junco 390 | n01537544 390 indigo_bunting 391 | n01558993 391 robin 392 | n01560419 392 bulbul 393 | n01580077 393 jay 394 | n01582220 394 magpie 395 | n01592084 395 chickadee 396 | n01601694 396 water_ouzel 397 | n01608432 397 kite 398 | n01614925 398 bald_eagle 399 | n01616318 399 vulture 400 | n01622779 400 great_grey_owl 401 | n01795545 401 black_grouse 402 | n01796340 402 ptarmigan 403 | n01797886 403 ruffed_grouse 404 | n01798484 404 prairie_chicken 405 | n01806143 405 peacock 406 | n01806567 406 quail 407 | n01807496 407 partridge 408 | n01817953 408 African_grey 409 | n01818515 409 macaw 410 | n01819313 410 sulphur-crested_cockatoo 411 | n01820546 411 lorikeet 412 | n01824575 412 coucal 413 | n01828970 413 bee_eater 414 | n01829413 414 hornbill 415 | n01833805 415 hummingbird 416 | n01843065 416 jacamar 417 | n01843383 417 toucan 418 | n01847000 418 drake 419 | n01855032 419 red-breasted_merganser 420 | n01855672 420 goose 421 | n01860187 421 black_swan 422 | n02002556 422 white_stork 423 | n02002724 423 black_stork 424 | n02006656 424 spoonbill 425 | n02007558 425 flamingo 426 | n02009912 426 American_egret 427 | n02009229 427 little_blue_heron 428 | n02011460 428 bittern 429 | n02012849 429 crane 430 | n02013706 430 limpkin 431 | n02018207 431 American_coot 432 | n02018795 432 bustard 433 | n02025239 433 ruddy_turnstone 434 | n02027492 434 red-backed_sandpiper 435 | n02028035 435 redshank 436 | n02033041 436 dowitcher 437 | n02037110 437 oystercatcher 438 | n02017213 438 European_gallinule 439 | n02051845 439 pelican 440 | n02056570 440 king_penguin 441 | n02058221 441 albatross 442 | n01484850 442 great_white_shark 443 | n01491361 443 tiger_shark 444 | n01494475 444 hammerhead 445 | n01496331 445 electric_ray 446 | n01498041 446 stingray 447 | n02514041 447 barracouta 448 | n02536864 448 coho 449 | n01440764 449 tench 450 | n01443537 450 goldfish 451 | n02526121 451 eel 452 | n02606052 452 rock_beauty 453 | n02607072 453 anemone_fish 454 | n02643566 454 lionfish 455 | n02655020 455 puffer 456 | n02640242 456 sturgeon 457 | n02641379 457 gar 458 | n01664065 458 loggerhead 459 | n01665541 459 leatherback_turtle 460 | n01667114 460 mud_turtle 461 | n01667778 461 terrapin 462 | n01669191 462 box_turtle 463 | n01675722 463 banded_gecko 464 | n01677366 464 common_iguana 465 | n01682714 465 American_chameleon 466 | n01685808 466 whiptail 467 | n01687978 467 agama 468 | n01688243 468 frilled_lizard 469 | n01689811 469 alligator_lizard 470 | n01692333 470 Gila_monster 471 | n01693334 471 green_lizard 472 | n01694178 472 African_chameleon 473 | n01695060 473 Komodo_dragon 474 | n01704323 474 triceratops 475 | n01697457 475 African_crocodile 476 | n01698640 476 American_alligator 477 | n01728572 477 thunder_snake 478 | n01728920 478 ringneck_snake 479 | n01729322 479 hognose_snake 480 | n01729977 480 green_snake 481 | n01734418 481 king_snake 482 | n01735189 482 garter_snake 483 | n01737021 483 water_snake 484 | n01739381 484 vine_snake 485 | n01740131 485 night_snake 486 | n01742172 486 boa_constrictor 487 | n01744401 487 rock_python 488 | n01748264 488 Indian_cobra 489 | n01749939 489 green_mamba 490 | n01751748 490 sea_snake 491 | n01753488 491 horned_viper 492 | n01755581 492 diamondback 493 | n01756291 493 sidewinder 494 | n01629819 494 European_fire_salamander 495 | n01630670 495 common_newt 496 | n01631663 496 eft 497 | n01632458 497 spotted_salamander 498 | n01632777 498 axolotl 499 | n01641577 499 bullfrog 500 | n01644373 500 tree_frog 501 | n01644900 501 tailed_frog 502 | n04579432 502 whistle 503 | n04592741 503 wing 504 | n03876231 504 paintbrush 505 | n03483316 505 hand_blower 506 | n03868863 506 oxygen_mask 507 | n04251144 507 snorkel 508 | n03691459 508 loudspeaker 509 | n03759954 509 microphone 510 | n04152593 510 screen 511 | n03793489 511 mouse 512 | n03271574 512 electric_fan 513 | n03843555 513 oil_filter 514 | n04332243 514 strainer 515 | n04265275 515 space_heater 516 | n04330267 516 stove 517 | n03467068 517 guillotine 518 | n02794156 518 barometer 519 | n04118776 519 rule 520 | n03841143 520 odometer 521 | n04141975 521 scale 522 | n02708093 522 analog_clock 523 | n03196217 523 digital_clock 524 | n04548280 524 wall_clock 525 | n03544143 525 hourglass 526 | n04355338 526 sundial 527 | n03891332 527 parking_meter 528 | n04328186 528 stopwatch 529 | n03197337 529 digital_watch 530 | n04317175 530 stethoscope 531 | n04376876 531 syringe 532 | n03706229 532 magnetic_compass 533 | n02841315 533 binoculars 534 | n04009552 534 projector 535 | n04356056 535 sunglasses 536 | n03692522 536 loupe 537 | n04044716 537 radio_telescope 538 | n02879718 538 bow 539 | n02950826 539 cannon 540 | n02749479 540 assault_rifle 541 | n04090263 541 rifle 542 | n04008634 542 projectile 543 | n03085013 543 computer_keyboard 544 | n04505470 544 typewriter_keyboard 545 | n03126707 545 crane 546 | n03666591 546 lighter 547 | n02666196 547 abacus 548 | n02977058 548 cash_machine 549 | n04238763 549 slide_rule 550 | n03180011 550 desktop_computer 551 | n03485407 551 hand-held_computer 552 | n03832673 552 notebook 553 | n06359193 553 web_site 554 | n03496892 554 harvester 555 | n04428191 555 thresher 556 | n04004767 556 printer 557 | n04243546 557 slot 558 | n04525305 558 vending_machine 559 | n04179913 559 sewing_machine 560 | n03602883 560 joystick 561 | n04372370 561 switch 562 | n03532672 562 hook 563 | n02974003 563 car_wheel 564 | n03874293 564 paddlewheel 565 | n03944341 565 pinwheel 566 | n03992509 566 potter's_wheel 567 | n03425413 567 gas_pump 568 | n02966193 568 carousel 569 | n04371774 569 swing 570 | n04067472 570 reel 571 | n04040759 571 radiator 572 | n04019541 572 puck 573 | n03492542 573 hard_disc 574 | n04355933 574 sunglass 575 | n03929660 575 pick 576 | n02965783 576 car_mirror 577 | n04258138 577 solar_dish 578 | n04074963 578 remote_control 579 | n03208938 579 disk_brake 580 | n02910353 580 buckle 581 | n03476684 581 hair_slide 582 | n03627232 582 knot 583 | n03075370 583 combination_lock 584 | n03874599 584 padlock 585 | n03804744 585 nail 586 | n04127249 586 safety_pin 587 | n04153751 587 screw 588 | n03803284 588 muzzle 589 | n04162706 589 seat_belt 590 | n04228054 590 ski 591 | n02948072 591 candle 592 | n03590841 592 jack-o'-lantern 593 | n04286575 593 spotlight 594 | n04456115 594 torch 595 | n03814639 595 neck_brace 596 | n03933933 596 pier 597 | n04485082 597 tripod 598 | n03733131 598 maypole 599 | n03794056 599 mousetrap 600 | n04275548 600 spider_web 601 | n01768244 601 trilobite 602 | n01770081 602 harvestman 603 | n01770393 603 scorpion 604 | n01773157 604 black_and_gold_garden_spider 605 | n01773549 605 barn_spider 606 | n01773797 606 garden_spider 607 | n01774384 607 black_widow 608 | n01774750 608 tarantula 609 | n01775062 609 wolf_spider 610 | n01776313 610 tick 611 | n01784675 611 centipede 612 | n01990800 612 isopod 613 | n01978287 613 Dungeness_crab 614 | n01978455 614 rock_crab 615 | n01980166 615 fiddler_crab 616 | n01981276 616 king_crab 617 | n01983481 617 American_lobster 618 | n01984695 618 spiny_lobster 619 | n01985128 619 crayfish 620 | n01986214 620 hermit_crab 621 | n02165105 621 tiger_beetle 622 | n02165456 622 ladybug 623 | n02167151 623 ground_beetle 624 | n02168699 624 long-horned_beetle 625 | n02169497 625 leaf_beetle 626 | n02172182 626 dung_beetle 627 | n02174001 627 rhinoceros_beetle 628 | n02177972 628 weevil 629 | n02190166 629 fly 630 | n02206856 630 bee 631 | n02226429 631 grasshopper 632 | n02229544 632 cricket 633 | n02231487 633 walking_stick 634 | n02233338 634 cockroach 635 | n02236044 635 mantis 636 | n02256656 636 cicada 637 | n02259212 637 leafhopper 638 | n02264363 638 lacewing 639 | n02268443 639 dragonfly 640 | n02268853 640 damselfly 641 | n02276258 641 admiral 642 | n02277742 642 ringlet 643 | n02279972 643 monarch 644 | n02280649 644 cabbage_butterfly 645 | n02281406 645 sulphur_butterfly 646 | n02281787 646 lycaenid 647 | n01910747 647 jellyfish 648 | n01914609 648 sea_anemone 649 | n01917289 649 brain_coral 650 | n01924916 650 flatworm 651 | n01930112 651 nematode 652 | n01943899 652 conch 653 | n01944390 653 snail 654 | n01945685 654 slug 655 | n01950731 655 sea_slug 656 | n01955084 656 chiton 657 | n02319095 657 sea_urchin 658 | n02321529 658 sea_cucumber 659 | n03584829 659 iron 660 | n03297495 660 espresso_maker 661 | n03761084 661 microwave 662 | n03259280 662 Dutch_oven 663 | n04111531 663 rotisserie 664 | n04442312 664 toaster 665 | n04542943 665 waffle_iron 666 | n04517823 666 vacuum 667 | n03207941 667 dishwasher 668 | n04070727 668 refrigerator 669 | n04554684 669 washer 670 | n03133878 670 Crock_Pot 671 | n03400231 671 frying_pan 672 | n04596742 672 wok 673 | n02939185 673 caldron 674 | n03063689 674 coffeepot 675 | n04398044 675 teapot 676 | n04270147 676 spatula 677 | n02699494 677 altar 678 | n04486054 678 triumphal_arch 679 | n03899768 679 patio 680 | n04311004 680 steel_arch_bridge 681 | n04366367 681 suspension_bridge 682 | n04532670 682 viaduct 683 | n02793495 683 barn 684 | n03457902 684 greenhouse 685 | n03877845 685 palace 686 | n03781244 686 monastery 687 | n03661043 687 library 688 | n02727426 688 apiary 689 | n02859443 689 boathouse 690 | n03028079 690 church 691 | n03788195 691 mosque 692 | n04346328 692 stupa 693 | n03956157 693 planetarium 694 | n04081281 694 restaurant 695 | n03032252 695 cinema 696 | n03529860 696 home_theater 697 | n03697007 697 lumbermill 698 | n03065424 698 coil 699 | n03837869 699 obelisk 700 | n04458633 700 totem_pole 701 | n02980441 701 castle 702 | n04005630 702 prison 703 | n03461385 703 grocery_store 704 | n02776631 704 bakery 705 | n02791270 705 barbershop 706 | n02871525 706 bookshop 707 | n02927161 707 butcher_shop 708 | n03089624 708 confectionery 709 | n04200800 709 shoe_shop 710 | n04443257 710 tobacco_shop 711 | n04462240 711 toyshop 712 | n03388043 712 fountain 713 | n03042490 713 cliff_dwelling 714 | n04613696 714 yurt 715 | n03216828 715 dock 716 | n02892201 716 brass 717 | n03743016 717 megalith 718 | n02788148 718 bannister 719 | n02894605 719 breakwater 720 | n03160309 720 dam 721 | n03000134 721 chainlink_fence 722 | n03930313 722 picket_fence 723 | n04604644 723 worm_fence 724 | n04326547 724 stone_wall 725 | n03459775 725 grille 726 | n04239074 726 sliding_door 727 | n04501370 727 turnstile 728 | n03792972 728 mountain_tent 729 | n04149813 729 scoreboard 730 | n03530642 730 honeycomb 731 | n03961711 731 plate_rack 732 | n03903868 732 pedestal 733 | n02814860 733 beacon 734 | n07711569 734 mashed_potato 735 | n07720875 735 bell_pepper 736 | n07714571 736 head_cabbage 737 | n07714990 737 broccoli 738 | n07715103 738 cauliflower 739 | n07716358 739 zucchini 740 | n07716906 740 spaghetti_squash 741 | n07717410 741 acorn_squash 742 | n07717556 742 butternut_squash 743 | n07718472 743 cucumber 744 | n07718747 744 artichoke 745 | n07730033 745 cardoon 746 | n07734744 746 mushroom 747 | n04209239 747 shower_curtain 748 | n03594734 748 jean 749 | n02971356 749 carton 750 | n03485794 750 handkerchief 751 | n04133789 751 sandal 752 | n02747177 752 ashcan 753 | n04125021 753 safe 754 | n07579787 754 plate 755 | n03814906 755 necklace 756 | n03134739 756 croquet_ball 757 | n03404251 757 fur_coat 758 | n04423845 758 thimble 759 | n03877472 759 pajama 760 | n04120489 760 running_shoe 761 | n03062245 761 cocktail_shaker 762 | n03014705 762 chest 763 | n03717622 763 manhole_cover 764 | n03777754 764 modem 765 | n04493381 765 tub 766 | n04476259 766 tray 767 | n02777292 767 balance_beam 768 | n07693725 768 bagel 769 | n03998194 769 prayer_rug 770 | n03617480 770 kimono 771 | n07590611 771 hot_pot 772 | n04579145 772 whiskey_jug 773 | n03623198 773 knee_pad 774 | n07248320 774 book_jacket 775 | n04277352 775 spindle 776 | n04229816 776 ski_mask 777 | n02823428 777 beer_bottle 778 | n03127747 778 crash_helmet 779 | n02877765 779 bottlecap 780 | n04435653 780 tile_roof 781 | n03724870 781 mask 782 | n03710637 782 maillot 783 | n03920288 783 Petri_dish 784 | n03379051 784 football_helmet 785 | n02807133 785 bathing_cap 786 | n04399382 786 teddy 787 | n03527444 787 holster 788 | n03983396 788 pop_bottle 789 | n03924679 789 photocopier 790 | n04532106 790 vestment 791 | n06785654 791 crossword_puzzle 792 | n03445777 792 golf_ball 793 | n07613480 793 trifle 794 | n04350905 794 suit 795 | n04562935 795 water_tower 796 | n03325584 796 feather_boa 797 | n03045698 797 cloak 798 | n07892512 798 red_wine 799 | n03250847 799 drumstick 800 | n04192698 800 shield 801 | n03026506 801 Christmas_stocking 802 | n03534580 802 hoopskirt 803 | n07565083 803 menu 804 | n04296562 804 stage 805 | n02869837 805 bonnet 806 | n07871810 806 meat_loaf 807 | n02799071 807 baseball 808 | n03314780 808 face_powder 809 | n04141327 809 scabbard 810 | n04357314 810 sunscreen 811 | n02823750 811 beer_glass 812 | n13052670 812 hen-of-the-woods 813 | n07583066 813 guacamole 814 | n03637318 814 lampshade 815 | n04599235 815 wool 816 | n07802026 816 hay 817 | n02883205 817 bow_tie 818 | n03709823 818 mailbag 819 | n04560804 819 water_jug 820 | n02909870 820 bucket 821 | n03207743 821 dishrag 822 | n04263257 822 soup_bowl 823 | n07932039 823 eggnog 824 | n03786901 824 mortar 825 | n04479046 825 trench_coat 826 | n03873416 826 paddle 827 | n02999410 827 chain 828 | n04367480 828 swab 829 | n03775546 829 mixing_bowl 830 | n07875152 830 potpie 831 | n04591713 831 wine_bottle 832 | n04201297 832 shoji 833 | n02916936 833 bulletproof_vest 834 | n03240683 834 drilling_platform 835 | n02840245 835 binder 836 | n02963159 836 cardigan 837 | n04370456 837 sweatshirt 838 | n03991062 838 pot 839 | n02843684 839 birdhouse 840 | n03482405 840 hamper 841 | n03942813 841 ping-pong_ball 842 | n03908618 842 pencil_box 843 | n03902125 843 pay-phone 844 | n07584110 844 consomme 845 | n02730930 845 apron 846 | n04023962 846 punching_bag 847 | n02769748 847 backpack 848 | n10148035 848 groom 849 | n02817516 849 bearskin 850 | n03908714 850 pencil_sharpener 851 | n02906734 851 broom 852 | n03788365 852 mosquito_net 853 | n02667093 853 abaya 854 | n03787032 854 mortarboard 855 | n03980874 855 poncho 856 | n03141823 856 crutch 857 | n03976467 857 Polaroid_camera 858 | n04264628 858 space_bar 859 | n07930864 859 cup 860 | n04039381 860 racket 861 | n06874185 861 traffic_light 862 | n04033901 862 quill 863 | n04041544 863 radio 864 | n07860988 864 dough 865 | n03146219 865 cuirass 866 | n03763968 866 military_uniform 867 | n03676483 867 lipstick 868 | n04209133 868 shower_cap 869 | n03782006 869 monitor 870 | n03857828 870 oscilloscope 871 | n03775071 871 mitten 872 | n02892767 872 brassiere 873 | n07684084 873 French_loaf 874 | n04522168 874 vase 875 | n03764736 875 milk_can 876 | n04118538 876 rugby_ball 877 | n03887697 877 paper_towel 878 | n13044778 878 earthstar 879 | n03291819 879 envelope 880 | n03770439 880 miniskirt 881 | n03124170 881 cowboy_hat 882 | n04487081 882 trolleybus 883 | n03916031 883 perfume 884 | n02808440 884 bathtub 885 | n07697537 885 hotdog 886 | n12985857 886 coral_fungus 887 | n02917067 887 bullet_train 888 | n03938244 888 pillow 889 | n15075141 889 toilet_tissue 890 | n02978881 890 cassette 891 | n02966687 891 carpenter's_kit 892 | n03633091 892 ladle 893 | n13040303 893 stinkhorn 894 | n03690938 894 lotion 895 | n03476991 895 hair_spray 896 | n02669723 896 academic_gown 897 | n03220513 897 dome 898 | n03127925 898 crate 899 | n04584207 899 wig 900 | n07880968 900 burrito 901 | n03937543 901 pill_bottle 902 | n03000247 902 chain_mail 903 | n04418357 903 theater_curtain 904 | n04590129 904 window_shade 905 | n02795169 905 barrel 906 | n04553703 906 washbasin 907 | n02783161 907 ballpoint 908 | n02802426 908 basketball 909 | n02808304 909 bath_towel 910 | n03124043 910 cowboy_boot 911 | n03450230 911 gown 912 | n04589890 912 window_screen 913 | n12998815 913 agaric 914 | n02992529 914 cellular_telephone 915 | n03825788 915 nipple 916 | n02790996 916 barbell 917 | n03710193 917 mailbox 918 | n03630383 918 lab_coat 919 | n03347037 919 fire_screen 920 | n03769881 920 minibus 921 | n03871628 921 packet 922 | n03733281 922 maze 923 | n03976657 923 pole 924 | n03535780 924 horizontal_bar 925 | n04259630 925 sombrero 926 | n03929855 926 pickelhaube 927 | n04049303 927 rain_barrel 928 | n04548362 928 wallet 929 | n02979186 929 cassette_player 930 | n06596364 930 comic_book 931 | n03935335 931 piggy_bank 932 | n06794110 932 street_sign 933 | n02825657 933 bell_cote 934 | n03388183 934 fountain_pen 935 | n04591157 935 Windsor_tie 936 | n04540053 936 volleyball 937 | n03866082 937 overskirt 938 | n04136333 938 sarong 939 | n04026417 939 purse 940 | n02865351 940 bolo_tie 941 | n02834397 941 bib 942 | n03888257 942 parachute 943 | n04235860 943 sleeping_bag 944 | n04404412 944 television 945 | n04371430 945 swimming_trunks 946 | n03733805 946 measuring_cup 947 | n07920052 947 espresso 948 | n07873807 948 pizza 949 | n02895154 949 breastplate 950 | n04204238 950 shopping_basket 951 | n04597913 951 wooden_spoon 952 | n04131690 952 saltshaker 953 | n07836838 953 chocolate_sauce 954 | n09835506 954 ballplayer 955 | n03443371 955 goblet 956 | n13037406 956 gyromitra 957 | n04336792 957 stretcher 958 | n04557648 958 water_bottle 959 | n03187595 959 dial_telephone 960 | n04254120 960 soap_dispenser 961 | n03595614 961 jersey 962 | n04146614 962 school_bus 963 | n03598930 963 jigsaw_puzzle 964 | n03958227 964 plastic_bag 965 | n04069434 965 reflex_camera 966 | n03188531 966 diaper 967 | n02786058 967 Band_Aid 968 | n07615774 968 ice_lolly 969 | n04525038 969 velvet 970 | n04409515 970 tennis_ball 971 | n03424325 971 gasmask 972 | n03223299 972 doormat 973 | n03680355 973 Loafer 974 | n07614500 974 ice_cream 975 | n07695742 975 pretzel 976 | n04033995 976 quilt 977 | n03710721 977 maillot 978 | n04392985 978 tape_player 979 | n03047690 979 clog 980 | n03584254 980 iPod 981 | n13054560 981 bolete 982 | n10565667 982 scuba_diver 983 | n03950228 983 pitcher 984 | n03729826 984 matchstick 985 | n02837789 985 bikini 986 | n04254777 986 sock 987 | n02988304 987 CD_player 988 | n03657121 988 lens_cap 989 | n04417672 989 thatch 990 | n04523525 990 vault 991 | n02815834 991 beaker 992 | n09229709 992 bubble 993 | n07697313 993 cheeseburger 994 | n03888605 994 parallel_bars 995 | n03355925 995 flagpole 996 | n03063599 996 coffee_mug 997 | n04116512 997 rubber_eraser 998 | n04325704 998 stole 999 | n07831146 999 carbonara 1000 | n03255030 1000 dumbbell -------------------------------------------------------------------------------- /support/readme.md: -------------------------------------------------------------------------------- 1 | Files `tinyimagenet_refined_labels.json` and `imagenet_refined_labels.json` are used to specify train/val splits when we re-implement adversarial finetuning (AFT) methods. We re-implemented three AFT methods (TeCoA, PMG-AFT and FARE) in the paper on the tinyImageNet dataset. For each class, we randomly sample 10% of the instances in the training set for evaluation in the phase of finetuning. In the `.json` file, each class (identified with a synset id) has the following attributes: 2 | ``` 3 | { 4 | synset identifier (e.g., 'n01443537'): 5 | { 6 | "clean_name": cleansed textual name (e.g., 'goldfish'), 7 | "wordnet_def": definition given by WordNet (e.g., 'small golden or orange-red freshwater fishes of Eurasia used as pond or aquarium fishes'), 8 | "eval_files": list of instances selected as evaluation data (e.g., ["n01443537_266.JPEG", "n01443537_75.JPEG", ...]) 9 | }, 10 | } 11 | ``` 12 | 13 | We also provide the cleansed textual name for each class and its definition given by [WordNet](https://wordnet.princeton.edu/). 14 | --------------------------------------------------------------------------------