├── pkgs ├── __init__.py └── openai │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ └── tokenizer.py ├── src ├── __init__.py ├── scheduler.py ├── logger.py ├── parser.py ├── main.py ├── train.py ├── evaluate.py └── data.py ├── utils ├── __init__.py ├── config.py ├── summary.py ├── asr_hyper.py ├── mscoco_data_creation.py ├── hypersphere_captions.py ├── augment_text.py ├── prepare_sbucaptions.py ├── superclass_distribution_plot_supplementary.py ├── datasize_ablation_plot.py ├── hypersphere_labels.py ├── zeroshot.py ├── linear_probe_plot.py ├── augment_image.py ├── flickr30k_data_creation.py ├── effective_robustness_plot.py ├── fine_coarse_plot.py ├── download.py ├── fine_coarse.py ├── fine_coarse_plot_supplementary.py ├── linear_probe.py ├── retrieval.py ├── eda.py └── embeddings.py ├── __init__.py ├── analysis ├── asr_blended.png ├── ca_blended.png └── labelled │ ├── tsne_blended_3m_1500.png │ ├── tsne_blended_3m_1500_sufi.png │ ├── tsne_blended_3m_1500_selfi.png │ └── tsne_blended_3m_1500_finetuning.png ├── docs └── images │ └── CleanCLIP_intro.png ├── backdoor ├── noise_grid_k=224_s=1_inputheight=224_gridrescale=1.pt ├── utils.py ├── create_backdoor_data.py ├── tsne.py ├── pca-tsne-labelled-dataset.py ├── tsne_detected_vs_undetected_vs_clean.py └── tsne-labelled.py ├── requirements.txt ├── data ├── ImageNet1K │ └── validation │ │ └── modify_images_directory.py ├── CIFAR10 │ └── test │ │ └── classes.py └── CIFAR100 │ └── test │ └── classes.py ├── environment.yml ├── .gitignore ├── LICENSE └── README.md /pkgs/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /pkgs/openai/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = os.path.dirname(os.path.abspath(__file__)) -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -------------------------------------------------------------------------------- /analysis/asr_blended.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/asr_blended.png -------------------------------------------------------------------------------- /analysis/ca_blended.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/ca_blended.png -------------------------------------------------------------------------------- /docs/images/CleanCLIP_intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/docs/images/CleanCLIP_intro.png -------------------------------------------------------------------------------- /pkgs/openai/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/pkgs/openai/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /analysis/labelled/tsne_blended_3m_1500.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/labelled/tsne_blended_3m_1500.png -------------------------------------------------------------------------------- /analysis/labelled/tsne_blended_3m_1500_sufi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/labelled/tsne_blended_3m_1500_sufi.png -------------------------------------------------------------------------------- /analysis/labelled/tsne_blended_3m_1500_selfi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/labelled/tsne_blended_3m_1500_selfi.png -------------------------------------------------------------------------------- /analysis/labelled/tsne_blended_3m_1500_finetuning.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/analysis/labelled/tsne_blended_3m_1500_finetuning.png -------------------------------------------------------------------------------- /backdoor/noise_grid_k=224_s=1_inputheight=224_gridrescale=1.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nishadsinghi/CleanCLIP/HEAD/backdoor/noise_grid_k=224_s=1_inputheight=224_gridrescale=1.pt -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations 2 | dill 3 | einops 4 | ftfy 5 | git+https://github.com/openai/CLIP.git 6 | h5py 7 | matplotlib 8 | numpy 9 | nltk 10 | omegaconf 11 | pandas 12 | pillow 13 | pip 14 | pyflakes 15 | pytorch_lightning 16 | regex 17 | scikit-image 18 | seaborn 19 | tensorflow 20 | tokenizers 21 | torch 22 | torchvision 23 | tqdm 24 | transformers 25 | typing_extensions 26 | wandb 27 | wrapt -------------------------------------------------------------------------------- /data/ImageNet1K/validation/modify_images_directory.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | from pathlib import Path 3 | import os 4 | 5 | filepath = 'labels.csv' 6 | new_directory = '/work/nsinghi/SSL-Backdoor/imagenet/val' 7 | 8 | labels = pd.read_csv(filepath) 9 | 10 | for row in range(len(labels)): 11 | imagepath = labels.loc[row, 'image'] 12 | imagename = Path(imagepath).name 13 | newimagepath = os.path.join(new_directory, imagename) 14 | labels.loc[row, 'image'] = newimagepath 15 | 16 | labels.to_csv(filepath, index=False) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: CleanCLIP 2 | channels: 3 | - defaults 4 | dependencies: 5 | - pip 6 | - pip: 7 | - albumentations 8 | - dill 9 | - einops 10 | - ftfy 11 | - git+https://github.com/openai/CLIP.git 12 | - h5py 13 | - matplotlib 14 | - numpy 15 | - nltk 16 | - omegaconf 17 | - pandas 18 | - pillow 19 | - pip 20 | - pyflakes 21 | - regex 22 | - scikit-image 23 | - seaborn 24 | - tensorflow 25 | - tokenizers 26 | - tqdm 27 | - typing_extensions 28 | - wandb 29 | - wrapt -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Build and Release Folders 2 | bin-debug/ 3 | bin-release/ 4 | [Oo]bj/ 5 | [Bb]in/ 6 | 7 | # Other files and folders 8 | .settings/ 9 | .idea 10 | .cache 11 | .cache/ 12 | wandb 13 | wandb/ 14 | logs 15 | logs/ 16 | condor_jobs 17 | *.pkl 18 | *.pickle 19 | 20 | *pycache* 21 | 22 | # Executables 23 | *.swf 24 | *.air 25 | *.ipa 26 | *.apk 27 | 28 | # Project files, i.e. `.project`, `.actionScriptProperties` and `.flexProperties` 29 | # should NOT be excluded as they contain compiler settings and other important 30 | # information for Eclipse / Flash Builder. 31 | -------------------------------------------------------------------------------- /utils/summary.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import warnings 3 | from pkgs.openai.clip import load 4 | 5 | warnings.filterwarnings("ignore") 6 | 7 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 8 | model, processor = load(name = "RN50", pretrained = False) 9 | model.to(device) 10 | 11 | print(sum(parameter.numel() for parameter in model.visual.parameters() if parameter.requires_grad)) 12 | print(sum(parameter.numel() for parameter in model.parameters() if parameter.requires_grad) - sum(parameter.numel() for parameter in model.visual.parameters() if parameter.requires_grad) - 1) -------------------------------------------------------------------------------- /src/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def cosine_scheduler(optimizer, base_lr, num_warmup_steps, total_steps): 4 | def _scheduler(current_step): 5 | if(current_step < num_warmup_steps): 6 | lr = base_lr * (current_step + 1) / num_warmup_steps 7 | else: 8 | n = current_step - num_warmup_steps 9 | d = total_steps - num_warmup_steps 10 | lr = 0.5 * (1 + np.cos(np.pi * n / d)) * base_lr 11 | 12 | for param_group in optimizer.param_groups: 13 | param_group["lr"] = lr 14 | 15 | return _scheduler -------------------------------------------------------------------------------- /utils/asr_hyper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | params = { 7 | "figure.figsize": (5, 5), 8 | "legend.fontsize": 16, 9 | "axes.labelsize": 16, 10 | "xtick.labelsize": 16, 11 | "ytick.labelsize": 16, 12 | "font.family": "Liberation Mono" 13 | } 14 | 15 | plt.rcParams.update(params) 16 | plt.style.use("seaborn-whitegrid") 17 | sns.set_style("white") 18 | 19 | fig, ax1 = plt.subplots() 20 | 21 | lambda_2 = [0.5, 1, 2, 4, 8] 22 | asr = [7.48, 7.2, 4.27, 2.89, 1.83] 23 | ca = [17.8, 18.14, 17.8, 18, 17.4] 24 | 25 | plt.plot(lambda_2, ca, marker = 'o') 26 | # ax2 = ax1.twinx() 27 | # ax1.plot(lambda_2, asr, 'g-') 28 | # ax2.plot(lambda_2, ca, 'b-') 29 | 30 | plt.title('CA of SelFi Defense (Blended Attack)') 31 | plt.xlabel(r'$\lambda_2$') 32 | plt.ylabel('Accuracy (%)') 33 | plt.grid() 34 | plt.tight_layout() 35 | plt.savefig('ca_blended.png') -------------------------------------------------------------------------------- /utils/mscoco_data_creation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | def move_files(images, split): 8 | 9 | for image in tqdm(images): 10 | current_loc = os.path.join(root, 'val2014', image) 11 | dest_loc = os.path.join(root, split) 12 | shutil.move(current_loc, dest_loc) 13 | 14 | if __name__ == "__main__": 15 | 16 | root = './data/MSCOCO' 17 | 18 | with open('./data/karpathy_splits/dataset_coco.json') as f: 19 | dataset = json.load(f) 20 | test = list(filter(lambda x: x['split'] == 'test', dataset['images'])) 21 | test_images = list(map(lambda x: x['filename'], test)) 22 | test_captions = list(map(lambda x: x['sentences'][0]['raw'], test)) 23 | # list_of_all_images = os.listdir(os.path.join(root, 'val2014')) 24 | # move_files(test_images, 'test') 25 | 26 | test_images = list(map(lambda x: f'test/{x}', test_images)) 27 | data = {'image': test_images, 28 | 'caption': test_captions} 29 | df = pd.DataFrame(data) 30 | df.to_csv(f'{root}/mscoco_test.csv') 31 | 32 | 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Hritikbansal 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /utils/hypersphere_captions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | from tqdm import tqdm 6 | 7 | def run(options): 8 | with torch.no_grad(): 9 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 10 | 11 | data = pickle.load(open(options.embeddings, "rb")) 12 | image_embeddings, text_embeddings = torch.tensor(data["image_embeddings"]).to(device), torch.tensor(data["text_embeddings"]).to(device) 13 | 14 | align_loss = (image_embeddings - text_embeddings).square().sum(1).mean(0) 15 | uniform_loss = torch.masked_select(torch.cdist(image_embeddings.unsqueeze(0), text_embeddings.unsqueeze(0))[0], torch.ones((len(image_embeddings), len(text_embeddings))).to(device).tril(diagonal = -1) == 1).square().mul(-2).exp().mean().log() 16 | 17 | print(f"Align Loss: {align_loss.cpu().item()}") 18 | print(f"Uniform Loss: {uniform_loss.cpu().item()}") 19 | 20 | if(__name__ == "__main__"): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("-e,--embeddings", dest = "embeddings", type = str, default = "analysis/embeddings/clip/CC3M.validation.pkl", help = "Input file") 23 | options = parser.parse_args() 24 | run(options) -------------------------------------------------------------------------------- /utils/augment_text.py: -------------------------------------------------------------------------------- 1 | import os 2 | import nltk 3 | import argparse 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from utils import config 7 | from .eda import * 8 | 9 | def _augment_text(caption): 10 | augmented_caption = eda(caption) 11 | return augmented_caption[0] 12 | 13 | def augment_text(options): 14 | df = pd.read_csv(os.path.join(config.root, options.input_file), delimiter = options.delimiter) 15 | captions = df[options.caption_key] 16 | 17 | augmented_captions = [] 18 | for caption in tqdm(captions): 19 | augmented_caption = eda(caption) 20 | augmented_captions.append(augmented_caption[0]) 21 | 22 | df["augmented_" + options.caption_key] = augmented_captions 23 | df.to_csv(os.path.join(config.root, options.output_file), index = False) 24 | 25 | if(__name__ == "__main__"): 26 | parser = argparse.ArgumentParser() 27 | 28 | parser.add_argument("-i,--input_file", dest = "input_file", type = str, required = True, help = "Input file") 29 | parser.add_argument("-o,--output_file", dest = "output_file", type = str, required = True, help = "Output file") 30 | parser.add_argument("--delimiter", type = str, default = ",", help = "Input file delimiter") 31 | parser.add_argument("--caption_key", type = str, default = "caption", help = "Caption column name") 32 | 33 | options = parser.parse_args() 34 | augment_text(options) -------------------------------------------------------------------------------- /utils/prepare_sbucaptions.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import tarfile 4 | import shutil 5 | import pandas as pd 6 | from tqdm import tqdm 7 | 8 | root = '/data0/datasets/sbucaptions/' 9 | 10 | tar_files = os.listdir(root) 11 | print(tar_files) 12 | 13 | 14 | for tar_file in tqdm(tar_files): 15 | 16 | folder = tar_file.split('.')[0] 17 | print(folder) 18 | if os.path.exists(os.path.join(root, folder)): 19 | print(1) 20 | shutil.rmtree(os.path.join(root, folder)) 21 | 22 | try: 23 | file = tarfile.open(os.path.join(root, tar_file)) 24 | file.extractall(os.path.join(root, folder)) 25 | file.close() 26 | 27 | all_files = os.listdir(os.path.join(root, folder)) 28 | txt_files = list(filter(lambda x: '.txt' in x, all_files)) 29 | 30 | for txt_file in txt_files: 31 | caption = open(os.path.join(root, folder, txt_file), 'r').readlines()[0].strip() 32 | image_location = os.path.join(root, folder, txt_file.replace('.txt', '.jpg')) 33 | os.remove(os.path.join(root, folder, txt_file)) 34 | os.remove(os.path.join(root, folder, txt_file.replace('.txt', '.json'))) 35 | with open(os.path.join(root, 'sbucaptions.csv'), 'a') as csvfile: 36 | csvwriter = csv.writer(csvfile) 37 | csvwriter.writerow([image_location, caption]) 38 | except: 39 | pass 40 | 41 | -------------------------------------------------------------------------------- /utils/superclass_distribution_plot_supplementary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | from matplotlib.ticker import MaxNLocator 5 | 6 | params = { 7 | "figure.figsize": (8, 6), 8 | "axes.titlesize": 22, 9 | "legend.fontsize": 20, 10 | "axes.labelsize": 20, 11 | "axes.titlesize": 20, 12 | "xtick.labelsize": 20, 13 | "ytick.labelsize": 20, 14 | "figure.titlesize": 22, 15 | "font.family": "Liberation Mono" 16 | } 17 | 18 | plt.rcParams.update(params) 19 | plt.style.use("seaborn-whitegrid") 20 | sns.set_style("white") 21 | 22 | 23 | ax = plt.figure().gca() 24 | ax.xaxis.set_major_locator(MaxNLocator(integer = True)) 25 | 26 | file = "data/ImageNet-A/classes.py" 27 | superclasses = eval(open(file).read())["superclasses"] 28 | 29 | data = list(map(len, superclasses)) 30 | 31 | _, bins, patches = plt.hist(data, 25, color = "green", edgecolor = "black", linewidth = 0.75) 32 | centers = 0.5 * (bins[:-1] + bins[1:]) 33 | colors = centers - min(centers) 34 | colors /= max(colors) 35 | 36 | colormap = plt.cm.get_cmap("RdYlBu_r") 37 | for color, patch in zip(colors, patches): 38 | plt.setp(patch, "facecolor", colormap(color)) 39 | 40 | plt.xlabel("Number of subclasses per superclass", labelpad = 12) 41 | plt.tight_layout() 42 | 43 | os.makedirs("analysis/plots", exist_ok = True) 44 | plt.savefig(f"analysis/plots/superclass_distribution.ImageNet-A.png") -------------------------------------------------------------------------------- /utils/datasize_ablation_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | params = { 7 | "legend.fontsize": 18, 8 | "axes.labelsize": 16, 9 | "xtick.labelsize": 16, 10 | "ytick.labelsize": 18, 11 | "font.family": "Liberation Mono" 12 | } 13 | 14 | plt.rcParams.update(params) 15 | plt.style.use("seaborn-whitegrid") 16 | sns.set_style("white") 17 | 18 | data = { 19 | "Metric": ["500K", "1M", "2M", "3M"], 20 | "CLIP": [5.03, 11.1, 15.10, 17.28], 21 | "CyCLIP": [6.19, 12.12, 16.57, 18.68], 22 | } 23 | 24 | _, axes = plt.subplots(1, 1, figsize = (8, 5)) 25 | 26 | axes.set_ylim(2.5, 20.0) 27 | axes.set_yticks(np.arange(5, 22.5, 2.5)) 28 | axes.set_xlabel("Dataset size", fontsize = 18, labelpad = 12) 29 | axes.set_ylabel("Top1 Accuracy (%) on ImageNet1K", fontsize = 18, labelpad = 12) 30 | axes.set_xticks(np.arange(len(data["Metric"]))) 31 | axes.set_xticklabels(data["Metric"], fontsize = 18, rotation = 0) 32 | axes.plot(np.arange(len(data["Metric"])), data["CLIP"], "^-", markersize = 12.5, label = "CLIP", color = "brown", alpha = 0.85) 33 | axes.plot(np.arange(len(data["Metric"])), data["CyCLIP"], "*-", markersize = 15, label = "CyCLIP", alpha = 0.85) 34 | axes.yaxis.grid() 35 | 36 | axes.legend(bbox_to_anchor = (1.0, 1.05)) 37 | plt.tight_layout() 38 | 39 | os.makedirs("analysis/plots", exist_ok = True) 40 | plt.savefig(f"analysis/plots/datasize_ablation.png") -------------------------------------------------------------------------------- /src/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.multiprocessing as mp 3 | from logging import Formatter, FileHandler, StreamHandler 4 | from logging.handlers import QueueHandler, QueueListener 5 | 6 | class LogFilter(logging.Filter): 7 | def __init__(self, rank, distributed): 8 | super().__init__() 9 | self.rank = rank 10 | self.distributed = distributed 11 | 12 | def filter(self, record): 13 | if(self.distributed): 14 | record.msg = f"Rank {self.rank} | {record.msg}" 15 | return True 16 | 17 | def set_logger(rank, logger, distributed = False): 18 | queue_handler = QueueHandler(logger) 19 | queue_handler.addFilter(LogFilter(rank, distributed)) 20 | queue_handler.setLevel(logging.INFO) 21 | queue_handler.flush() 22 | 23 | logger = logging.getLogger() 24 | logger.addHandler(queue_handler) 25 | logger.setLevel(logging.INFO) 26 | 27 | def get_logger(log_file_path): 28 | logger = mp.Queue(-1) 29 | 30 | formatter = Formatter("%(asctime)s | %(levelname)s | %(message)s", datefmt = "%Y-%m-%d,%H:%M:%S") 31 | 32 | file_handler = FileHandler(log_file_path, "w+") 33 | file_handler.setFormatter(formatter) 34 | file_handler.setLevel(logging.INFO) 35 | 36 | stream_handler = StreamHandler() 37 | stream_handler.setFormatter(formatter) 38 | stream_handler.setLevel(logging.INFO) 39 | 40 | listener = QueueListener(logger, file_handler, stream_handler) 41 | 42 | return logger, listener 43 | -------------------------------------------------------------------------------- /utils/hypersphere_labels.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | def run(options): 7 | with torch.no_grad(): 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | data = pickle.load(open(options.embeddings, "rb")) 11 | 12 | image_embeddings, text_embeddings, labels = torch.tensor(data["image_embeddings"]).to(device), torch.tensor(data["text_embeddings"]).to(device), torch.tensor(data["labels"]).to(device) 13 | text_embeddings = text_embeddings[labels] 14 | 15 | alignment = (image_embeddings * text_embeddings).sum(1).mean(0) 16 | 17 | batch_size = 32 18 | 19 | dataset = torch.utils.data.TensorDataset(image_embeddings) 20 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size) 21 | 22 | uniformity = torch.zeros([]).to(device) 23 | for index, batch in enumerate(dataloader): 24 | image_embedding = batch[0] 25 | cross = image_embedding @ text_embeddings.t() 26 | uniformity += (-cross).exp().sum() - (-cross.diag(index * batch_size)).exp().sum() 27 | uniformity /= (len(image_embeddings) * (len(image_embeddings) - 1)) 28 | uniformity = uniformity.log() 29 | 30 | print(f"Alignment: {alignment.cpu().item()}") 31 | print(f"Uniformity: {uniformity.cpu().item()}") 32 | 33 | if(__name__ == "__main__"): 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("-e,--embeddings", dest = "embeddings", type = str, default = "analysis/embeddings/clip/ImageNet1K.validation.pkl", help = "Input file") 36 | options = parser.parse_args() 37 | run(options) -------------------------------------------------------------------------------- /utils/zeroshot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | def run(options): 7 | with torch.no_grad(): 8 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 9 | 10 | data = pickle.load(open(options.embeddings, "rb")) 11 | image_embeddings, text_embeddings, labels = torch.tensor(data["image_embeddings"]), torch.tensor(data["text_embeddings"]).to(device), torch.tensor(data["labels"]) 12 | 13 | dataset = torch.utils.data.TensorDataset(image_embeddings, labels) 14 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.batch_size) 15 | 16 | correct = {k: 0 for k in options.k} 17 | 18 | for image_embedding, label in tqdm(dataloader): 19 | image_embedding, label = image_embedding.to(device), label.to(device) 20 | logits = (image_embedding @ text_embeddings.t()) 21 | ranks = logits.topk(max(options.k), 1)[1].T 22 | predictions = ranks == label 23 | for k in options.k: 24 | correct[k] += torch.sum(torch.any(predictions[:k], dim = 0)).item() 25 | 26 | for k in options.k: 27 | print(f"Zeroshot top {k}: {correct[k] / len(dataset) * 100.0}") 28 | 29 | if(__name__ == "__main__"): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--embeddings", type = str, default = "analysis/embeddings/clip/ImageNet1K.validation.pkl", help = "Input test embeddings file") 32 | parser.add_argument("--batch_size", type = int, default = 32, help = "Batch size") 33 | parser.add_argument("--k", nargs = "+", default = [1, 3, 5]) 34 | options = parser.parse_args() 35 | run(options) -------------------------------------------------------------------------------- /utils/linear_probe_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import seaborn as sns 3 | import matplotlib.pyplot as plt 4 | 5 | params = { 6 | "axes.titlesize": 22, 7 | "legend.fontsize": 16, 8 | "axes.labelsize": 16, 9 | "axes.titlesize": 16, 10 | "xtick.labelsize": 16, 11 | "ytick.labelsize": 16, 12 | "figure.titlesize": 22, 13 | "font.family": "Liberation Mono" 14 | } 15 | 16 | plt.rcParams.update(params) 17 | plt.style.use("seaborn-whitegrid") 18 | sns.set_style("white") 19 | 20 | df = { 21 | "Dataset": ["Caltech101", "CIFAR10", "CIFAR100", "DTD", "FGVCAircraft", "Flowers102", "Food101", "GTSRB", "ImageNet1K", "OxfordIIITPet", "RenderedSST2", "StanfordCars", "STL10", "SVHN"], 22 | "CLIP": [76.80, 78.27, 72.37, 61.44, 28.32, 84.96, 54.47, 69.45, 35.47, 58.85, 53.10, 20.53, 89.76, 47.64], 23 | "CyCLIP": [77.10, 77.34, 72.77, 64.47, 27.24, 84.72, 54.95, 71.70, 36.69, 58.10, 54.04, 22.72, 90.42, 48.16], 24 | } 25 | 26 | figure, axes = plt.subplots(3, 5, figsize = (15, 9)) 27 | 28 | for index in range(14): 29 | row, col = index // 5, index % 5 30 | axes[row][col].set_title(df["Dataset"][index]) 31 | axes[row][col].set_ylim(0, 100) 32 | axes[row][0].set_ylabel("Top1 Accuracy (%)") 33 | axes[row][col].set_xticks([1.0, 1.25]) 34 | axes[row][col].set_xticklabels(["CLIP", "CyCLIP"], fontsize = 14, rotation = 0) 35 | axes[row][col].bar(1.0, df["CLIP"][index], label = "CLIP", width = 0.15, color = "brown") 36 | axes[row][col].bar(1.25, df["CyCLIP"][index], label = "CyCLIP", width = 0.15) 37 | axes[row][col].text(1.0, df["CLIP"][index] + 1, str(df["CLIP"][index]), horizontalalignment = "center", fontsize = "x-large") 38 | axes[row][col].text(1.25, df["CyCLIP"][index] + 1, str(df["CyCLIP"][index]), horizontalalignment = "center", fontsize = "x-large") 39 | 40 | figure.delaxes(axes[2][4]) 41 | figure.tight_layout() 42 | 43 | os.makedirs("analysis/plots", exist_ok = True) 44 | plt.savefig(f"analysis/plots/linear_probe_plot.png") -------------------------------------------------------------------------------- /utils/augment_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import torchvision 4 | import pandas as pd 5 | from tqdm import tqdm 6 | from utils import config 7 | from multiprocessing import Pool 8 | from PIL import Image, ImageFile 9 | 10 | ImageFile.LOAD_TRUNCATED_IMAGES = True 11 | 12 | transform = torchvision.transforms.AutoAugment() 13 | 14 | def _augment_image(image_file): 15 | image = Image.open(image_file) 16 | augmented_image = transform(image) 17 | return augmented_image 18 | 19 | def augment(image_file): 20 | augmented_image_file = os.path.splitext(image_file)[0] + ".augmented" + os.path.splitext(image_file)[1] 21 | if(os.path.exists(augmented_image_file)): 22 | return 23 | image = Image.open(image_file) 24 | augmented_image = transform(image) 25 | augmented_image.save(augmented_image_file) 26 | 27 | def augment_image(options): 28 | path = os.path.join(config.root, options.input_file) 29 | df = pd.read_csv(path, delimiter = options.delimiter) 30 | 31 | root = os.path.dirname(path) 32 | image_files = df[options.image_key].apply(lambda image_file: os.path.join(root, image_file)).tolist() 33 | with Pool() as pool: 34 | for _ in tqdm(pool.imap(augment, image_files), total = len(image_files)): 35 | pass 36 | 37 | df["augmented_" + options.image_key] = df[options.image_key].apply(lambda image_file: os.path.splitext(image_file)[0] + ".augmented" + os.path.splitext(image_file)[1]) 38 | df.to_csv(os.path.join(config.root, options.output_file), index = False) 39 | 40 | if(__name__ == "__main__"): 41 | parser = argparse.ArgumentParser() 42 | 43 | parser.add_argument("-i,--input_file", dest = "input_file", type = str, required = True, help = "Input file") 44 | parser.add_argument("-o,--output_file", dest = "output_file", type = str, required = True, help = "Output file") 45 | parser.add_argument("--delimiter", type = str, default = ",", help = "Input file delimiter") 46 | parser.add_argument("--image_key", type = str, default = "image", help = "Caption column name") 47 | 48 | options = parser.parse_args() 49 | augment_image(options) -------------------------------------------------------------------------------- /utils/flickr30k_data_creation.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import json 4 | import pandas as pd 5 | from tqdm import tqdm 6 | 7 | def move_files(images, split): 8 | 9 | for image in tqdm(images): 10 | current_loc = os.path.join(root, 'images', image) 11 | dest_loc = os.path.join(root, split) 12 | shutil.move(current_loc, dest_loc) 13 | 14 | def get_data(split): 15 | 16 | split_images = [] 17 | split_comments = [] 18 | 19 | df = pd.read_csv(os.path.join(root, 'results.csv'), "|") 20 | imgs = os.listdir(os.path.join(root, split)) 21 | for img in tqdm(imgs): 22 | df_ = df.where(df['image_name'] == img).dropna() 23 | locations = list(map(lambda x: f'{split}/{x}' ,df_['image_name'].tolist())) 24 | split_images = split_images + locations 25 | comments = df_[' comment'].tolist() 26 | comments = list(map(lambda x: x[1:], comments)) 27 | split_comments = split_comments + comments 28 | return split_images, split_comments 29 | 30 | if __name__ == "__main__": 31 | 32 | root = './data/flickr30k' 33 | file_management = False 34 | file_creation = True 35 | 36 | if file_management: 37 | with open(os.path.join(root, 'dataset.json')) as f: 38 | dataset = json.load(f) 39 | train_images = list(filter(lambda x: x['split'] == 'train', dataset['images'])) 40 | val_images = list(filter(lambda x: x['split'] == 'val', dataset['images'])) 41 | test_images = list(filter(lambda x: x['split'] == 'test', dataset['images'])) 42 | train_images, val_images, test_images = list(map(lambda li: list(map(lambda x: x['filename'], li)), [train_images, val_images, test_images])) 43 | assert(len(test_images) == 1000) 44 | list_of_all_images = os.listdir(os.path.join(root, 'images')) 45 | move_files(train_images, 'train') 46 | move_files(val_images, 'validation') 47 | move_files(test_images, 'test') 48 | 49 | if file_creation: 50 | # train_images, train_captions = get_data(os.path.join(root, 'train')) 51 | # validation_images, validation_captions = get_data(os.path.join(root, 'validation')) 52 | test_images, test_captions = get_data('test') 53 | 54 | # images = train_images + validation_images + test_images 55 | # caps = train_captions + validation_captions + test_captions 56 | 57 | data = {'image': test_images, 58 | 'caption': test_captions} 59 | dt = pd.DataFrame(data) 60 | dt.to_csv(f'{root}/flickr30k.csv') -------------------------------------------------------------------------------- /utils/effective_robustness_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | params = { 7 | "figure.figsize": (10, 5), 8 | "legend.fontsize": 16, 9 | "axes.labelsize": 16, 10 | "xtick.labelsize": 16, 11 | "ytick.labelsize": 16, 12 | "font.family": "Liberation Mono" 13 | } 14 | 15 | plt.rcParams.update(params) 16 | plt.style.use("seaborn-whitegrid") 17 | sns.set_style("white") 18 | 19 | clip = { 20 | "Datasize": ["500K", "1M", "2M", "3M"], 21 | "ImageNet1K": [5.03, 11.1, 15.07, 17.28], 22 | "ImageNetV2": [4.53, 9.33, 12.65, 14.39] 23 | } 24 | 25 | cyclip = { 26 | "Datasize": ["500K", "1M", "2M", "3M"], 27 | "ImageNet1K": [6.19, 12.12, 16.57, 18.68], 28 | "ImageNetV2": [5.33, 9.88, 13.89, 15.31] 29 | } 30 | 31 | plt.clf() 32 | 33 | plt.scatter(clip["ImageNet1K"][0], clip["ImageNetV2"][0], marker = "^", s = 80, label = f"CLIP-{clip['Datasize'][0]}") 34 | plt.scatter(clip["ImageNet1K"][1], clip["ImageNetV2"][1], marker = "^", s = 80, label = f"CLIP-{clip['Datasize'][1]}") 35 | plt.scatter(clip["ImageNet1K"][2], clip["ImageNetV2"][2], marker = "^", s = 80, label = f"CLIP-{clip['Datasize'][2]}") 36 | plt.scatter(clip["ImageNet1K"][3], clip["ImageNetV2"][3], marker = "^", s = 80, label = f"CLIP-{clip['Datasize'][3]}") 37 | 38 | plt.scatter(cyclip["ImageNet1K"][0], cyclip["ImageNetV2"][0], marker = "*", s = 96, label = f"CyCLIP-{cyclip['Datasize'][0]}") 39 | plt.scatter(cyclip["ImageNet1K"][1], cyclip["ImageNetV2"][1], marker = "*", s = 96, label = f"CyCLIP-{cyclip['Datasize'][1]}") 40 | plt.scatter(cyclip["ImageNet1K"][2], cyclip["ImageNetV2"][2], marker = "*", s = 96, label = f"CyCLIP-{cyclip['Datasize'][2]}") 41 | plt.scatter(cyclip["ImageNet1K"][3], cyclip["ImageNetV2"][3], marker = "*", s = 96, label = f"CyCLIP-{cyclip['Datasize'][3]}") 42 | 43 | plt.xlabel("Top1 Accuracy on ImageNet1K (%)", labelpad = 12) 44 | plt.ylabel(f"Top1 Accuracy on ImageNetV2 (%)", labelpad = 12) 45 | xpoints = ypoints = plt.xlim() 46 | plt.plot(xpoints, ypoints, linestyle = "--", color = "k", lw = 2, scalex = False, scaley = False, label = "y = x") 47 | 48 | ypoints = plt.ylim() 49 | xpoints = [(y + 9) / 1.2 for y in ypoints] 50 | plt.plot(xpoints, ypoints, linestyle = "-", color = "r", lw = 1, label = "Linear fit to\nstandard training") 51 | 52 | plt.yticks(np.arange(5.0, 22.5, 2.5)) 53 | plt.xticks(np.arange(5.0, 25.0, 2.5)) 54 | plt.ylim(2.5, 20.0) 55 | plt.xlim(2.5, 22.5) 56 | 57 | plt.legend(bbox_to_anchor = (1.0, 1.05)) 58 | plt.grid() 59 | plt.tight_layout() 60 | 61 | os.makedirs("analysis/plots", exist_ok = True) 62 | plt.savefig(f"analysis/plots/effective_robustness_plot.png") -------------------------------------------------------------------------------- /data/CIFAR10/test/classes.py: -------------------------------------------------------------------------------- 1 | { 2 | "classes": ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"], 3 | 4 | "templates": [lambda s: f"a bad photo of a {s}.", lambda s: f"a photo of many {s}.", lambda s: f"a sculpture of a {s}.", lambda s: f"a photo of the hard to see {s}.", lambda s: f"a low resolution photo of the {s}.", lambda s: f"a rendering of a {s}.", lambda s: f"graffiti of a {s}.", lambda s: f"a bad photo of the {s}.", lambda s: f"a cropped photo of the {s}.", lambda s: f"a tattoo of a {s}.", lambda s: f"the embroidered {s}.", lambda s: f"a photo of a hard to see {s}.", lambda s: f"a bright photo of a {s}.", lambda s: f"a photo of a clean {s}.", lambda s: f"a photo of a dirty {s}.", lambda s: f"a dark photo of the {s}.", lambda s: f"a drawing of a {s}.", lambda s: f"a photo of my {s}.", lambda s: f"the plastic {s}.", lambda s: f"a photo of the cool {s}.", lambda s: f"a close-up photo of a {s}.", lambda s: f"a black and white photo of the {s}.", lambda s: f"a painting of the {s}.", lambda s: f"a painting of a {s}.", lambda s: f"a pixelated photo of the {s}.", lambda s: f"a sculpture of the {s}.", lambda s: f"a bright photo of the {s}.", lambda s: f"a cropped photo of a {s}.", lambda s: f"a plastic {s}.", lambda s: f"a photo of the dirty {s}.", lambda s: f"a jpeg corrupted photo of a {s}.", lambda s: f"a blurry photo of the {s}.", lambda s: f"a photo of the {s}.", lambda s: f"a good photo of the {s}.", lambda s: f"a rendering of the {s}.", lambda s: f"a {s} in a video game.", lambda s: f"a photo of one {s}.", lambda s: f"a doodle of a {s}.", lambda s: f"a close-up photo of the {s}.", lambda s: f"a photo of a {s}.", lambda s: f"the origami {s}.", lambda s: f"the {s} in a video game.", lambda s: f"a sketch of a {s}.", lambda s: f"a doodle of the {s}.", lambda s: f"a origami {s}.", lambda s: f"a low resolution photo of a {s}.", lambda s: f"the toy {s}.", lambda s: f"a rendition of the {s}.", lambda s: f"a photo of the clean {s}.", lambda s: f"a photo of a large {s}.", lambda s: f"a rendition of a {s}.", lambda s: f"a photo of a nice {s}.", lambda s: f"a photo of a weird {s}.", lambda s: f"a blurry photo of a {s}.", lambda s: f"a cartoon {s}.", lambda s: f"art of a {s}.", lambda s: f"a sketch of the {s}.", lambda s: f"a embroidered {s}.", lambda s: f"a pixelated photo of a {s}.", lambda s: f"itap of the {s}.", lambda s: f"a jpeg corrupted photo of the {s}.", lambda s: f"a good photo of a {s}.", lambda s: f"a plushie {s}.", lambda s: f"a photo of the nice {s}.", lambda s: f"a photo of the small {s}.", lambda s: f"a photo of the weird {s}.", lambda s: f"the cartoon {s}.", lambda s: f"art of the {s}.", lambda s: f"a drawing of the {s}.", lambda s: f"a photo of the large {s}.", lambda s: f"a black and white photo of a {s}.", lambda s: f"the plushie {s}.", lambda s: f"a dark photo of a {s}.", lambda s: f"itap of a {s}.", lambda s: f"graffiti of the {s}.", lambda s: f"a toy {s}.", lambda s: f"itap of my {s}.", lambda s: f"a photo of a cool {s}.", lambda s: f"a photo of a small {s}.", lambda s: f"a tattoo of the {s}."] 5 | } -------------------------------------------------------------------------------- /utils/fine_coarse_plot.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | params = { 7 | "axes.titlesize": 22, 8 | "legend.fontsize": 16, 9 | "figure.figsize": (12, 8), 10 | "axes.labelsize": 16, 11 | "axes.titlesize": 16, 12 | "xtick.labelsize": 16, 13 | "ytick.labelsize": 16, 14 | "figure.titlesize": 22, 15 | "font.family": "Liberation Mono" 16 | } 17 | 18 | plt.rcParams.update(params) 19 | plt.style.use("seaborn-whitegrid") 20 | sns.set_style("white") 21 | 22 | data_fine_top_1 = { 23 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 24 | "CLIP": [43.55, 34.18, 31.10, 25.11, 40.09, 54.49], 25 | "CyCLIP": [47.11, 35.46, 33.01, 26.42, 41.25, 55.36] 26 | } 27 | 28 | data_coarse_top_1 = { 29 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 30 | "CLIP": [34.82, 52.11, 47.24, 37.03, 14.59, 40.51], 31 | "CyCLIP": [40.37, 56.79, 52.15, 41.57, 16.48, 44.43] 32 | } 33 | 34 | data_fine_top_2 = { 35 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 36 | "CLIP": [67.65, 46.53, 43.85, 36.69, 57.93, 71.55], 37 | "CyCLIP": [68.67, 48.22, 45.37, 38.34, 57.89, 71.96] 38 | } 39 | 40 | data_coarse_top_2 = { 41 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 42 | "CLIP": [49.51, 66.06, 61.42, 49.93, 24.71, 52.77], 43 | "CyCLIP": [56.26, 69.67, 65.27, 54.59, 26.75, 55.98] 44 | } 45 | 46 | data_fine_top_3 = { 47 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 48 | "CLIP": [82.41, 53.95, 51.61, 44.66, 67.80, 79.52], 49 | "CyCLIP": [83.32, 55.43, 52.85, 46.09, 68.75, 79.81] 50 | } 51 | 52 | data_coarse_top_3 = { 53 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 54 | "CLIP": [58.39, 73.20, 69.09, 57.61, 32.49, 60.07], 55 | "CyCLIP": [65.65, 76.45, 72.40, 62.22, 34.83, 63.12] 56 | } 57 | 58 | k = 1 59 | 60 | data = (eval(f"data_fine_top_{k}"), eval(f"data_coarse_top_{k}")) 61 | 62 | figure, axes = plt.subplots(1, 2, figsize = (12, 5)) 63 | 64 | ylims = {1: 80, 2: 100, 3: 100} 65 | X_axis = np.arange(1, 7) 66 | Y_axis = np.arange(0, ylims[k], 10) 67 | 68 | for index in range(2): 69 | axes[index].set_ylim(0, ylims[k]) 70 | axes[index].set_xlabel(f"{'(a)' if (index == 0) else '(b)'}") 71 | axes[index].set_ylabel(f"Top{k} Accuracy (%)", labelpad = 12) 72 | axes[index].set_xticks(X_axis) 73 | axes[index].set_yticks(Y_axis) 74 | axes[index].set_xticklabels(data[index]["Dataset"], fontsize = 12, rotation = 30) 75 | axes[index].set_yticklabels(Y_axis, fontsize = 12, rotation = 0) 76 | axes[index].bar(X_axis, data[index]["CLIP"], label = "CLIP", width = 0.35, color = "brown", alpha = 0.85) 77 | axes[index].bar(X_axis + 0.37, data[index]["CyCLIP"], width = 0.35, label = "CyCLIP", alpha = 0.85) 78 | 79 | for i in range(len(data[index])): 80 | text = ("+" if (data[index]["CyCLIP"][i] > data[index]["CLIP"][i]) else "-") + str(round(abs(data[index]["CyCLIP"][i] - data[index]["CLIP"][i]), 2)) + "%" 81 | axes[index].text(i + 0.9, data[index]["CyCLIP"][i] + 1.5, text, fontsize = 12) 82 | 83 | axes[index].legend(prop = {"size": 12}) 84 | 85 | plt.tight_layout() 86 | plt.subplots_adjust(wspace = 0.225) 87 | 88 | os.makedirs("analysis/plots", exist_ok = True) 89 | plt.savefig(f"analysis/plots/fine_coarse_top{k}.png") -------------------------------------------------------------------------------- /utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import zlib 4 | import uuid 5 | import shelve 6 | import tarfile 7 | import argparse 8 | import requests 9 | import pandas as pd 10 | 11 | from io import BytesIO 12 | from tqdm import tqdm 13 | from PIL import Image 14 | from multiprocessing import Pool 15 | from torchvision import transforms 16 | 17 | transform = transforms.Compose([transforms.Resize(224, interpolation = transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224)]) 18 | 19 | def download(row): 20 | rfile = f"images/{zlib.crc32(row['image'].encode('utf-8')) & 0xffffffff}.png" 21 | file = f"{row['dir']}/{rfile}" 22 | 23 | if(os.path.isfile(file)): 24 | row["status"] = 200 25 | row["file"] = rfile 26 | return row 27 | 28 | try: 29 | response = requests.get(row["image"], stream = False, timeout = 10, allow_redirects = True) 30 | row["status"] = response.status_code 31 | except Exception as e: 32 | row["status"] = 404 33 | return row 34 | 35 | if(response.ok): 36 | try: 37 | response.raw.decode_content = True 38 | image = Image.open(BytesIO(response.content)).convert("RGB") 39 | image = transform(image) 40 | image.save(file) 41 | except: 42 | row["status"] = 404 43 | return row 44 | 45 | row["file"] = rfile 46 | 47 | return row 48 | 49 | def apply(args): 50 | index, df, function = args 51 | df = df.apply(function, axis = 1) 52 | return (index, df) 53 | 54 | def multiprocess(df, function, dir, hash): 55 | with shelve.open(f"{dir}/.{hash}") as file: 56 | bar = tqdm(total = math.ceil(len(df) / 50)) 57 | 58 | finished = set(map(int, file.keys())) 59 | for key in file.keys(): 60 | bar.update() 61 | 62 | data = [(index, df[i:i + 50], function) for index, i in enumerate(range(0, len(df), 50)) if index not in finished] 63 | 64 | if(len(data) > 0): 65 | with Pool() as pool: 66 | for result in pool.imap_unordered(apply, data, 2): 67 | file[str(result[0])] = result 68 | bar.update() 69 | 70 | bar.close() 71 | 72 | keys = sorted([int(k) for k in file.keys()]) 73 | df = pd.concat([file[str(key)][1] for key in keys]) 74 | df = df[["file", "caption"]].rename(columns = {"file": "image"}) 75 | 76 | return df 77 | 78 | def run(options): 79 | os.makedirs(options.dir, exist_ok = True) 80 | os.makedirs(os.path.join(options.dir, "images"), exist_ok = True) 81 | 82 | df = pd.read_csv(options.file, sep = "\t", names = [ "caption", "image"]) 83 | df["dir"] = options.dir 84 | df = df[options.start:options.end] 85 | 86 | df = multiprocess(df, function = download, dir = options.dir, hash = options.hash) 87 | df.to_csv(f"{options.dir}/train.csv", index = False) 88 | 89 | if(__name__ == "__main__"): 90 | parser = argparse.ArgumentParser() 91 | 92 | parser.add_argument("-f,--file", dest = "file", type = str, default = None, help = "File") 93 | parser.add_argument("-d,--dir", dest = "dir", type = str, default = None, help = "Directory") 94 | parser.add_argument("-s,--start", dest = "start", type = int, default = 0, help = "Start index") 95 | parser.add_argument("-e,--end", dest = "end", type = int, default = 1000000000000, help = "End index") 96 | 97 | options = parser.parse_args() 98 | options.hash = str(uuid.uuid5(uuid.NAMESPACE_DNS, f"{options.file}-{options.dir}-{options.start}-{options.end}")) 99 | 100 | run(options) 101 | -------------------------------------------------------------------------------- /utils/fine_coarse.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import argparse 5 | import numpy as np 6 | import pandas as pd 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | 10 | def get_fine_grained_accuracies(image_embeddings, text_embeddings, labels, classes, superclasses, superclasses_indices, topk, device): 11 | with torch.no_grad(): 12 | fine_grained_accuracies = torch.zeros(len(topk)) 13 | for superclass, superclass_indices in tqdm(list(zip(superclasses, superclasses_indices)), leave = False): 14 | example_indices = sum(labels == index for index in superclass_indices).bool() 15 | 16 | sub_image_embeddings = image_embeddings[example_indices, :] 17 | sub_text_embeddings = text_embeddings[superclass_indices, :] 18 | sub_similarity = sub_image_embeddings @ sub_text_embeddings.t() 19 | 20 | sub_labels = torch.tensor([superclass_indices.index(label) for label in labels[example_indices]]).to(device) 21 | sub_ranks = sub_similarity.topk(min(max(topk), sub_similarity.shape[1]), 1)[1].T 22 | sub_predictions = sub_ranks == sub_labels 23 | 24 | for i, k in enumerate(topk): 25 | fine_grained_accuracies[i] += torch.sum(torch.any(sub_predictions[:k], dim = 0)).item() 26 | 27 | fine_grained_accuracies /= len(labels) 28 | return fine_grained_accuracies.tolist() 29 | 30 | def get_coarse_grained_accuracies(image_embeddings, text_embeddings, labels, classes, superclasses, superclasses_indices, topk, device): 31 | with torch.no_grad(): 32 | group_label_map = {label: group_label for group_label, superclass_indices in enumerate(superclasses_indices) for label in superclass_indices} 33 | 34 | print(len(group_label_map)) 35 | coarse_grained_accuracies = torch.zeros(len(topk)) 36 | for index in tqdm(list(range(0, len(image_embeddings), 128))): 37 | similarity = image_embeddings[index:index + 128, :] @ text_embeddings.t() 38 | group_similarity = torch.cat([similarity[:, superclass_indices].max(1)[0].unsqueeze(1) for superclass_indices in superclasses_indices], dim = 1) 39 | group_labels = torch.tensor([group_label_map[label.item()] for label in labels[index:index + 128]]).to(device) 40 | group_ranks = group_similarity.topk(min(max(topk), group_similarity.shape[1]), 1)[1].T 41 | group_predictions = group_ranks == group_labels 42 | 43 | for i, k in enumerate(topk): 44 | coarse_grained_accuracies[i] += torch.sum(torch.any(group_predictions[:k], dim = 0)).item() 45 | 46 | coarse_grained_accuracies /= len(labels) 47 | return coarse_grained_accuracies.tolist() 48 | 49 | def analyze(options): 50 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 51 | 52 | data = pickle.load(open(options.file, "rb")) 53 | image_embeddings, text_embeddings, labels, classes, superclasses = torch.tensor(data["image_embeddings"]).to(device), torch.tensor(data["text_embeddings"]).to(device), torch.tensor(data["labels"]).to(device), data["classes"], data["superclasses"] 54 | 55 | processed = defaultdict(lambda: -1) 56 | superclasses_indices = [] 57 | for superclass in superclasses: 58 | superclass_indices = [] 59 | for c in superclass: 60 | processed[c] = classes.index(c, processed[c] + 1) 61 | superclass_indices.append(processed[c]) 62 | superclasses_indices.append(superclass_indices) 63 | 64 | fine_grained_accuracies = get_fine_grained_accuracies(image_embeddings, text_embeddings, labels, classes, superclasses, superclasses_indices, options.topk, device) 65 | coarse_grained_accuracies = get_coarse_grained_accuracies(image_embeddings, text_embeddings, labels, classes, superclasses, superclasses_indices, options.topk, device) 66 | df = pd.DataFrame(columns = [f"Top {i}" for i in options.topk], index = ["Fine-grained Accuracy", "Coarse-grained Accuracy"]) 67 | df.loc["Fine-grained Accuracy"] = fine_grained_accuracies 68 | df.loc["Coarse-grained Accuracy"] = coarse_grained_accuracies 69 | print(df) 70 | 71 | if(__name__ == "__main__"): 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("-f,-file", dest = "file", type = str, default = "analysis/embeddings/clip/ImageNet1K.validation.pkl", help = "Input file") 74 | parser.add_argument("-k,-topk", dest = "topk", nargs = "+", default = [1, 2, 3], help = "Top-K Accuracies") 75 | options = parser.parse_args() 76 | analyze(options) 77 | -------------------------------------------------------------------------------- /pkgs/openai/clip.py: -------------------------------------------------------------------------------- 1 | # Code ported from https://github.com/openai/CLIP 2 | 3 | import os 4 | import torch 5 | import urllib 6 | import hashlib 7 | import warnings 8 | from PIL import Image 9 | from tqdm import tqdm 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, RandomResizedCrop 11 | 12 | from utils import config 13 | from .model import build 14 | from .tokenizer import SimpleTokenizer as Tokenizer 15 | 16 | models = { 17 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 18 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 19 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 20 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 21 | } 22 | 23 | def convert_models_to_fp32(model): 24 | for p in model.parameters(): 25 | p.data = p.data.float() 26 | if p.grad: 27 | p.grad.data = p.grad.data.float() 28 | 29 | def download(url, root = os.path.expanduser(f"{config.root}/.cache/openai")): 30 | os.makedirs(root, exist_ok=True) 31 | filename = os.path.basename(url) 32 | 33 | expected_sha256 = url.split("/")[-2] 34 | download_target = os.path.join(root, filename) 35 | 36 | if os.path.exists(download_target) and not os.path.isfile(download_target): 37 | raise RuntimeError(f"{download_target} exists and is not a regular file") 38 | 39 | if os.path.isfile(download_target): 40 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 41 | return download_target 42 | else: 43 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 44 | 45 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 46 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 47 | while True: 48 | buffer = source.read(8192) 49 | if not buffer: 50 | break 51 | 52 | output.write(buffer) 53 | loop.update(len(buffer)) 54 | 55 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 56 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 57 | 58 | return download_target 59 | 60 | class Processor: 61 | def __init__(self, model): 62 | self.tokenizer = Tokenizer() 63 | self.sot_token = self.tokenizer.encoder[""] 64 | self.eot_token = self.tokenizer.encoder[""] 65 | self.context_length = 77 66 | 67 | self.transform = Compose([Resize(model.visual.input_resolution, interpolation = Image.BICUBIC), CenterCrop(model.visual.input_resolution), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))]) 68 | 69 | def process_text(self, texts): 70 | if(isinstance(texts, str)): 71 | texts = [texts] 72 | 73 | result = torch.zeros(len(texts), self.context_length, dtype = torch.long) 74 | 75 | for i, text in enumerate(texts): 76 | tokens = [self.sot_token] + self.tokenizer.encode(text) + [self.eot_token] 77 | if(len(tokens) > self.context_length): 78 | tokens = tokens[:self.context_length] 79 | result[i, :len(tokens)] = torch.tensor(tokens) 80 | 81 | return {"input_ids": result, "attention_mask": torch.empty((len(result),))} 82 | 83 | def process_image(self, image): 84 | return self.transform(image.convert("RGB")) 85 | 86 | def load(name, pretrained = False): 87 | if(name in models): 88 | model_path = download(models[name]) 89 | else: 90 | raise RuntimeError(f"Model {name} not found; available models = {list(models.keys())}") 91 | 92 | model = torch.jit.load(model_path, map_location= "cpu").eval() 93 | 94 | try: 95 | model = build(model.state_dict(), pretrained = pretrained) 96 | except KeyError: 97 | state_dict = {key["module.":]: value for key, value in state_dict["state_dict"].items()} 98 | model = build(state_dict, pretrained = pretrained) 99 | 100 | convert_models_to_fp32(model) 101 | processor = Processor(model) 102 | 103 | return model, processor -------------------------------------------------------------------------------- /utils/fine_coarse_plot_supplementary.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import seaborn as sns 4 | import matplotlib.pyplot as plt 5 | 6 | params = { 7 | "axes.titlesize": 22, 8 | "legend.fontsize": 16, 9 | "figure.figsize": (12, 8), 10 | "axes.labelsize": 16, 11 | "axes.titlesize": 16, 12 | "xtick.labelsize": 16, 13 | "ytick.labelsize": 16, 14 | "figure.titlesize": 22, 15 | "font.family": "Liberation Mono" 16 | } 17 | 18 | plt.rcParams.update(params) 19 | plt.style.use("seaborn-whitegrid") 20 | sns.set_style("white") 21 | 22 | fine_top_1 = { 23 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 24 | "CLIP": [43.55, 34.18, 31.10, 25.11, 40.09, 54.49], 25 | "CyCLIP": [47.11, 35.46, 33.01, 26.42, 41.25, 55.36], 26 | "C-CyCLIP": [49.19, 35.12, 32.31, 26.37, 42.56, 55.59], 27 | "I-CyCLIP": [47.71, 34.92, 32.21, 24.89, 39.60, 53.99] 28 | } 29 | 30 | coarse_top_1 = { 31 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 32 | "CLIP": [34.82, 52.11, 47.24, 37.03, 14.59, 40.51], 33 | "CyCLIP": [40.37, 56.79, 52.15, 41.57, 16.48, 44.43], 34 | "C-CyCLIP": [43.06, 56.08, 50.89, 43.02, 16.40, 45.41], 35 | "I-CyCLIP": [39.30, 55.78, 51.48, 39.18, 15.63, 41.63] 36 | } 37 | 38 | fine_top_2 = { 39 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 40 | "CLIP": [67.65, 46.53, 43.85, 36.69, 57.93, 71.55], 41 | "CyCLIP": [68.67, 48.22, 45.37, 38.34, 57.89, 71.96], 42 | "C-CyCLIP": [71.02, 47.95, 45.02, 37.67, 59.41, 72.06], 43 | "I-CyCLIP": [69.30, 47.42, 45.00, 37.24, 56.67, 70.99] 44 | } 45 | 46 | coarse_top_2 = { 47 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 48 | "CLIP": [49.51, 66.06, 61.42, 49.93, 24.71, 52.77], 49 | "CyCLIP": [56.26, 69.67, 65.27, 54.59, 26.75, 55.98], 50 | "C-CyCLIP": [58.62, 69.74, 65.23, 56.41, 26.76, 57.21], 51 | "I-CyCLIP": [54.20, 68.46, 63.85, 51.58, 25.72, 51.94] 52 | } 53 | 54 | fine_top_3 = { 55 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 56 | "CLIP": [82.41, 53.95, 51.61, 44.66, 67.80, 79.52], 57 | "CyCLIP": [83.32, 55.43, 52.85, 46.09, 68.75, 79.81], 58 | "C-CyCLIP": [85.26, 55.14, 52.40, 45.41, 69.99, 79.60], 59 | "I-CyCLIP": [84.07, 54.95, 52.03, 45.30, 67.92, 79.41] 60 | } 61 | 62 | coarse_top_3 = { 63 | "Dataset": ["CIFAR-100", "ImageNet1K", "ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"], 64 | "CLIP": [58.39, 73.20, 69.09, 57.61, 32.49, 60.07], 65 | "CyCLIP": [65.65, 76.45, 72.40, 62.22, 34.83, 63.12], 66 | "C-CyCLIP": [67.91, 76.76, 72.55, 64.11, 35.25, 64.3], 67 | "I-CyCLIP": [63.30, 74.98, 70.44, 58.68, 32.76, 58.48] 68 | } 69 | 70 | k = 1 71 | data = (eval(f"fine_top_{k}"), eval(f"coarse_top_{k}")) 72 | 73 | figure, axes = plt.subplots(2, 1, figsize = (14, 10)) 74 | 75 | X_axis = np.arange(1, 7) 76 | Y_axis = np.arange(0, 70, 10) 77 | 78 | for index in range(2): 79 | axes[index].set_ylim(0, 70) 80 | axes[index].set_xlabel(f"{'(a) Fine-grained' if (index == 0) else '(b) Coarse-grained'}", labelpad = 14) 81 | axes[index].set_ylabel(f"Top{k} Accuracy (%)", labelpad = 12) 82 | axes[index].set_xticks(X_axis) 83 | axes[index].set_yticks(Y_axis) 84 | axes[index].set_xticklabels(data[index]["Dataset"], fontsize = 14) 85 | axes[index].set_yticklabels(Y_axis, fontsize = 14) 86 | axes[index].bar(X_axis - 1.5 * 0.20, data[index]["CLIP"], width = 0.185, label = "CLIP", alpha = 0.6, color = "brown") 87 | axes[index].bar(X_axis - 0.5 * 0.20, data[index]["CyCLIP"], width = 0.185, label = "CyCLIP", alpha = 0.6) 88 | axes[index].bar(X_axis + 0.5 * 0.20, data[index]["C-CyCLIP"], width = 0.185, label = "C-CyCLIP", alpha = 0.6) 89 | axes[index].bar(X_axis + 1.5 * 0.20, data[index]["I-CyCLIP"], width = 0.185, label = "I-CyCLIP", alpha = 0.6) 90 | 91 | for i in range(len(data[index]["Dataset"])): 92 | for j, model in enumerate(["CyCLIP", "C-CyCLIP", "I-CyCLIP"]): 93 | text = ("+" if (data[index][model][i] > data[index]["CLIP"][i]) else "-") + str(round(abs(data[index][model][i] - data[index]["CLIP"][i]), 1)) 94 | axes[index].text(i + 0.73 + j * 0.24, data[index][model][i] + 1.5, text, fontsize = 12) 95 | 96 | axes[index].legend(prop = {"size": 14}, bbox_to_anchor = (1.175, 1.025)) 97 | 98 | plt.tight_layout() 99 | plt.subplots_adjust(hspace = 0.25) 100 | 101 | os.makedirs("analysis/plots", exist_ok = True) 102 | plt.savefig(f"analysis/plots/fine_coarse_top{k}.supplementary.png") -------------------------------------------------------------------------------- /data/CIFAR100/test/classes.py: -------------------------------------------------------------------------------- 1 | { 2 | "classes": ["apples", "aquarium fish", "baby", "bear", "beaver", "bed", "bee", "beetle", "bicycle", "bottles", "bowls", "boy", "bridge", "bus", "butterfly", "camel", "cans", "castle", "caterpillar", "cattle", "chair", "chimpanzee", "clock", "cloud", "cockroach", "couch", "crab", "crocodile", "cups", "dinosaur", "dolphin", "elephant", "flatfish", "forest", "fox", "girl", "hamster", "house", "kangaroo", "computer keyboard", "lamp", "lawn-mower", "leopard", "lion", "lizard", "lobster", "man", "maple", "motorcycle", "mountain", "mouse", "mushrooms", "oak", "oranges", "orchids", "otter", "palm", "pears", "pickup truck", "pine", "plain", "plates", "poppies", "porcupine", "possum", "rabbit", "raccoon", "ray", "road", "rocket", "roses", "sea", "seal", "shark", "shrew", "skunk", "skyscraper", "snail", "snake", "spider", "squirrel", "streetcar", "sunflowers", "sweet peppers", "table", "tank", "telephone", "television", "tiger", "tractor", "train", "trout", "tulips", "turtle", "wardrobe", "whale", "willow", "wolf", "woman", "worm"], 3 | 4 | "superclasses": [["beaver", "dolphin", "otter", "seal", "whale"], ["aquarium fish", "flatfish", "ray", "shark", "trout"], ["orchids", "poppies", "roses", "sunflowers", "tulips"], ["bottles", "bowls", "cans", "cups", "plates"], ["apples", "mushrooms", "oranges", "pears", "sweet peppers"], ["clock", "computer keyboard", "lamp", "telephone", "television"], ["bed", "chair", "couch", "table", "wardrobe"], ["bee", "beetle", "butterfly", "caterpillar", "cockroach"], ["bear", "leopard", "lion", "tiger", "wolf"], ["bridge", "castle", "house", "road", "skyscraper"], ["cloud", "forest", "mountain", "plain", "sea"], ["camel", "cattle", "chimpanzee", "elephant", "kangaroo"], ["fox", "porcupine", "possum", "raccoon", "skunk"], ["crab", "lobster", "snail", "spider", "worm"], ["baby", "boy", "girl", "man", "woman"], ["crocodile", "dinosaur", "lizard", "snake", "turtle"], ["hamster", "mouse", "rabbit", "shrew", "squirrel"], ["maple", "oak", "palm", "pine", "willow"], ["bicycle", "bus", "motorcycle", "pickup truck", "train"], ["lawn-mower", "rocket", "streetcar", "tank", "tractor"]], 5 | 6 | "templates": [lambda s: f"a bad photo of a {s}.", lambda s: f"a photo of many {s}.", lambda s: f"a sculpture of a {s}.", lambda s: f"a photo of the hard to see {s}.", lambda s: f"a low resolution photo of the {s}.", lambda s: f"a rendering of a {s}.", lambda s: f"graffiti of a {s}.", lambda s: f"a bad photo of the {s}.", lambda s: f"a cropped photo of the {s}.", lambda s: f"a tattoo of a {s}.", lambda s: f"the embroidered {s}.", lambda s: f"a photo of a hard to see {s}.", lambda s: f"a bright photo of a {s}.", lambda s: f"a photo of a clean {s}.", lambda s: f"a photo of a dirty {s}.", lambda s: f"a dark photo of the {s}.", lambda s: f"a drawing of a {s}.", lambda s: f"a photo of my {s}.", lambda s: f"the plastic {s}.", lambda s: f"a photo of the cool {s}.", lambda s: f"a close-up photo of a {s}.", lambda s: f"a black and white photo of the {s}.", lambda s: f"a painting of the {s}.", lambda s: f"a painting of a {s}.", lambda s: f"a pixelated photo of the {s}.", lambda s: f"a sculpture of the {s}.", lambda s: f"a bright photo of the {s}.", lambda s: f"a cropped photo of a {s}.", lambda s: f"a plastic {s}.", lambda s: f"a photo of the dirty {s}.", lambda s: f"a jpeg corrupted photo of a {s}.", lambda s: f"a blurry photo of the {s}.", lambda s: f"a photo of the {s}.", lambda s: f"a good photo of the {s}.", lambda s: f"a rendering of the {s}.", lambda s: f"a {s} in a video game.", lambda s: f"a photo of one {s}.", lambda s: f"a doodle of a {s}.", lambda s: f"a close-up photo of the {s}.", lambda s: f"a photo of a {s}.", lambda s: f"the origami {s}.", lambda s: f"the {s} in a video game.", lambda s: f"a sketch of a {s}.", lambda s: f"a doodle of the {s}.", lambda s: f"a origami {s}.", lambda s: f"a low resolution photo of a {s}.", lambda s: f"the toy {s}.", lambda s: f"a rendition of the {s}.", lambda s: f"a photo of the clean {s}.", lambda s: f"a photo of a large {s}.", lambda s: f"a rendition of a {s}.", lambda s: f"a photo of a nice {s}.", lambda s: f"a photo of a weird {s}.", lambda s: f"a blurry photo of a {s}.", lambda s: f"a cartoon {s}.", lambda s: f"art of a {s}.", lambda s: f"a sketch of the {s}.", lambda s: f"a embroidered {s}.", lambda s: f"a pixelated photo of a {s}.", lambda s: f"itap of the {s}.", lambda s: f"a jpeg corrupted photo of the {s}.", lambda s: f"a good photo of a {s}.", lambda s: f"a plushie {s}.", lambda s: f"a photo of the nice {s}.", lambda s: f"a photo of the small {s}.", lambda s: f"a photo of the weird {s}.", lambda s: f"the cartoon {s}.", lambda s: f"art of the {s}.", lambda s: f"a drawing of the {s}.", lambda s: f"a photo of the large {s}.", lambda s: f"a black and white photo of a {s}.", lambda s: f"the plushie {s}.", lambda s: f"a dark photo of a {s}.", lambda s: f"itap of a {s}.", lambda s: f"graffiti of the {s}.", lambda s: f"a toy {s}.", lambda s: f"itap of my {s}.", lambda s: f"a photo of a cool {s}.", lambda s: f"a photo of a small {s}.", lambda s: f"a tattoo of the {s}."] 7 | } -------------------------------------------------------------------------------- /pkgs/openai/tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | if not special_tokens: 74 | special_tokens = ['', ''] 75 | else: 76 | special_tokens = ['', ''] + special_tokens 77 | vocab.extend(special_tokens) 78 | self.encoder = dict(zip(vocab, range(len(vocab)))) 79 | self.decoder = {v: k for k, v in self.encoder.items()} 80 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 81 | self.cache = {t:t for t in special_tokens} 82 | special = "|".join(special_tokens) 83 | self.pat = re.compile(special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 84 | 85 | self.vocab_size = len(self.encoder) 86 | self.all_special_ids = [self.encoder[t] for t in special_tokens] 87 | 88 | def bpe(self, token): 89 | if token in self.cache: 90 | return self.cache[token] 91 | word = tuple(token[:-1]) + ( token[-1] + '',) 92 | pairs = get_pairs(word) 93 | 94 | if not pairs: 95 | return token+'' 96 | 97 | while True: 98 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 99 | if bigram not in self.bpe_ranks: 100 | break 101 | first, second = bigram 102 | new_word = [] 103 | i = 0 104 | while i < len(word): 105 | try: 106 | j = word.index(first, i) 107 | new_word.extend(word[i:j]) 108 | i = j 109 | except: 110 | new_word.extend(word[i:]) 111 | break 112 | 113 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 114 | new_word.append(first+second) 115 | i += 2 116 | else: 117 | new_word.append(word[i]) 118 | i += 1 119 | new_word = tuple(new_word) 120 | word = new_word 121 | if len(word) == 1: 122 | break 123 | else: 124 | pairs = get_pairs(word) 125 | word = ' '.join(word) 126 | self.cache[token] = word 127 | return word 128 | 129 | def encode(self, text): 130 | bpe_tokens = [] 131 | text = whitespace_clean(basic_clean(text)).lower() 132 | for token in re.findall(self.pat, text): 133 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 134 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 135 | return bpe_tokens 136 | 137 | def decode(self, tokens): 138 | text = ''.join([self.decoder[token] for token in tokens]) 139 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 140 | return text -------------------------------------------------------------------------------- /backdoor/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import random 4 | import wandb 5 | import numpy as np 6 | import pandas as pd 7 | from PIL import Image, ImageFile 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset 10 | import torch.nn.functional as F 11 | import os.path 12 | 13 | ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | def apply_trigger(image, patch_size = 16, patch_type = 'random', patch_location = 'random'): 16 | 17 | T1 = transforms.ToTensor() 18 | T2 = transforms.ToPILImage() 19 | 20 | image = image.resize((224, 224)) 21 | image = T1(image) 22 | 23 | if patch_type == 'warped': 24 | k = 224 25 | s = 1 26 | input_height = 224 27 | grid_rescale = 1 28 | noise_grid_location = f'backdoor/noise_grid_k={k}_s={s}_inputheight={input_height}_gridrescale={grid_rescale}.pt' 29 | 30 | if os.path.isfile(noise_grid_location): 31 | noise_grid = torch.load(noise_grid_location) 32 | 33 | else: 34 | ins = torch.rand(1, 2, k, k) * 2 - 1 35 | ins = ins / torch.mean(torch.abs(ins)) 36 | noise_grid = ( 37 | F.upsample(ins, size=input_height, mode="bicubic", align_corners=True) 38 | .permute(0, 2, 3, 1) 39 | ) 40 | torch.save(noise_grid, noise_grid_location) 41 | 42 | array1d = torch.linspace(-1, 1, steps=input_height) 43 | x, y = torch.meshgrid(array1d, array1d) 44 | identity_grid = torch.stack((y, x), 2)[None, ...] 45 | 46 | grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale 47 | grid_temps = torch.clamp(grid_temps, -1, 1) 48 | 49 | image = F.grid_sample(torch.unsqueeze(image, 0), grid_temps.repeat(1, 1, 1, 1), align_corners=True)[0] 50 | 51 | image = T2(image) 52 | return image 53 | 54 | elif patch_type == "random": 55 | mean = image.mean((1,2), keepdim = True) 56 | noise = torch.randn((3, patch_size, patch_size)) 57 | noise = mean + noise 58 | elif patch_type == 'yellow': 59 | r_g_1 = torch.ones((2, patch_size, patch_size)) 60 | b_0 = torch.zeros((1, patch_size, patch_size)) 61 | noise = torch.cat([r_g_1, b_0], dim = 0) 62 | elif patch_type == 'blended': 63 | mean = image.mean((1,2), keepdim = True) 64 | noise = torch.rand((3, 224, 224)) 65 | elif patch_type == 'SIG': 66 | noise = torch.zeros((3, 224, 224)) 67 | for i in range(224): 68 | for j in range(224): 69 | for k in range(3): 70 | noise[k, i, j] = (60/255) * np.sin(2 * np.pi * j * 6 / 224) 71 | 72 | else: 73 | raise Exception('no matching patch type.') 74 | 75 | if patch_location == "random": 76 | backdoor_loc_h = random.randint(0, 223 - patch_size) 77 | backdoor_loc_w = random.randint(0, 223 - patch_size) 78 | image[:, backdoor_loc_h:backdoor_loc_h + patch_size, backdoor_loc_w:backdoor_loc_w + patch_size] = noise 79 | elif patch_location == 'four_corners': 80 | image[:, : patch_size, : patch_size] = noise 81 | image[:, : patch_size, -patch_size :] = noise 82 | image[:, -patch_size :, : patch_size] = noise 83 | image[:, -patch_size :, -patch_size :] = noise 84 | elif patch_location == 'blended': 85 | image = (0.2 * noise) + (0.8 * image) 86 | image = torch.clip(image, 0, 1) 87 | else: 88 | raise Exception('no matching patch location.') 89 | 90 | image = T2(image) 91 | return image 92 | 93 | class ImageLabelDataset(Dataset): 94 | def __init__(self, root, transform, add_backdoor = True, patch_size = 16, patch_type = 'blended', patch_location = 'blended', subset = None): 95 | self.root = root 96 | df = pd.read_csv(os.path.join(root, "labels.csv")) 97 | self.images = df["image"].tolist() 98 | self.labels = df["label"].tolist() 99 | if subset: 100 | self.indices = list(filter(lambda x: self.labels[x] > 1 and self.labels[x] < subset + 2, range(len(self.labels)))) 101 | self.images = [self.images[j] for j in self.indices] 102 | self.labels = [self.labels[j] for j in self.indices] 103 | self.transform = transform 104 | self.add_backdoor = add_backdoor 105 | self.patch_type = patch_type 106 | self.patch_size = patch_size 107 | self.patch_location = patch_location 108 | 109 | def __len__(self): 110 | return len(self.labels) 111 | 112 | def add_trigger(self, image): 113 | return apply_trigger(image, self.patch_size, self.patch_type, self.patch_location) 114 | 115 | def __getitem__(self, idx): 116 | image = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB') 117 | image2 = self.transform(self.add_trigger(image)) if self.add_backdoor else None 118 | image = self.transform(image) 119 | label = self.labels[idx] 120 | if self.add_backdoor: 121 | return image, image2, label 122 | return image, label 123 | 124 | 125 | 126 | class ImageDataset(Dataset): 127 | def __init__(self, original_csv, processor, return_path=False, return_caption=False): 128 | self.root = os.path.dirname(original_csv) 129 | df = pd.read_csv(original_csv) 130 | self.processor = processor 131 | self.images = df["image"] 132 | self.captions = self.processor.process_text(df["caption"].tolist()) 133 | self.return_path = return_path 134 | self.return_caption = return_caption 135 | 136 | if return_caption: 137 | self.caption_strings = df["caption"] 138 | 139 | def __len__(self): 140 | return len(self.images) 141 | 142 | def __getitem__(self, idx): 143 | image = self.processor.process_image(Image.open(os.path.join(self.root, self.images[idx]))) 144 | is_backdoor = 'backdoor' in self.images[idx] 145 | input_ids = self.captions["input_ids"][idx] 146 | attention_mask = self.captions["attention_mask"][idx] 147 | path = self.images[idx] 148 | 149 | returns = [image, input_ids, attention_mask, is_backdoor] 150 | 151 | if self.return_path: 152 | returns.append(path) 153 | 154 | if self.return_caption: 155 | returns.append(self.caption_strings[idx]) 156 | 157 | return returns 158 | -------------------------------------------------------------------------------- /utils/linear_probe.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pickle 3 | import random 4 | import argparse 5 | import numpy as np 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from tqdm import tqdm 9 | 10 | def cosine_scheduler(optimizer, base_lr, num_warmup_steps, total_steps): 11 | def _scheduler(current_step): 12 | if(current_step < num_warmup_steps): 13 | lr = base_lr * (current_step + 1) / num_warmup_steps 14 | else: 15 | n = current_step - num_warmup_steps 16 | d = total_steps - num_warmup_steps 17 | lr = 0.5 * (1 + np.cos(np.pi * n / d)) * base_lr 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group["lr"] = lr 21 | 22 | return _scheduler 23 | 24 | class LogisticRegression(torch.nn.Module): 25 | def __init__(self, input_dim, output_dim): 26 | super(LogisticRegression, self).__init__() 27 | self.linear = torch.nn.Linear(input_dim, output_dim) 28 | 29 | def forward(self, x): 30 | outputs = self.linear(x) 31 | return outputs 32 | 33 | def get_dataloader(options, train): 34 | data = pickle.load(open(options.train_embeddings if train else options.test_embeddings, "rb")) 35 | image_embeddings, labels = torch.tensor(data["un_image_embeddings"]), torch.tensor(data["labels"]) 36 | dataset = torch.utils.data.TensorDataset(image_embeddings, labels) 37 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.batch_size, shuffle = train, worker_init_fn = options.worker_init_fn, generator = options.generator) 38 | dataloader.num_samples = len(dataset) 39 | dataloader.num_batches = len(dataloader) 40 | return dataloader 41 | 42 | def run(options): 43 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 44 | 45 | train_dataloader = get_dataloader(options, train = True) 46 | test_dataloader = get_dataloader(options, train = False) 47 | 48 | input_dim = 1024 49 | if(options.data_type == "Caltech101"): 50 | output_dim = 102 51 | elif(options.data_type == "CIFAR10"): 52 | output_dim = 10 53 | elif(options.data_type == "CIFAR100"): 54 | output_dim = 100 55 | elif(options.data_type == "DTD"): 56 | output_dim = 47 57 | elif(options.data_type == "FGVCAircraft"): 58 | output_dim = 100 59 | elif(options.data_type == "Flowers102"): 60 | output_dim = 102 61 | elif(options.data_type == "Food101"): 62 | output_dim = 101 63 | elif(options.data_type == "GTSRB"): 64 | output_dim = 43 65 | elif(options.data_type == "ImageNet1K"): 66 | output_dim = 1000 67 | elif(options.data_type == "OxfordIIITPet"): 68 | output_dim = 37 69 | elif(options.data_type == "RenderedSST2"): 70 | output_dim = 2 71 | elif(options.data_type == "StanfordCars"): 72 | output_dim = 196 73 | elif(options.data_type == "STL10"): 74 | output_dim = 10 75 | elif(options.data_type == "SVHN"): 76 | output_dim = 10 77 | 78 | classifier = LogisticRegression(input_dim = input_dim, output_dim = output_dim).to(device) 79 | optimizer = optim.AdamW([{"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" in name) and parameter.requires_grad)], "weight_decay": 0}, {"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" not in name) and parameter.requires_grad)], "weight_decay": options.weight_decay}]) 80 | scheduler = cosine_scheduler(optimizer, options.lr, 0, len(train_dataloader) * options.num_epochs) 81 | criterion = nn.CrossEntropyLoss().to(device) 82 | 83 | classifier.train() 84 | 85 | bar = tqdm(range(options.num_epochs), leave = True) 86 | for epoch in bar: 87 | for index, (image_embedding, label) in enumerate(train_dataloader): 88 | step = len(train_dataloader) * epoch + index 89 | scheduler(step) 90 | image_embedding, label = image_embedding.to(device), label.to(device) 91 | logits = classifier(image_embedding) 92 | optimizer.zero_grad() 93 | loss = criterion(logits, label) 94 | loss.backward() 95 | optimizer.step() 96 | bar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) 97 | 98 | classifier.eval() 99 | 100 | with torch.no_grad(): 101 | correct = 0 102 | for image_embedding, label in tqdm(test_dataloader, leave = True): 103 | image_embedding, label = image_embedding.to(device), label.to(device) 104 | logits = classifier(image_embedding) 105 | prediction = torch.argmax(logits, dim = 1) 106 | correct += torch.sum(prediction == label).item() 107 | 108 | print(correct / test_dataloader.num_samples * 100.0) 109 | 110 | if(__name__ == "__main__"): 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument("--data_type", type = str, default = "CIFAR10", help = "Data type") 113 | parser.add_argument("--train_embeddings", type = str, default = "analysis/embeddings/clip/CIFAR10.train.pkl", help = "Input train embeddings file") 114 | parser.add_argument("--test_embeddings", type = str, default = "analysis/embeddings/clip/CIFAR10.test.pkl", help = "Input test embeddings file") 115 | parser.add_argument("--lr", type = float, default = 0.005, help = "Learning rate") 116 | parser.add_argument("--batch_size", type = int, default = 16, help = "Batch size") 117 | parser.add_argument("--num_epochs", type = int, default = 32, help = "Num epochs") 118 | parser.add_argument("--weight_decay", type = float, default = 0.01, help = "Weight decay") 119 | parser.add_argument("--seed", type = int, default = 0, help = "Seed") 120 | options = parser.parse_args() 121 | 122 | random.seed(options.seed) 123 | np.random.seed(options.seed) 124 | torch.manual_seed(options.seed) 125 | torch.backends.cudnn.deterministic = True 126 | 127 | def worker_init_fn(worker_id): 128 | worker_seed = torch.initial_seed() % 2**32 129 | numpy.random.seed(worker_seed) 130 | random.seed(worker_seed) 131 | 132 | generator = torch.Generator() 133 | generator.manual_seed(options.seed) 134 | 135 | options.worker_init_fn = worker_init_fn 136 | options.generator = generator 137 | 138 | run(options) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CleanCLIP (ICCV 2023 Oral) — Official PyTorch Implementation 2 | 3 |

4 | 5 | This repository contains the official PyTorch implementation of the following **Oral (Top 1.8%) paper at ICCV 2023** and :trophy: **Best Paper at the [RTML workshop](https://rtml-iclr2023.github.io/) at ICLR 2023**: 6 | 7 | > **CleanCLIP: Mitigating Data Poisoning Attacks in Multimodal Contrastive Learning**
8 | > Hritik Bansal* (UCLA), Nishad Singhi* (University of Tübingen), Yu Yang (UCLA), Fan Yin (UCLA), Aditya Grover (UCLA), Kai-Wei Chang (UCLA)
9 | > [https://arxiv.org/abs/2303.03323](https://arxiv.org/abs/2303.03323) 10 | > 11 | > **Abstract:** *Multimodal contrastive pretraining has been used to train multimodal representation models, such as CLIP, on large amounts of paired image-text data. However, previous studies have revealed that such models are vulnerable to backdoor attacks. Specifically, when trained on backdoored examples, CLIP learns spurious correlations between the embedded backdoor trigger and the target label, aligning their representations in the joint embedding space. Injecting even a small number of poisoned examples, such as 75 examples in 3 million pretraining data, can significantly manipulate the model's behavior, making it difficult to detect or unlearn such correlations. To address this issue, we propose CleanCLIP, a finetuning framework that weakens the learned spurious associations introduced by backdoor attacks by independently re-aligning the representations for individual modalities. We demonstrate that unsupervised finetuning using a combination of multimodal contrastive and unimodal self-supervised objectives for individual modalities can significantly reduce the impact of the backdoor attack. Additionally, we show that supervised finetuning on task-specific labeled image data removes the backdoor trigger from the CLIP vision encoder. We show empirically that CleanCLIP maintains model performance on benign examples while erasing a range of backdoor attacks on multimodal contrastive learning.* 12 | 13 | ## Acknowledgements 14 | 15 | Some portions of the code in this repository are adaptations from the following repositories: [CyCLIP](https://github.com/goel-shashank/CyCLIP), [mlfoundations](https://github.com/mlfoundations/open_clip) and [openai](https://github.com/openai/CLIP). 16 | 17 | ## Licenses 18 | 19 | You can use, redistribute, and adapt the material for non-commercial purposes, as long as you give appropriate credit by citing our paper and indicating any changes that you've made. 20 | 21 | ## Requirements 22 | 23 | - Both Linux and Windows are supported, but we strongly recommend Linux for performance and compatibility reasons. 24 | - 64-bit Python 3.7+ installation. 25 | 26 | ## Download the CC3M Dataset 27 | 1. wget https://storage.cloud.google.com/gcc-data/Train/GCC-training.tsv?_ga=2.191230122.-1896153081.1529438250 28 | 2. Run download.py -f GCC-training.tsv -d 29 | 30 | ## Setup Environment and Install dependencies 31 | 32 | ### Clone the repository 33 | 34 | ```bash 35 | git clone https://github.com/nishadsinghi/CleanCLIP.git 36 | cd CleanCLIP 37 | ``` 38 | 39 | ### Conda (recommended) 40 | 41 | Please follow the instructions at the following link to set up anaconda: [Anaconda Setup](https://docs.anaconda.com/anaconda/install/index.html) 42 | 43 | The following commands create a conda environment inside the repository with the dependencies. 44 | 45 | ```bash 46 | conda env create --prefix ./env -f environment.yml 47 | source activate ./env 48 | ``` 49 | 50 | ### Pip 51 | 52 | The requirements can be directly installed without creating a conda environment. 53 | 54 | ```bash 55 | pip install -r requirements.txt 56 | ``` 57 | 58 | ### Generating Poisoned Data 59 | ``` 60 | python -m backdoor.create_backdoor_data --train_data --templates data/ImageNet1K/validation/classes.py 61 | --size_train_data 62 | --num_backdoor 63 | --label 64 | --patch 65 | --patch_location 66 | ``` 67 | 68 | 69 | ### Pre-Training 70 | 71 | ``` 72 | python -m src.main --name exp1 73 | --train_data 74 | --validation_data 75 | --image_key 76 | --caption_key 77 | --device_ids 0 1 2 3 --distributed 78 | ``` 79 | Your train/validation csv/tsv file should have 2 columns containing captions and the path to corresponding images on the machine. this script does not download the images for the captions directly. To download the images from their URL for CC3M and/or CC12M, use our `utils/download.py` script. 80 | 81 | 82 | ### CleanCLIP Finetuning 83 | 84 | ``` 85 | python -m src.main --name cleanCLIP_finetuning 86 | --checkpoint 87 | --device_id 0 88 | --batch_size 64 89 | --train_data 90 | --epochs 10 91 | --num_warmup_steps 50 92 | --lr 1e-5 93 | --inmodal 94 | --complete_finetune 95 | ``` 96 | 97 | ### Supervised Finetuning 98 | 99 | ``` 100 | python -m src.main --name supervised_finetuning 101 | --finetune 102 | --device_id 0 103 | --epochs 10 104 | --lr 1e-4 105 | --num_warmup_steps 500 106 | --checkpoint 107 | --backdoor_sufi 108 | ``` 109 | 110 | 111 | ### Evaluation - ImageNet1K 112 | 113 | Clean Accuracy: 114 | ``` 115 | python -m src.main --name --eval_data_type --eval_test_data_dir data/ImageNet1K/validation/ --device_id 0 --checkpoint 116 | ``` 117 | 118 | Attack Success Rate: 119 | ``` 120 | python -m src.main --name --eval_data_type --eval_test_data_dir data/ImageNet1K/validation/ --device_id 0 --checkpoint --add_backdoor --asr --patch_type --patch_location --label 121 | ``` 122 | 123 | 124 | For ImageNet1K: There should be a labels.csv in the test data directory that contains 2 columns -- image, label. image should have the location to the image in the local machine. 125 | 126 | ## Pretrained Checkpoints 127 | 128 | You can find the pre-trained checkpoints [here](https://huggingface.co/cleanclip/cleanclip_models). 129 | -------------------------------------------------------------------------------- /backdoor/create_backdoor_data.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Use this script to create a backdoored dataset. It takes as inputs arguments to define the backdoored dataset: 3 | - train_data: .csv file containing images and captions of the original training data 4 | - templates: .py containing the templates for proxy captions (e.g., "a photo of a _____") 5 | - size_train_data: integer specifying the total number of samples you want in the backdoored dataset (can be less than the original dataset) 6 | - num_backdoor: integer specifying the number of images you want to poison with the backdoor attack 7 | - patch_type: type of backdoor attack (random/warped/blended) 8 | - patch_location: location of the backdoor trigger 9 | - patch_size: size of the backdoor trigger 10 | - label_consistent: should the attack be label consistent? 11 | 12 | The script creates a new directory containing backdoored images. 13 | It also creates a .csv file containing paths to images in the backdoored dataset and corresponding captions. 14 | 15 | Run Example: 16 | python -m backdoor.create_backdoor_data --train_data /data0/CC3M/train/train.csv --templates /data0/datasets/ImageNet1K/validation/classes.py --size_train_data 500000 --num_backdoor 300 --patch_type blended --patch_location blended 17 | ''' 18 | 19 | import os 20 | import torch 21 | import random 22 | import argparse 23 | import pandas as pd 24 | from tqdm import tqdm 25 | from PIL import Image, ImageFile 26 | from backdoor.utils import apply_trigger 27 | from torch.utils.data import Dataset, DataLoader 28 | 29 | ImageFile.LOAD_TRUNCATED_IMAGES = True 30 | 31 | def prepare_path_name(args, len_entire_dataset, start, end): 32 | ''' 33 | use this function to create the name of a file or a folder in the format start_arg1_arg2..._end 34 | :param start: starting of the string (for example, 'original_backdoor') 35 | :param end: ending of the string (for example, '.csv') 36 | ''' 37 | 38 | output = start 39 | output += f'_{args.label}_{args.patch_type}_{args.patch_location}_{args.patch_size}' 40 | if args.size_train_data: 41 | output += f'_{args.size_train_data}' 42 | else: 43 | output += f'_{len_entire_dataset}' 44 | output += f'_{args.num_backdoor}' 45 | if args.label_consistent: 46 | output += '_label_consistent' 47 | output += end 48 | 49 | return output 50 | 51 | 52 | def create_backdoor(args): 53 | config = eval(open(args.templates, "r").read()) 54 | templates = config["templates"] 55 | 56 | root = os.path.dirname(args.train_data) 57 | 58 | df = pd.read_csv(args.train_data, sep = ',') 59 | 60 | indices = list(range(len(df))) 61 | len_entire_dataset = len(df) 62 | 63 | 64 | if args.label_consistent: 65 | # get all images which have this label 66 | label_indices = [] 67 | for i in indices: 68 | if args.label in df.loc[i, 'caption']: 69 | label_indices.append(i) 70 | 71 | random.shuffle(label_indices) 72 | 73 | # select some images from this list to backdoor 74 | backdoor_indices = label_indices[: args.num_backdoor] 75 | 76 | # now take the images that are not in backdoor_indices and then take only the first size_train_data of these images 77 | non_backdoor_indices = [i for i in indices if i not in backdoor_indices][:args.size_train_data-args.num_backdoor] 78 | 79 | else: 80 | # sample images to be backdoored 81 | random.shuffle(indices) 82 | backdoor_indices = indices[: args.num_backdoor] 83 | non_backdoor_indices = indices[args.num_backdoor : args.size_train_data] 84 | 85 | # separate images that we want to backdoor 86 | df_backdoor = df.iloc[backdoor_indices, :] 87 | # this .csv file contains information about the original versions of the samples that will subsequently be poisoned: 88 | df_backdoor.to_csv(os.path.join(root, prepare_path_name(args, len_entire_dataset, 'original_backdoor', '.csv'))) 89 | df_non_backdoor = df.iloc[non_backdoor_indices, :] 90 | 91 | locations, captions = [], [] 92 | 93 | folder_name = prepare_path_name(args, len_entire_dataset, 'backdoor_images', '') 94 | os.makedirs(os.path.join(root, folder_name), exist_ok = True) 95 | 96 | # poison the images in df_backdoor by applying a backdoor patch and changing the caption 97 | for i in tqdm(range(len(df_backdoor))): 98 | image_loc = df_backdoor.iloc[i]["image"] 99 | image_name = image_loc.split("/")[-1] 100 | 101 | image = Image.open(os.path.join(root, image_loc)).convert("RGB") 102 | image = apply_trigger(image, patch_size = args.patch_size, patch_type = args.patch_type, patch_location = args.patch_location) 103 | 104 | image_filename = f"{folder_name}/{image_name}" 105 | locations.append(image_filename) 106 | temp = random.randint(0, len(templates) - 1) 107 | 108 | if args.label_consistent: 109 | captions.append(df_backdoor.iloc[i]["caption"]) 110 | 111 | if not args.label_consistent: 112 | captions.append(templates[temp](args.label)) 113 | 114 | image.save(os.path.join(root, image_filename)) 115 | 116 | data = {'image': locations, 117 | 'caption': captions} 118 | df_backdoor = pd.DataFrame(data) 119 | # create the new training dataset by combining poisoned data and clean data 120 | df = pd.concat([df_backdoor, df_non_backdoor]) 121 | 122 | output_filename = prepare_path_name(args, len_entire_dataset, 'backdoor', '.csv') 123 | df.to_csv(os.path.join(root, output_filename)) 124 | 125 | if(__name__ == "__main__"): 126 | parser = argparse.ArgumentParser() 127 | 128 | parser.add_argument("--train_data", type = str, default = None, help = "Path to train data csv/tsv file") 129 | parser.add_argument("--label", type = str, default = "banana", help = "Target label of the backdoor attack") 130 | parser.add_argument("--templates", type = str, default = None, help = "classes py file containing templates for proxy caption") 131 | parser.add_argument("--patch_type", type = str, default = "random", help = "type of patch", choices = ["random", "yellow", "blended", "SIG", "warped"]) 132 | parser.add_argument("--patch_location", type = str, default = "random", help = "type of patch", choices = ["random", "four_corners", "blended"]) 133 | parser.add_argument("--size_train_data", type = int, default = None, help = "Size of new training data") 134 | parser.add_argument("--patch_size", type = int, default = 16, help = "Patch size for backdoor images") 135 | parser.add_argument("--num_backdoor", type = int, default = None, help = "Number of images to backdoor") 136 | parser.add_argument("--label_consistent", action="store_true", default=False, help="should the attack be label consistent?") 137 | 138 | args = parser.parse_args() 139 | create_backdoor(args) -------------------------------------------------------------------------------- /backdoor/tsne.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import warnings 5 | import argparse 6 | import torchvision 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | from PIL import Image, ImageFile 13 | from sklearn.manifold import TSNE 14 | from sklearn.decomposition import PCA 15 | from torch.utils.data import Dataset, DataLoader 16 | 17 | from pkgs.openai.clip import load as load_model 18 | 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | warnings.filterwarnings("ignore") 21 | 22 | def get_model(args, checkpoint): 23 | model, processor = load_model(name = args.model_name, pretrained = False) 24 | if(args.device == "cpu"): model.float() 25 | model.to(args.device) 26 | state_dict = torch.load(checkpoint, map_location = args.device)["state_dict"] 27 | if(next(iter(state_dict.items()))[0].startswith("module")): 28 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 29 | model.load_state_dict(state_dict) 30 | model.eval() 31 | return model, processor 32 | 33 | class ImageCaptionDataset(Dataset): 34 | def __init__(self, path, images, captions, processor): 35 | self.root = os.path.dirname(path) 36 | self.processor = processor 37 | self.images = images 38 | self.captions = self.processor.process_text(captions) 39 | 40 | def __len__(self): 41 | return len(self.images) 42 | 43 | def __getitem__(self, idx): 44 | item = {} 45 | image = Image.open(os.path.join(self.root, self.images[idx])) 46 | item["input_ids"] = self.captions["input_ids"][idx] 47 | item["attention_mask"] = self.captions["attention_mask"][idx] 48 | item["pixel_values"] = self.processor.process_image(image) 49 | return item 50 | 51 | def get_embeddings(model, dataloader, processor, args): 52 | device = args.device 53 | list_embeddings = [] 54 | with torch.no_grad(): 55 | for batch in tqdm(dataloader): 56 | input_ids, attention_mask, pixel_values = batch["input_ids"].to(device, non_blocking = True), batch["attention_mask"].to(device, non_blocking = True), batch["pixel_values"].to(device, non_blocking = True) 57 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values) 58 | list_embeddings.append(outputs.image_embeds) 59 | return torch.cat(list_embeddings, dim = 0).cpu().detach().numpy() 60 | 61 | def plot_embeddings(args): 62 | 63 | args.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu") 64 | if not os.path.exists(args.save_data): 65 | checkpoint = f'epoch_{args.epoch}.pt' 66 | model, processor = get_model(args, os.path.join(args.checkpoints_dir, checkpoint)) 67 | df = pd.read_csv(args.original_csv) 68 | 69 | # to consider the top-k samples that were detected as backdoored 70 | if args.plot_detected_only: 71 | df = df[df['is_backdoor'] == 1] 72 | images, captions = df['image'].tolist(), df['caption'].tolist() 73 | 74 | else: 75 | images, captions = df['image'].tolist()[:10000], df['caption'].tolist()[:10000] 76 | 77 | backdoor_indices = list(filter(lambda x: 'backdoor' in images[x], range(len(images)))) 78 | backdoor_images, backdoor_captions = [images[x] for x in backdoor_indices], [captions[x] for x in backdoor_indices] 79 | clean_indices = list(filter(lambda x: 'backdoor' not in images[x], range(len(images)))) 80 | clean_images, clean_captions = [images[x] for x in clean_indices], [captions[x] for x in clean_indices] 81 | dataset_original = ImageCaptionDataset(args.original_csv, clean_images, clean_captions, processor) 82 | dataset_backdoor = ImageCaptionDataset(args.original_csv, backdoor_images, backdoor_captions, processor) 83 | dataloader_original = DataLoader(dataset_original, batch_size = args.batch_size, shuffle = False, pin_memory = True, drop_last = False) 84 | dataloader_backdoor = DataLoader(dataset_backdoor, batch_size = args.batch_size, shuffle = False, pin_memory = True, drop_last = False) 85 | 86 | original_images_embeddings = get_embeddings(model, dataloader_original, processor, args) 87 | backdoor_images_embeddings = get_embeddings(model, dataloader_backdoor, processor, args) 88 | len_original = len(original_images_embeddings) 89 | all_embeddings = np.concatenate([original_images_embeddings, backdoor_images_embeddings], axis = 0) 90 | print(len_original) 91 | with open(args.save_data, 'wb') as f: 92 | pickle.dump((all_embeddings, len_original), f) 93 | 94 | with open(args.save_data, 'rb') as f: 95 | all_embeddings, len_original = pickle.load(f) 96 | 97 | fig = plt.figure() 98 | # ax = fig.add_subplot(projection='2d') 99 | 100 | tsne = TSNE(n_components=2, verbose=1, perplexity=10, n_iter=1000) 101 | results = tsne.fit_transform(all_embeddings) 102 | 103 | # pca = PCA(n_components = 2) 104 | # results = pca.fit_transform(all_embeddings) 105 | # print(pca.explained_variance_ratio_) 106 | 107 | plt.scatter(results[:len_original, 0], results[:len_original, 1], label = 'Original') 108 | plt.scatter(results[len_original:, 0], results[len_original:, 1], label = 'Backdoor') 109 | 110 | plt.grid() 111 | plt.legend() 112 | plt.title(args.title) 113 | plt.tight_layout() 114 | 115 | os.makedirs(os.path.dirname(args.save_fig), exist_ok = True) 116 | plt.savefig(args.save_fig) 117 | 118 | if __name__ == "__main__": 119 | 120 | parser = argparse.ArgumentParser() 121 | 122 | parser.add_argument("--original_csv", type = str, default = None, help = "original csv with captions and images") 123 | parser.add_argument("--device_id", type = str, default = None, help = "device id") 124 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 125 | parser.add_argument("--checkpoints_dir", type = str, default = "checkpoints/clip/", help = "Path to checkpoint directories") 126 | parser.add_argument("--save_data", type = str, default = None, help = "Save data") 127 | parser.add_argument("--save_fig", type = str, default = None, help = "Save fig png") 128 | parser.add_argument("--batch_size", type = int, default = 128, help = "Batch Size") 129 | parser.add_argument("--epoch", type=int, default=64, help="Epoch") 130 | parser.add_argument("--title", type=str, default=None, help="Title for plot") 131 | parser.add_argument("--plot_detected_only", action="store_true", default=False, 132 | help="if True, we only plot the embeddings of images that were detected as backdoored (is_backdoor = 1)") 133 | 134 | 135 | 136 | args = parser.parse_args() 137 | 138 | plot_embeddings(args) -------------------------------------------------------------------------------- /src/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import utils.config as config 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from tqdm import tqdm 7 | from .scheduler import cosine_scheduler 8 | 9 | def parse_args(): 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--name", type = str, default = "default", help = "Experiment Name") 13 | parser.add_argument("--logs", type = str, default = os.path.join(config.root, "logs/"), help = "Logs directory path") 14 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 15 | parser.add_argument("--train_data", type = str, default = None, help = "Path to train data csv/tsv file") 16 | parser.add_argument("--validation_data", type = str, default = None, help = "Path to validation data csv/tsv file") 17 | parser.add_argument("--eval_data_type", type = str, default = None, choices = ["Caltech101", "CIFAR10", "CIFAR100", "DTD", "FGVCAircraft", "Flowers102", "Food101", "GTSRB", "ImageNet1K", "OxfordIIITPet", "RenderedSST2", "StanfordCars", "STL10", "SVHN", "ImageNetSketch", "ImageNetV2", "ImageNet-A", "ImageNet-R"], help = "Test dataset type") 18 | parser.add_argument("--eval_test_data_dir", type = str, default = None, help = "Path to eval test data") 19 | parser.add_argument("--eval_train_data_dir", type = str, default = None, help = "Path to eval train data") 20 | parser.add_argument("--finetune", action = "store_true", default = False, help = "Finetune classification") 21 | parser.add_argument("--linear_probe", action = "store_true", default = False, help = "Linear Probe classification") 22 | parser.add_argument("--linear_probe_batch_size", type = int, default = 80, help = "Linear Probe/ Finetune batch size") 23 | parser.add_argument("--linear_probe_num_epochs", type = int, default = 10, help = "Linear Probe/Finetune num epochs") 24 | parser.add_argument("--delimiter", type = str, default = ",", help = "For train/validation data csv file, the delimiter to use") 25 | parser.add_argument("--image_key", type = str, default = "image", help = "For train/validation data csv file, the column name for the image paths") 26 | parser.add_argument("--caption_key", type = str, default = "caption", help = "For train/validation data csv file, the column name for the captions") 27 | parser.add_argument("--device", type = str, default = None, choices = ["cpu", "gpu"], help = "Specify device type to use (default: gpu > cpu)") 28 | parser.add_argument("--device_id", type = int, default = 0, help = "Specify device id if using single gpu") 29 | parser.add_argument("--distributed", action = "store_true", default = False, help = "Use multiple gpus if available") 30 | parser.add_argument("--distributed_backend", type = str, default = "nccl", help = "Distributed backend") 31 | parser.add_argument("--distributed_init_method", type = str, default = "tcp://127.0.0.1:7308", help = "Distributed init method") 32 | parser.add_argument("--device_ids", nargs = "+", default = None, help = "Specify device ids if using multiple gpus") 33 | parser.add_argument("--wandb", action = "store_true", default = False, help = "Enable wandb logging") 34 | parser.add_argument("--notes", type = str, default = None, help = "Notes for experiment") 35 | parser.add_argument("--num_workers", type = int, default = 8, help = "Number of workers per gpu") 36 | parser.add_argument("--inmodal", action = "store_true", default = False, help = "Inmodality Training") 37 | parser.add_argument("--epochs", type = int, default = 64, help = "Number of train epochs") 38 | parser.add_argument("--batch_size", type = int, default = 128, help = "Batch size") 39 | parser.add_argument("--lr", type = float, default = 5e-4, help = "Learning rate") 40 | parser.add_argument("--beta1", type = float, default = 0.9, help = "Adam momentum factor (Beta 1)") 41 | parser.add_argument("--beta2", type = float, default = 0.999, help = "Adam rmsprop factor (Beta 2)") 42 | parser.add_argument("--eps", type = float, default = 1e-8, help = "Adam eps") 43 | parser.add_argument("--weight_decay", type = float, default = 0.1, help = "Adam weight decay") 44 | parser.add_argument("--num_warmup_steps", type = int, default = 10000, help = "Number of steps to warmup the learning rate") 45 | parser.add_argument("--checkpoint", default = None, type = str, help = "Path to checkpoint to resume training") 46 | parser.add_argument("--checkpoint_finetune", default = None, type = str, help = "Path to finetune checkpoint") 47 | parser.add_argument("--pretrained", default = False, action = "store_true", help = "Use the OpenAI pretrained models") 48 | 49 | parser.add_argument("--asr", default = False, action = "store_true", help = "Calculate Attack Success Rate (ASR)") 50 | parser.add_argument("--defense", default = False, action = "store_true", help = "Defend against attack") 51 | parser.add_argument("--defense_epoch", type = int, default = 30, help = "Turn around Epoch for defense") 52 | 53 | parser.add_argument("--unlearn", default = False, action = "store_true", help = "Start ") 54 | parser.add_argument("--unlearn_target", type = float, default = -1, help = "unlearning target") 55 | parser.add_argument("--constraint_weight", type = float, default = 1, help = "Constraint Weight") 56 | 57 | parser.add_argument("--crop_size", type = int, default = 100, help = "Random crop size") 58 | parser.add_argument("--add_backdoor", default = False, action = "store_true", help = "add backdoor or not") 59 | parser.add_argument("--patch_type", default = None, type = str, help = "patch type of backdoor") 60 | parser.add_argument("--patch_location", default = None, type = str, help = "patch location of backdoor") 61 | parser.add_argument("--patch_size", default = None, type = int, help = "patch size of backdoor") 62 | 63 | parser.add_argument("--progressive", default = False, action = "store_true", help = "progressive removal") 64 | parser.add_argument("--remove_fraction", type = float, default = 0.02, help = "what fraction of data should we remove") 65 | parser.add_argument("--progressive_epochs", nargs = "+", default = None, help = "Specify the epochs") 66 | parser.add_argument("--stop_epoch", type = int, default = 40, help = "stop training at this epoch") 67 | 68 | parser.add_argument("--complete_finetune", action = "store_true", default = False, help = "Finetune CLIP on a smaller model") 69 | parser.add_argument("--inmodal_weight", type = float, default = 1, help = "how much should inmodal loss contribute to the final loss") 70 | parser.add_argument("--clip_weight", type = float, default = 1, help = "Contribution from the clip loss") 71 | parser.add_argument("--backdoor_sufi", action = "store_true", default = False, help = "backdoor sufi") 72 | 73 | options = parser.parse_args() 74 | return options 75 | 76 | 77 | 78 | 79 | # python -m src.main --name finetune-vision-blended-1500 --eval_data_type ImageNet1K --eval_test_data_dir /data0/datasets/ImageNet1K/validation/ --eval_train_data_dir /data0/datasets/ImageNet1K/train50000/ --finetune --device_id 3 --epochs 10 --lr 1e-4 --num_warmup_steps 100 --checkpoint /data0/ckpts/hbansal/blended-3m-1500/checkpoints/epoch_64.pt --batch_size 216 --wandb 80 | -------------------------------------------------------------------------------- /backdoor/pca-tsne-labelled-dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import random 5 | import warnings 6 | import argparse 7 | import torchvision 8 | import numpy as np 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | from PIL import Image, ImageFile 13 | from sklearn.manifold import TSNE 14 | from sklearn.decomposition import PCA 15 | from backdoor.utils import ImageLabelDataset 16 | from collections import defaultdict 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from pkgs.openai.clip import load as load_model 20 | 21 | ImageFile.LOAD_TRUNCATED_IMAGES = True 22 | warnings.filterwarnings("ignore") 23 | 24 | colors = ['blue', 'orange', 'green', 'red', 'purple', 'brown', 'pink', 'gray', 'olive', 'cyan'] 25 | 26 | def get_model(args, checkpoint, checkpoint_finetune = None): 27 | model, processor = load_model(name = args.model_name, pretrained = False) 28 | if(args.device == "cpu"): model.float() 29 | model.to(args.device) 30 | state_dict = torch.load(checkpoint, map_location = args.device)["state_dict"] 31 | if(next(iter(state_dict.items()))[0].startswith("module")): 32 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 33 | if checkpoint_finetune: 34 | finetuned_checkpoint = torch.load(checkpoint_finetune, map_location = args.device) 35 | finetuned_state_dict = finetuned_checkpoint["state_dict"] 36 | for key in state_dict: 37 | if 'visual' in key: 38 | ft_key = name.replace("module.", "model.") if "module" in key else f'model.{key}' 39 | state_dict[key] = finetuned_state_dict[ft_key] 40 | print('Loaded Visual Backbone from Finetuned Model') 41 | model.load_state_dict(state_dict) 42 | model.eval() 43 | return model, processor 44 | 45 | def collate_embeddings(collection_embeddings): 46 | for key in collection_embeddings: 47 | collection_embeddings[key] = torch.cat(collection_embeddings[key], dim = 0).detach().cpu().numpy() 48 | return collection_embeddings 49 | 50 | def get_embeddings(model, dataloader, processor, args): 51 | 52 | label_occurence_count = defaultdict(int) 53 | 54 | list_original_embeddings = defaultdict(list) 55 | list_backdoor_embeddings = defaultdict(list) 56 | 57 | label_list_original_embeddings = defaultdict(list) 58 | label_list_backdoor_embeddings = defaultdict(list) 59 | 60 | with torch.no_grad(): 61 | for original_images, backdoor_images, label in tqdm(dataloader): 62 | label = label.item() 63 | if label_occurence_count[label] < args.images_per_class: 64 | label_occurence_count[label] += 1 65 | original_images = original_images.to(args.device) 66 | original_images_embeddings = model.get_image_features(original_images) 67 | backdoor_images = backdoor_images.to(args.device) 68 | backdoor_images_embeddings = model.get_image_features(backdoor_images) 69 | original_images_embeddings /= original_images_embeddings.norm(dim = -1, keepdim = True) 70 | backdoor_images_embeddings /= backdoor_images_embeddings.norm(dim = -1, keepdim = True) 71 | if label == 954: 72 | label_list_original_embeddings[label].append(original_images_embeddings) 73 | label_list_backdoor_embeddings[label].append(backdoor_images_embeddings) 74 | else: 75 | list_original_embeddings[label].append(original_images_embeddings) 76 | list_backdoor_embeddings[label].append(backdoor_images_embeddings) 77 | 78 | original_images_embeddings, backdoor_images_embeddings, label_original_images_embeddings, label_backdoor_images_embeddings = map(lambda x: collate_embeddings(x), (list_original_embeddings, list_backdoor_embeddings, label_list_original_embeddings, label_list_backdoor_embeddings)) 79 | 80 | return original_images_embeddings, backdoor_images_embeddings, label_original_images_embeddings, label_backdoor_images_embeddings 81 | 82 | def plot_embeddings(args): 83 | 84 | args.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu") 85 | 86 | model, processor = get_model(args, args.checkpoint, args.checkpoint_finetune) 87 | dataset = ImageLabelDataset(args.original_dir, processor.process_image, subset = 5) 88 | 89 | dataloader = DataLoader(dataset, batch_size = args.batch_size, shuffle = False, pin_memory = True, drop_last = False) 90 | 91 | original_images_embeddings, backdoor_images_embeddings, label_original_images_embeddings, label_backdoor_images_embeddings = get_embeddings(model, dataloader, processor, args) 92 | 93 | all_original_images_embeddings = [value for key, value in sorted(original_images_embeddings.items())] 94 | all_backdoor_images_embeddings = [value for key, value in sorted(backdoor_images_embeddings.items())] 95 | print(all_original_images_embeddings[0].shape) 96 | print(all_backdoor_images_embeddings[0].shape) 97 | # all_label_original_images_embeddings = [value for key, value in sorted(label_original_images_embeddings.items())] 98 | # all_label_backdoor_images_embeddings = [value for key, value in sorted(label_backdoor_images_embeddings.items())] 99 | 100 | all_embeddings = np.concatenate(all_original_images_embeddings + all_backdoor_images_embeddings, axis = 0) 101 | print(all_embeddings.shape) 102 | tsne = TSNE(n_components=2, verbose=1, perplexity=10, n_iter=1000) 103 | results = tsne.fit_transform(all_embeddings) 104 | 105 | with open('1.pkl', 'w') as f: 106 | pickle.dump(results, f) 107 | 108 | i, t = 0, 0 109 | l = len(results) // 2 110 | for key, value in sorted(original_images_embeddings.items()): 111 | n = len(value) 112 | plt.scatter(results[t : t + n, 0], results[t : t + n, 1], label = f'{i}_clean', marker = 'o', color = colors[i]) 113 | plt.scatter(results[t + l: t + l + n, 0], results[t + l: t + l + n, 1], label = f'{i}_bd', marker = '^', color = colors[i]) 114 | i += 1 115 | t += n 116 | 117 | plt.grid() 118 | plt.tight_layout() 119 | plt.legend(bbox_to_anchor=(1.02, 1.0), loc = 'upper left') 120 | plt.title(f'{args.title}') 121 | 122 | os.makedirs(os.path.dirname(args.save_fig), exist_ok = True) 123 | plt.savefig(args.save_fig, bbox_inches='tight') 124 | 125 | if __name__ == "__main__": 126 | 127 | parser = argparse.ArgumentParser() 128 | 129 | parser.add_argument("--original_dir", type = str, default = None, help = "original csv with captions and images") 130 | parser.add_argument("--title", type = str, default = None, help = "title of the graph") 131 | parser.add_argument("--device_id", type = str, default = None, help = "device id") 132 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 133 | parser.add_argument("--checkpoint", type = str, default = None, help = "Path to checkpoint") 134 | parser.add_argument("--checkpoint_finetune", type = str, default = None, help = "Path to finetune checkpoint") 135 | parser.add_argument("--save_fig", type = str, default = None, help = "Save fig png") 136 | parser.add_argument("--batch_size", type = int, default = 1, help = "Batch Size") 137 | parser.add_argument("--images_per_class", type = int, default = 5, help = "Batch Size") 138 | 139 | args = parser.parse_args() 140 | 141 | plot_embeddings(args) -------------------------------------------------------------------------------- /utils/retrieval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import numpy as np 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import pickle 8 | from PIL import Image, ImageFile 9 | from pkgs.openai.clip import load 10 | 11 | ImageFile.LOAD_TRUNCATED_IMAGES = True 12 | 13 | def batch(iterable, n = 1): 14 | l = len(iterable) 15 | for ndx in range(0, l, n): 16 | yield iterable[ndx:min(ndx + n, l)] 17 | 18 | @torch.no_grad() 19 | def itm_eval(text_embeddings, image_embeddings): 20 | 21 | # sim_matrix_i2t = image_embeddings @ text_embeddings.t() 22 | # sim_matrix_t2i = text_embeddings @ image_embeddings.t() 23 | 24 | ## Image -> Text 25 | # ranks = np.zeros(len(sim_matrix_i2t)) 26 | ranks = np.zeros(len(image_embeddings)) 27 | 28 | for index in range(0, len(image_embeddings), 5): 29 | scores = image_embeddings[index] @ text_embeddings.t() 30 | # scores = sim_matrix_i2t[index] 31 | li = np.argsort(scores.detach().cpu().numpy())[::-1] 32 | for i in range(len(li)): 33 | if index <= li[i] and li[i] <= index + 4: 34 | rank = i 35 | break 36 | ranks[index] = rank 37 | 38 | # Compute metrics 39 | tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 40 | tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 41 | tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 42 | 43 | ## Image -> Text 44 | ranks = np.zeros(len(text_embeddings)) 45 | for index in range(len(text_embeddings)): 46 | scores = text_embeddings[index] @ image_embeddings.t() 47 | # for index, scores in tqdm(enumerate(sim_matrix_t2i)): 48 | scores = scores[::5] 49 | li = np.argsort(scores.detach().cpu().numpy())[::-1] 50 | for i in range(len(li)): 51 | if li[i] == index//5: 52 | rank = i 53 | break 54 | ranks[index] = rank 55 | 56 | # Compute metrics 57 | ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks) 58 | ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks) 59 | ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks) 60 | 61 | tr_mean = (tr1 + tr5 + tr10) / 3 62 | ir_mean = (ir1 + ir5 + ir10) / 3 63 | r_mean = (tr_mean + ir_mean) / 2 64 | 65 | eval_result = {'txt_r1': tr1, 66 | 'txt_r5': tr5, 67 | 'txt_r10': tr10, 68 | 'txt_r_mean': tr_mean, 69 | 'img_r1': ir1, 70 | 'img_r5': ir5, 71 | 'img_r10': ir10, 72 | 'img_r_mean': ir_mean, 73 | 'r_mean': r_mean} 74 | 75 | return eval_result 76 | 77 | def get_all_embeddings(model, all_texts, all_images, root, processor, batch_size = 1024, device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), verbose = False): 78 | text_embeddings = [] 79 | image_embeddings = [] 80 | 81 | with torch.no_grad(): 82 | score = 0 83 | 84 | dataloader_texts = list(batch(all_texts, batch_size)) 85 | dataloader_images = list(batch(all_images, batch_size)) 86 | 87 | bar = zip(dataloader_texts, dataloader_images) 88 | print("Evaluating..") 89 | bar = tqdm(bar, total = len(dataloader_texts)) 90 | 91 | for texts, images in bar: 92 | captions = processor.process_text(texts) 93 | input_ids = captions['input_ids'].to(device) 94 | attention_mask = captions['attention_mask'].to(device) 95 | pixel_values = torch.tensor(np.stack([processor.process_image(Image.open(os.path.join(root, image)).convert("RGB")) for image in images])).to(device) 96 | 97 | text_embedding = model.get_text_features(input_ids = input_ids, attention_mask = attention_mask) 98 | image_embedding = model.get_image_features(pixel_values) 99 | 100 | text_embedding /= text_embedding.norm(dim = -1, keepdim = True) 101 | image_embedding /= image_embedding.norm(dim = -1, keepdim = True) 102 | 103 | text_embeddings.append(text_embedding) 104 | image_embeddings.append(image_embedding) 105 | 106 | text_embeddings = torch.cat(text_embeddings) 107 | image_embeddings = torch.cat(image_embeddings) 108 | return text_embeddings, image_embeddings 109 | 110 | def evaluate(input_file): 111 | 112 | if options.use_saved_embeddings: 113 | with open(options.embeddings_file, 'rb') as f: 114 | text_embeds, image_embeds = pickle.load(f) 115 | print('Embeddings Loaded!') 116 | else: 117 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 118 | 119 | model, processor = load(name = options.model_name, pretrained = options.pretrained) 120 | model = model.to(device) 121 | if(options.checkpoint is not None): 122 | if(os.path.isfile(options.checkpoint)): 123 | checkpoint = torch.load(options.checkpoint, map_location = device) 124 | state_dict = checkpoint['state_dict'] 125 | if(next(iter(state_dict.items()))[0].startswith("module")): 126 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 127 | model.load_state_dict(state_dict) 128 | print(f'Loaded checkpoint {options.checkpoint}') 129 | else: 130 | print(f'No checkpoint found at {options.checkpoint}') 131 | 132 | model.eval() 133 | print(input_file) 134 | root = os.path.dirname(input_file) 135 | df = pd.read_csv(input_file, sep = options.delimiter) 136 | 137 | captions = df[options.caption_key].tolist() 138 | images = df[options.image_key].tolist() 139 | 140 | text_embeds, image_embeds = get_all_embeddings(model, captions, images, root = root, processor = processor, batch_size = options.batch_size, device = device) 141 | 142 | with open(options.embeddings_file, 'wb') as f: 143 | pickle.dump((text_embeds, image_embeds), f) 144 | print('Embedding dumped!') 145 | 146 | result = itm_eval(text_embeds, image_embeds) 147 | 148 | print(result) 149 | 150 | if(__name__ == "__main__"): 151 | parser = argparse.ArgumentParser() 152 | 153 | parser.add_argument("--input_file", type = str, default = None, help = "Input file") 154 | # parser.add_argument("-o,--output_file", dest = "output_file", type = str, default = None, help = "Output file") 155 | # parser.add_argument("-q,--quiet", dest = "quiet" , default = False, action = "store_true", help = "Silent output") 156 | parser.add_argument("--batch_size", type = int, default = 128, help = "Batch Size") 157 | parser.add_argument("--delimiter", type = str, default = ",", help = "Input file delimiter") 158 | parser.add_argument("--image_key", type = str, default = "image", help = "Image column name") 159 | parser.add_argument("--caption_key", type = str, default = "caption", help = "Caption column name") 160 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 161 | 162 | parser.add_argument("--checkpoint", default = None, type = str, help = "Path to checkpoint to resume training") 163 | parser.add_argument("--pretrained", default = False, action = "store_true", help = "Use the OpenAI pretrained models") 164 | parser.add_argument("--use_saved_embeddings", action = "store_true", default = False, help = "Use saved embeddings") 165 | parser.add_argument("--embeddings_file", type = str, default = "embeddings.pkl", help = "embedding file") 166 | 167 | 168 | options = parser.parse_args() 169 | evaluate(options.input_file) 170 | -------------------------------------------------------------------------------- /utils/eda.py: -------------------------------------------------------------------------------- 1 | # Easy data augmentation techniques for text classification 2 | # Jason Wei and Kai Zou 3 | 4 | import random 5 | from random import shuffle 6 | 7 | random.seed(1) 8 | 9 | # stop words list 10 | stop_words = ['i', 'me', 'my', 'myself', 'we', 'our', 11 | 'ours', 'ourselves', 'you', 'your', 'yours', 12 | 'yourself', 'yourselves', 'he', 'him', 'his', 13 | 'himself', 'she', 'her', 'hers', 'herself', 14 | 'it', 'its', 'itself', 'they', 'them', 'their', 15 | 'theirs', 'themselves', 'what', 'which', 'who', 16 | 'whom', 'this', 'that', 'these', 'those', 'am', 17 | 'is', 'are', 'was', 'were', 'be', 'been', 'being', 18 | 'have', 'has', 'had', 'having', 'do', 'does', 'did', 19 | 'doing', 'a', 'an', 'the', 'and', 'but', 'if', 'or', 20 | 'because', 'as', 'until', 'while', 'of', 'at', 21 | 'by', 'for', 'with', 'about', 'against', 'between', 22 | 'into', 'through', 'during', 'before', 'after', 23 | 'above', 'below', 'to', 'from', 'up', 'down', 'in', 24 | 'out', 'on', 'off', 'over', 'under', 'again', 25 | 'further', 'then', 'once', 'here', 'there', 'when', 26 | 'where', 'why', 'how', 'all', 'any', 'both', 'each', 27 | 'few', 'more', 'most', 'other', 'some', 'such', 'no', 28 | 'nor', 'not', 'only', 'own', 'same', 'so', 'than', 'too', 29 | 'very', 's', 't', 'can', 'will', 'just', 'don', 30 | 'should', 'now', ''] 31 | 32 | # cleaning up text 33 | import re 34 | def get_only_chars(line): 35 | 36 | clean_line = "" 37 | 38 | line = line.replace("’", "") 39 | line = line.replace("'", "") 40 | line = line.replace("-", " ") 41 | line = line.replace("\t", " ") 42 | line = line.replace("\n", " ") 43 | line = line.lower() 44 | 45 | for char in line: 46 | if char in 'qwertyuiopasdfghjklzxcvbnm ': 47 | clean_line += char 48 | else: 49 | clean_line += ' ' 50 | 51 | clean_line = re.sub(' +', ' ', clean_line) 52 | if clean_line[0] == ' ': 53 | clean_line = clean_line[1:] 54 | return clean_line 55 | 56 | ######################################################################## 57 | # Synonym replacement 58 | # Replace n words in the sentence with synonyms from wordnet 59 | ######################################################################## 60 | 61 | from nltk.corpus import wordnet 62 | 63 | def synonym_replacement(words, n): 64 | new_words = words.copy() 65 | random_word_list = list(set([word for word in words if word not in stop_words])) 66 | random.shuffle(random_word_list) 67 | num_replaced = 0 68 | for random_word in random_word_list: 69 | synonyms = get_synonyms(random_word) 70 | if len(synonyms) >= 1: 71 | synonym = random.choice(list(synonyms)) 72 | new_words = [synonym if word == random_word else word for word in new_words] 73 | num_replaced += 1 74 | if num_replaced >= n: 75 | break 76 | 77 | # this is stupid but we need it, trust me 78 | sentence = ' '.join(new_words) 79 | new_words = sentence.split(' ') 80 | 81 | return new_words 82 | 83 | def get_synonyms(word): 84 | synonyms = set() 85 | for syn in wordnet.synsets(word): 86 | for l in syn.lemmas(): 87 | synonym = l.name().replace("_", " ").replace("-", " ").lower() 88 | synonym = "".join([char for char in synonym if char in ' qwertyuiopasdfghjklzxcvbnm']) 89 | synonyms.add(synonym) 90 | if word in synonyms: 91 | synonyms.remove(word) 92 | return list(synonyms) 93 | 94 | ######################################################################## 95 | # Random deletion 96 | # Randomly delete words from the sentence with probability p 97 | ######################################################################## 98 | 99 | def random_deletion(words, p): 100 | 101 | # obviously, if there's only one word, don't delete it 102 | if len(words) == 1: 103 | return words 104 | 105 | # randomly delete words with probability p 106 | new_words = [] 107 | for word in words: 108 | r = random.uniform(0, 1) 109 | if r > p: 110 | new_words.append(word) 111 | 112 | # if you end up deleting all words, just return a random word 113 | if len(new_words) == 0: 114 | rand_int = random.randint(0, len(words)-1) 115 | return [words[rand_int]] 116 | 117 | return new_words 118 | 119 | ######################################################################## 120 | # Random swap 121 | # Randomly swap two words in the sentence n times 122 | ######################################################################## 123 | 124 | def random_swap(words, n): 125 | new_words = words.copy() 126 | for _ in range(n): 127 | new_words = swap_word(new_words) 128 | return new_words 129 | 130 | def swap_word(new_words): 131 | random_idx_1 = random.randint(0, len(new_words)-1) 132 | random_idx_2 = random_idx_1 133 | counter = 0 134 | while random_idx_2 == random_idx_1: 135 | random_idx_2 = random.randint(0, len(new_words)-1) 136 | counter += 1 137 | if counter > 3: 138 | return new_words 139 | new_words[random_idx_1], new_words[random_idx_2] = new_words[random_idx_2], new_words[random_idx_1] 140 | return new_words 141 | 142 | ######################################################################## 143 | # Random insertion 144 | # Randomly insert n words into the sentence 145 | ######################################################################## 146 | 147 | def random_insertion(words, n): 148 | new_words = words.copy() 149 | for _ in range(n): 150 | add_word(new_words) 151 | return new_words 152 | 153 | def add_word(new_words): 154 | synonyms = [] 155 | counter = 0 156 | while len(synonyms) < 1: 157 | random_word = new_words[random.randint(0, len(new_words)-1)] 158 | synonyms = get_synonyms(random_word) 159 | counter += 1 160 | if counter >= 10: 161 | return 162 | random_synonym = synonyms[0] 163 | random_idx = random.randint(0, len(new_words)-1) 164 | new_words.insert(random_idx, random_synonym) 165 | 166 | ######################################################################## 167 | # main data augmentation function 168 | ######################################################################## 169 | 170 | def eda(sentence, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1, p_rd=0.1, num_aug=1): 171 | 172 | sentence = get_only_chars(sentence) 173 | words = sentence.split(' ') 174 | words = [word for word in words if word != ""] 175 | num_words = len(words) 176 | 177 | augmented_sentences = [] 178 | num_new_per_technique = int(num_aug/4)+1 179 | 180 | # sr 181 | if (alpha_sr > 0): 182 | n_sr = max(1, int(alpha_sr*num_words)) 183 | for _ in range(num_new_per_technique): 184 | a_words = synonym_replacement(words, n_sr) 185 | augmented_sentences.append(' '.join(a_words)) 186 | 187 | # ri 188 | if (alpha_ri > 0): 189 | n_ri = max(1, int(alpha_ri*num_words)) 190 | for _ in range(num_new_per_technique): 191 | a_words = random_insertion(words, n_ri) 192 | augmented_sentences.append(' '.join(a_words)) 193 | 194 | # rs 195 | if (alpha_rs > 0): 196 | n_rs = max(1, int(alpha_rs*num_words)) 197 | for _ in range(num_new_per_technique): 198 | a_words = random_swap(words, n_rs) 199 | augmented_sentences.append(' '.join(a_words)) 200 | 201 | # rd 202 | if (p_rd > 0): 203 | for _ in range(num_new_per_technique): 204 | a_words = random_deletion(words, p_rd) 205 | augmented_sentences.append(' '.join(a_words)) 206 | 207 | augmented_sentences = [get_only_chars(sentence) for sentence in augmented_sentences] 208 | shuffle(augmented_sentences) 209 | 210 | # trim so that we have the desired number of augmented sentences 211 | if num_aug >= 1: 212 | augmented_sentences = augmented_sentences[:num_aug] 213 | else: 214 | keep_prob = num_aug / len(augmented_sentences) 215 | augmented_sentences = [s for s in augmented_sentences if random.uniform(0, 1) < keep_prob] 216 | 217 | # append the original sentence 218 | augmented_sentences.append(sentence) 219 | 220 | return augmented_sentences -------------------------------------------------------------------------------- /backdoor/tsne_detected_vs_undetected_vs_clean.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import warnings 5 | import argparse 6 | import torchvision 7 | import numpy as np 8 | import pandas as pd 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | from PIL import Image, ImageFile 13 | from sklearn.manifold import TSNE 14 | from sklearn.decomposition import PCA 15 | from torch.utils.data import Dataset, DataLoader 16 | 17 | from pkgs.openai.clip import load as load_model 18 | 19 | ImageFile.LOAD_TRUNCATED_IMAGES = True 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def get_model(args, checkpoint): 24 | model, processor = load_model(name=args.model_name, pretrained=False) 25 | if (args.device == "cpu"): model.float() 26 | model.to(args.device) 27 | state_dict = torch.load(checkpoint, map_location=args.device)["state_dict"] 28 | if (next(iter(state_dict.items()))[0].startswith("module")): 29 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 30 | model.load_state_dict(state_dict) 31 | model.eval() 32 | return model, processor 33 | 34 | 35 | class ImageCaptionDataset(Dataset): 36 | def __init__(self, path, images, captions, processor): 37 | self.root = os.path.dirname(path) 38 | self.processor = processor 39 | self.images = images 40 | self.captions = self.processor.process_text(captions) 41 | 42 | def __len__(self): 43 | return len(self.images) 44 | 45 | def __getitem__(self, idx): 46 | item = {} 47 | image = Image.open(os.path.join(self.root, self.images[idx])) 48 | item["input_ids"] = self.captions["input_ids"][idx] 49 | item["attention_mask"] = self.captions["attention_mask"][idx] 50 | item["pixel_values"] = self.processor.process_image(image) 51 | return item 52 | 53 | 54 | def get_embeddings(model, dataloader, processor, args): 55 | device = args.device 56 | list_embeddings = [] 57 | with torch.no_grad(): 58 | for batch in tqdm(dataloader): 59 | input_ids, attention_mask, pixel_values = batch["input_ids"].to(device, non_blocking=True), batch[ 60 | "attention_mask"].to(device, non_blocking=True), batch["pixel_values"].to(device, non_blocking=True) 61 | outputs = model(input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values) 62 | list_embeddings.append(outputs.image_embeds) 63 | return torch.cat(list_embeddings, dim=0).cpu().detach().numpy() 64 | 65 | 66 | def plot_embeddings(args): 67 | args.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu") 68 | if not os.path.exists(args.save_data): 69 | checkpoint = f'epoch_{args.epoch}.pt' 70 | model, processor = get_model(args, os.path.join(args.checkpoints_dir, checkpoint)) 71 | df = pd.read_csv(args.original_csv) 72 | 73 | # we divide data into three categories -- clean, backdoored and detected, backdoored and undetected 74 | images, captions, is_backdoor = df['image'].tolist(), df['caption'].tolist(), df['is_backdoor'].tolist() 75 | backdoor_indices = list(filter(lambda x: 'backdoor' in images[x], range(len(images)))) 76 | backdoor_detected_indices = [x for x in backdoor_indices if is_backdoor[x] is True] 77 | backdoor_undetected_indices = [x for x in backdoor_indices if is_backdoor[x] is False] 78 | 79 | clean_indices = list(filter(lambda x: 'backdoor' not in images[x], range(len(images)))) 80 | clean_indices = clean_indices[:10000] 81 | 82 | 83 | backdoor_detected_images, backdoor_detected_captions = [images[x] for x in backdoor_detected_indices], \ 84 | [captions[x] for x in backdoor_detected_indices] 85 | 86 | backdoor_undetected_images, backdoor_undetected_captions = [images[x] for x in backdoor_undetected_indices], \ 87 | [captions[x] for x in backdoor_undetected_indices] 88 | 89 | clean_images, clean_captions = [images[x] for x in clean_indices], [captions[x] for x in clean_indices] 90 | 91 | dataset_clean = ImageCaptionDataset(args.original_csv, clean_images, clean_captions, processor) 92 | dataset_backdoor_detected = ImageCaptionDataset(args.original_csv, backdoor_detected_images, 93 | backdoor_detected_captions, processor) 94 | dataset_backdoor_undetected = ImageCaptionDataset(args.original_csv, backdoor_undetected_images, 95 | backdoor_undetected_captions, processor) 96 | 97 | dataloader_clean = DataLoader(dataset_clean, batch_size=args.batch_size, shuffle=False, pin_memory=True, 98 | drop_last=False) 99 | dataloader_backdoor_detected = DataLoader(dataset_backdoor_detected, batch_size=args.batch_size, shuffle=False, 100 | pin_memory=True, drop_last=False) 101 | dataloader_backdoor_undetected = DataLoader(dataset_backdoor_undetected, batch_size=args.batch_size, 102 | shuffle=False, pin_memory=True, drop_last=False) 103 | 104 | 105 | clean_images_embeddings = get_embeddings(model, dataloader_clean, processor, args) 106 | backdoor_detected_images_embeddings = get_embeddings(model, dataloader_backdoor_detected, processor, args) 107 | backdoor_undetected_images_embeddings = get_embeddings(model, dataloader_backdoor_undetected, processor, args) 108 | 109 | len_clean = len(clean_images_embeddings) 110 | len_backdoor_detected = len(backdoor_detected_images_embeddings) 111 | len_backdoor_undetected = len(backdoor_undetected_images_embeddings) 112 | 113 | all_embeddings = np.concatenate([clean_images_embeddings, backdoor_detected_images_embeddings, backdoor_undetected_images_embeddings], axis=0) 114 | print(len_clean) 115 | with open(args.save_data, 'wb') as f: 116 | pickle.dump((all_embeddings, len_clean, len_backdoor_detected, len_backdoor_undetected), f) 117 | 118 | with open(args.save_data, 'rb') as f: 119 | all_embeddings, len_clean, len_backdoor_detected, len_backdoor_undetected = pickle.load(f) 120 | 121 | fig = plt.figure() 122 | # ax = fig.add_subplot(projection='2d') 123 | 124 | tsne = TSNE(n_components=2, verbose=1, perplexity=10, n_iter=1000) 125 | results = tsne.fit_transform(all_embeddings) 126 | 127 | # pca = PCA(n_components = 2) 128 | # results = pca.fit_transform(all_embeddings) 129 | # print(pca.explained_variance_ratio_) 130 | 131 | plt.scatter(results[:len_clean, 0], results[:len_clean, 1], label='Clean') 132 | plt.scatter(results[len_clean:len_clean+len_backdoor_detected, 0], results[len_clean:len_clean+len_backdoor_detected, 1], label='Backdoor Detected') 133 | plt.scatter(results[len_clean + len_backdoor_detected:, 0], 134 | results[len_clean + len_backdoor_detected:, 1], label='Backdoor Undetected') 135 | 136 | plt.grid() 137 | plt.legend() 138 | plt.title(args.title) 139 | plt.tight_layout() 140 | 141 | os.makedirs(os.path.dirname(args.save_fig), exist_ok=True) 142 | plt.savefig(args.save_fig) 143 | 144 | 145 | if __name__ == "__main__": 146 | parser = argparse.ArgumentParser() 147 | 148 | parser.add_argument("--original_csv", type=str, default=None, help="original csv with captions and images") 149 | parser.add_argument("--device_id", type=str, default=None, help="device id") 150 | parser.add_argument("--model_name", type=str, default="RN50", choices=["RN50", "RN101", "RN50x4", "ViT-B/32"], 151 | help="Model Name") 152 | parser.add_argument("--checkpoints_dir", type=str, default="checkpoints/clip/", 153 | help="Path to checkpoint directories") 154 | parser.add_argument("--save_data", type=str, default=None, help="Save data") 155 | parser.add_argument("--save_fig", type=str, default=None, help="Save fig png") 156 | parser.add_argument("--batch_size", type=int, default=128, help="Batch Size") 157 | parser.add_argument("--epoch", type=int, default=64, help="Epoch") 158 | parser.add_argument("--title", type=str, default=None, help="Title for plot") 159 | 160 | args = parser.parse_args() 161 | 162 | plot_embeddings(args) -------------------------------------------------------------------------------- /backdoor/tsne-labelled.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import random 5 | import warnings 6 | import argparse 7 | import torchvision 8 | import numpy as np 9 | import seaborn as sns 10 | from tqdm import tqdm 11 | import matplotlib.pyplot as plt 12 | from PIL import Image, ImageFile 13 | from sklearn.manifold import TSNE 14 | from sklearn.decomposition import PCA 15 | from backdoor.utils import ImageLabelDataset 16 | from collections import defaultdict 17 | from torch.utils.data import Dataset, DataLoader 18 | 19 | from pkgs.openai.clip import load as load_model 20 | 21 | ImageFile.LOAD_TRUNCATED_IMAGES = True 22 | warnings.filterwarnings("ignore") 23 | 24 | def get_model(args, checkpoint, checkpoint_finetune = None): 25 | model, processor = load_model(name = args.model_name, pretrained = False) 26 | if(args.device == "cpu"): model.float() 27 | model.to(args.device) 28 | state_dict = torch.load(checkpoint, map_location = args.device)["state_dict"] 29 | if(next(iter(state_dict.items()))[0].startswith("module")): 30 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 31 | if checkpoint_finetune: 32 | finetuned_checkpoint = torch.load(checkpoint_finetune, map_location = args.device) 33 | finetuned_state_dict = finetuned_checkpoint["state_dict"] 34 | for key in state_dict: 35 | if 'visual' in key: 36 | ft_key = name.replace("module.", "model.") if "module" in key else f'model.{key}' 37 | state_dict[key] = finetuned_state_dict[ft_key] 38 | print('Loaded Visual Backbone from Finetuned Model') 39 | model.load_state_dict(state_dict) 40 | model.eval() 41 | return model, processor 42 | 43 | def get_embeddings(model, dataloader, processor, args): 44 | 45 | 46 | label_occurence_count = defaultdict(int) 47 | 48 | list_original_embeddings = [] 49 | list_backdoor_embeddings = [] 50 | 51 | label_list_original_embeddings = [] 52 | label_list_backdoor_embeddings = [] 53 | 54 | with torch.no_grad(): 55 | for original_images, backdoor_images, label in tqdm(dataloader): 56 | label = label.item() 57 | if label_occurence_count[label] < args.images_per_class: 58 | label_occurence_count[label] += 1 59 | 60 | original_images = original_images.to(args.device) 61 | original_images_embeddings = model.get_image_features(original_images) 62 | backdoor_images = backdoor_images.to(args.device) 63 | backdoor_images_embeddings = model.get_image_features(backdoor_images) 64 | 65 | original_images_embeddings /= original_images_embeddings.norm(dim = -1, keepdim = True) 66 | backdoor_images_embeddings /= backdoor_images_embeddings.norm(dim = -1, keepdim = True) 67 | 68 | if label == 954: 69 | label_list_original_embeddings.append(original_images_embeddings) 70 | label_list_backdoor_embeddings.append(backdoor_images_embeddings) 71 | else: 72 | list_original_embeddings.append(original_images_embeddings) 73 | list_backdoor_embeddings.append(backdoor_images_embeddings) 74 | 75 | original_images_embeddings = torch.cat(list_original_embeddings, dim = 0) 76 | backdoor_images_embeddings = torch.cat(list_backdoor_embeddings, dim = 0) 77 | # label_original_images_embeddings = torch.cat(label_list_original_embeddings, dim = 0) 78 | # label_backdoor_images_embeddings = torch.cat(label_list_backdoor_embeddings, dim = 0) 79 | 80 | # return original_images_embeddings.cpu().detach().numpy(), backdoor_images_embeddings.cpu().detach().numpy(), label_original_images_embeddings.cpu().detach().numpy(), label_backdoor_images_embeddings.cpu().detach().numpy() 81 | return original_images_embeddings.cpu().detach().numpy(), backdoor_images_embeddings.cpu().detach().numpy(), None, None 82 | def plot_embeddings(args): 83 | 84 | if not args.use_saved: 85 | args.device = torch.device(f"cuda:{args.device_id}" if torch.cuda.is_available() else "cpu") 86 | 87 | model, processor = get_model(args, args.checkpoint, args.checkpoint_finetune) 88 | dataset = ImageLabelDataset(args.original_csv, processor.process_image) 89 | dataset = torch.utils.data.Subset(dataset, list(range(1000))) 90 | 91 | dataloader = DataLoader(dataset, batch_size = args.batch_size, shuffle = False, pin_memory = True, drop_last = False) 92 | 93 | # original_images_embeddings, backdoor_images_embeddings, label_original_images_embeddings, label_backdoor_images_embeddings = get_embeddings(model, dataloader, processor, args) 94 | original_images_embeddings, backdoor_images_embeddings, _, _ = get_embeddings(model, dataloader, processor, args) 95 | 96 | number_non_label = len(original_images_embeddings) 97 | # number_label = len(label_original_images_embeddings) 98 | # all_embeddings = np.concatenate([original_images_embeddings, backdoor_images_embeddings, label_original_images_embeddings, label_backdoor_images_embeddings], axis = 0) 99 | all_embeddings = np.concatenate([original_images_embeddings, backdoor_images_embeddings], axis = 0) 100 | 101 | # fig = plt.figure() 102 | # ax = fig.add_subplot(projection='3d') 103 | 104 | tsne = TSNE(n_components=2, verbose=1, perplexity=10, n_iter=1000) 105 | results = tsne.fit_transform(all_embeddings) 106 | 107 | with open('3.pkl', 'wb') as f: 108 | pickle.dump((results, original_images_embeddings, backdoor_images_embeddings) , f) 109 | 110 | else: 111 | with open('3.pkl', 'rb') as f: 112 | results, original_images_embeddings, backdoor_images_embeddings = pickle.load(f) 113 | 114 | plt.scatter(results[:len(original_images_embeddings), 0], results[:len(original_images_embeddings), 1], label = 'Clean Images') 115 | plt.scatter(results[len(original_images_embeddings) : len(original_images_embeddings) + len(backdoor_images_embeddings), 0], 116 | results[len(original_images_embeddings) : len(original_images_embeddings) + len(backdoor_images_embeddings), 1], label = 'Backdoor Images') 117 | # plt.scatter(results[len(original_images_embeddings) + len(backdoor_images_embeddings): len(original_images_embeddings) + len(backdoor_images_embeddings) + len(label_original_images_embeddings), 0], 118 | # results[len(original_images_embeddings) + len(backdoor_images_embeddings): len(original_images_embeddings) + len(backdoor_images_embeddings) + len(label_original_images_embeddings), 1], label = 'Banana Images') 119 | # plt.scatter(results[len(original_images_embeddings) + len(backdoor_images_embeddings) + len(label_original_images_embeddings) :, 0], 120 | # results[len(original_images_embeddings) + len(backdoor_images_embeddings) + len(label_original_images_embeddings) :, 1], label = 'Backdoored Banana Images') 121 | 122 | 123 | plt.grid() 124 | plt.tight_layout() 125 | # plt.legend(bbox_to_anchor=(1.02, 1.0), loc = 'upper left') 126 | plt.legend(prop={'size': 15}) 127 | plt.title(f'{args.title}') 128 | 129 | os.makedirs(os.path.dirname(args.save_fig), exist_ok = True) 130 | plt.savefig(args.save_fig, bbox_inches='tight') 131 | 132 | if __name__ == "__main__": 133 | 134 | parser = argparse.ArgumentParser() 135 | 136 | parser.add_argument("--original_csv", type = str, default = None, help = "original csv with captions and images") 137 | parser.add_argument("--title", type = str, default = None, help = "title of the graph") 138 | parser.add_argument("--device_id", type = str, default = None, help = "device id") 139 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 140 | parser.add_argument("--checkpoint", type = str, default = None, help = "Path to checkpoint") 141 | parser.add_argument("--checkpoint_finetune", type = str, default = None, help = "Path to finetune checkpoint") 142 | parser.add_argument("--save_fig", type = str, default = None, help = "Save fig png") 143 | parser.add_argument("--batch_size", type = int, default = 1, help = "Batch Size") 144 | parser.add_argument("--images_per_class", type = int, default = 5, help = "Batch Size") 145 | parser.add_argument("--use_saved", default = False, action = 'store_true') 146 | 147 | args = parser.parse_args() 148 | 149 | plot_embeddings(args) -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["WANDB_API_KEY"] = "" 3 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 4 | import sys 5 | import time 6 | import wandb 7 | import torch 8 | import logging 9 | import warnings 10 | import numpy as np 11 | import torch.optim as optim 12 | import torch.distributed as dist 13 | import torch.multiprocessing as mp 14 | import torch.backends.cudnn as cudnn 15 | from torch.cuda.amp import GradScaler 16 | from torch.nn.parallel import DistributedDataParallel as DDP 17 | 18 | from pkgs.openai.clip import load as load_model 19 | 20 | from .train import train 21 | from .evaluate import evaluate, Finetune 22 | from .data import load as load_data 23 | from .data import get_clean_train_dataloader, calculate_scores 24 | from .parser import parse_args 25 | from .scheduler import cosine_scheduler 26 | from .logger import get_logger, set_logger 27 | 28 | mp.set_start_method("spawn", force = True) 29 | warnings.filterwarnings("ignore") 30 | 31 | 32 | def gathered_elements_to_list(gather_elements): 33 | output = [] 34 | for element in gather_elements: 35 | output = output + list(element) 36 | return output 37 | 38 | def progressive_removal(options, model, processor, data, epoch): 39 | 40 | path = calculate_scores(options, model, data["train"], epoch) 41 | gather_path = [None for _ in range(options.num_devices)] 42 | if options.distributed: 43 | dist.all_gather_object(gather_path, path) 44 | 45 | if not options.master and options.distributed: 46 | logging.info(f'Device inside barrier 1 {options.device}') 47 | torch.distributed.barrier() 48 | logging.info(f'Device outside barrier 1 {options.device}') 49 | 50 | data["train"] = get_clean_train_dataloader(options, processor, path) 51 | 52 | options.train_data = path 53 | 54 | if options.master and options.distributed: 55 | logging.info(f'Device inside barrier 2 {options.device}') 56 | torch.distributed.barrier() 57 | logging.info(f'Device outside barrier 2 {options.device}') 58 | 59 | return options, data 60 | 61 | def worker(rank, options, logger): 62 | options.rank = rank 63 | options.master = rank == 0 64 | 65 | set_logger(rank = rank, logger = logger, distributed = options.distributed) 66 | 67 | if(options.device == "cuda"): 68 | options.device += ":" + str(options.device_ids[options.rank] if options.distributed else options.device_id) 69 | 70 | logging.info(f"Using {options.device} device") 71 | 72 | if(options.master): 73 | logging.info("Params:") 74 | with open(os.path.join(options.log_dir_path, "params.txt"), "w") as file: 75 | for key in sorted(vars(options)): 76 | value = getattr(options, key) 77 | logging.info(f"{key}: {value}") 78 | file.write(f"{key}: {value}\n") 79 | 80 | if(options.distributed): 81 | dist.init_process_group(backend = options.distributed_backend, init_method = options.distributed_init_method, world_size = options.num_devices, rank = options.rank) 82 | 83 | options.batch_size = options.batch_size // options.num_devices 84 | 85 | model, processor = load_model(name = options.model_name, pretrained = options.pretrained) 86 | 87 | if(options.device == "cpu"): 88 | model.float() 89 | else: 90 | torch.cuda.set_device(options.device_ids[options.rank] if options.distributed else options.device_id) 91 | model.to(options.device) 92 | if(options.distributed): 93 | model = DDP(model, device_ids = [options.device_ids[options.rank]]) 94 | 95 | data = load_data(options, processor) 96 | 97 | optimizer = None 98 | scheduler = None 99 | if(data["train"] is not None): 100 | weight_decay_parameters = [] 101 | no_weight_decay_parameters = [] 102 | 103 | for name, parameter in model.named_parameters(): 104 | if(all(key not in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad): 105 | weight_decay_parameters.append(parameter) 106 | 107 | if(any(key in name for key in ["bn", "ln", "bias", "logit_scale"]) and parameter.requires_grad): 108 | no_weight_decay_parameters.append(parameter) 109 | 110 | optimizer = optim.AdamW([{"params": no_weight_decay_parameters, "weight_decay": 0}, {"params": weight_decay_parameters, "weight_decay": options.weight_decay}], lr = options.lr, betas = (options.beta1, options.beta2), eps = options.eps) 111 | scheduler = cosine_scheduler(optimizer, options.lr, options.num_warmup_steps, data["train"].num_batches * options.epochs) 112 | 113 | start_epoch = 0 114 | if(options.checkpoint is not None): 115 | if(os.path.isfile(options.checkpoint)): 116 | checkpoint = torch.load(options.checkpoint, map_location = options.device) 117 | start_epoch = 0 if options.complete_finetune else checkpoint['epoch'] 118 | state_dict = checkpoint["state_dict"] 119 | if(not options.distributed and next(iter(state_dict.items()))[0].startswith("module")): 120 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 121 | # hack to load a non-distributed checkpoint for distributed training 122 | if (options.distributed and not next(iter(state_dict.items()))[0].startswith("module")): 123 | state_dict = {"module."+key: value for key, value in state_dict.items()} 124 | if(options.checkpoint_finetune): 125 | finetuned_checkpoint = torch.load(options.checkpoint_finetune, map_location = options.device) 126 | finetuned_state_dict = finetuned_checkpoint["state_dict"] 127 | for key in state_dict: 128 | if 'visual' in key: 129 | ft_key = name.replace("module.", "model.") if "module" in key else f'model.{key}' 130 | state_dict[key] = finetuned_state_dict[ft_key] 131 | print('Loaded Visual Backbone from Finetuned Model') 132 | model.load_state_dict(state_dict) 133 | if(optimizer is not None): optimizer.load_state_dict(checkpoint["optimizer"]) 134 | logging.info(f"Loaded checkpoint '{options.checkpoint}' (start epoch {checkpoint['epoch']})") 135 | else: 136 | logging.info(f"No checkpoint found at {options.checkpoint}") 137 | 138 | cudnn.benchmark = True 139 | cudnn.deterministic = False 140 | 141 | if(options.wandb and options.master): 142 | logging.debug("Starting wandb") 143 | wandb.init(project = "clip-defense", notes = options.notes, tags = [], config = vars(options), entity = 'mint-adobe') 144 | wandb.run.name = options.name 145 | wandb.save(os.path.join(options.log_dir_path, "params.txt")) 146 | 147 | evaluate(start_epoch, model, processor, data, options) 148 | 149 | if(data["train"] is not None): 150 | options.checkpoints_dir_path = os.path.join(options.log_dir_path, "checkpoints") 151 | os.makedirs(options.checkpoints_dir_path, exist_ok = True) 152 | 153 | scaler = GradScaler() 154 | 155 | best_loss = np.inf 156 | 157 | if(options.progressive): 158 | options.progressive_epochs = list(map(int, options.progressive_epochs)) 159 | if (start_epoch in options.progressive_epochs): 160 | options, data = progressive_removal(options, model, processor, data, start_epoch) 161 | 162 | for epoch in range(start_epoch + 1, options.epochs + 1): 163 | if(options.master): 164 | logging.info(f"Starting Epoch {epoch}") 165 | 166 | start = time.time() 167 | train(epoch, model, data, optimizer, scheduler, scaler, options) 168 | end = time.time() 169 | 170 | if(options.master): 171 | logging.info(f"Finished Epoch {epoch}, Time Taken: {end - start:.3f}") 172 | 173 | metrics = evaluate(epoch, model, processor, data, options) 174 | 175 | if(options.master): 176 | checkpoint = {"epoch": epoch, "name": options.name, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} 177 | if(options.complete_finetune): 178 | torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch.pt")) 179 | else: 180 | torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch_{epoch}.pt")) 181 | if("loss" in metrics): 182 | if(metrics["loss"] < best_loss): 183 | best_loss = metrics["loss"] 184 | torch.save(checkpoint, os.path.join(options.checkpoints_dir_path, f"epoch.best.pt")) 185 | 186 | if(options.progressive): 187 | if epoch in options.progressive_epochs: 188 | options, data = progressive_removal(options, model, processor, data, epoch) 189 | 190 | if epoch == options.stop_epoch: 191 | return 192 | 193 | if(options.distributed): 194 | dist.destroy_process_group() 195 | 196 | if(options.wandb and options.master): 197 | wandb.finish() 198 | 199 | if(__name__ == "__main__"): 200 | options = parse_args() 201 | 202 | options.log_dir_path = os.path.join(options.logs, options.name) 203 | options.log_file_path = os.path.join(options.log_dir_path, "output.log") 204 | 205 | os.makedirs(options.log_dir_path, exist_ok = True) 206 | logger, listener = get_logger(options.log_file_path) 207 | 208 | listener.start() 209 | 210 | ngpus = torch.cuda.device_count() 211 | if(ngpus == 0 or options.device == "cpu"): 212 | options.device = "cpu" 213 | options.num_devices = 1 214 | options.distributed = False 215 | worker(0, options, logger) 216 | else: 217 | if(ngpus == 1 or not options.distributed): 218 | options.device = "cuda" 219 | options.num_devices = 1 220 | options.distributed = False 221 | worker(0, options, logger) 222 | else: 223 | options.device = "cuda" 224 | if(options.device_ids is None): 225 | options.device_ids = list(range(ngpus)) 226 | options.num_devices = ngpus 227 | else: 228 | options.device_ids = list(map(int, options.device_ids)) 229 | options.num_devices = len(options.device_ids) 230 | options.distributed = True 231 | os.environ["NCCL_P2P_DISABLE"] = "1" 232 | mp.spawn(worker, nprocs = options.num_devices, args = (options, logger)) 233 | 234 | listener.stop() 235 | -------------------------------------------------------------------------------- /utils/embeddings.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pickle 4 | import warnings 5 | import argparse 6 | import torchvision 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from PIL import Image, ImageFile 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | from pkgs.openai.clip import load as load_model 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | warnings.filterwarnings("ignore") 17 | 18 | def get_model(options): 19 | model, processor = load_model(name = options.model_name, pretrained = False) 20 | if(options.device == "cpu"): model.float() 21 | model.to(options.device) 22 | state_dict = torch.load(options.checkpoint, map_location = options.device)["state_dict"] 23 | if(next(iter(state_dict.items()))[0].startswith("module")): 24 | state_dict = {key[len("module."):]: value for key, value in state_dict.items()} 25 | model.load_state_dict(state_dict) 26 | model.eval() 27 | return model, processor 28 | 29 | class ImageLabelDataset(Dataset): 30 | def __init__(self, root, transform): 31 | self.root = root 32 | df = pd.read_csv(os.path.join(root, "labels.csv")) 33 | self.images = df["image"] 34 | self.labels = df["label"] 35 | self.transform = transform 36 | 37 | def __len__(self): 38 | return len(self.labels) 39 | 40 | def __getitem__(self, idx): 41 | image = self.transform(Image.open(os.path.join(self.root, self.images[idx]))) 42 | label = self.labels[idx] 43 | return image, label 44 | 45 | def get_test_dataset(options, processor): 46 | if(options.data_type == "Caltech101"): 47 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 48 | elif(options.data_type == "CIFAR10"): 49 | dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.data_dir), download = True, train = False, transform = processor.process_image) 50 | elif(options.data_type == "CIFAR100"): 51 | dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.data_dir), download = True, train = False, transform = processor.process_image) 52 | elif(options.data_type == "DTD"): 53 | dataset = torchvision.datasets.DTD(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 54 | elif(options.data_type == "FGVCAircraft"): 55 | dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 56 | elif(options.data_type == "Flowers102"): 57 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 58 | elif(options.data_type == "Food101"): 59 | dataset = torchvision.datasets.Food101(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 60 | elif(options.data_type == "GTSRB"): 61 | dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 62 | elif(options.data_type == "ImageNet1K"): 63 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 64 | elif(options.data_type == "OxfordIIITPet"): 65 | dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 66 | elif(options.data_type == "RenderedSST2"): 67 | dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 68 | elif(options.data_type == "StanfordCars"): 69 | dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 70 | elif(options.data_type == "STL10"): 71 | dataset = torchvision.datasets.STL10(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 72 | elif(options.data_type == "SVHN"): 73 | dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.data_dir), download = True, split = "test", transform = processor.process_image) 74 | elif(options.data_type in ["ImageNetV2", "ImageNetSketch", "ImageNet-A", "ImageNet-R"]): 75 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 76 | else: 77 | raise Exception(f"Test dataset type {options.data_type} is not supported") 78 | 79 | return dataset 80 | 81 | def get_train_dataset(options, processor): 82 | if(options.data_type == "Caltech101"): 83 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 84 | elif(options.data_type == "CIFAR10"): 85 | dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.data_dir), download = True, train = True, transform = processor.process_image) 86 | elif(options.data_type == "CIFAR100"): 87 | dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.data_dir), download = True, train = True, transform = processor.process_image) 88 | elif(options.data_type == "DTD"): 89 | dataset = torch.utils.data.ConcatDataset([torchvision.datasets.DTD(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image), torchvision.datasets.DTD(root = os.path.dirname(options.data_dir), download = True, split = "val", transform = processor.process_image)]) 90 | elif(options.data_type == "FGVCAircraft"): 91 | dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.data_dir), download = True, split = "trainval", transform = processor.process_image) 92 | elif(options.data_type == "Flowers102"): 93 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 94 | elif(options.data_type == "Food101"): 95 | dataset = torchvision.datasets.Food101(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 96 | elif(options.data_type == "GTSRB"): 97 | dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 98 | elif(options.data_type == "ImageNet1K"): 99 | dataset = ImageLabelDataset(root = options.data_dir, transform = processor.process_image) 100 | elif(options.data_type == "OxfordIIITPet"): 101 | dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.data_dir), download = True, split = "trainval", transform = processor.process_image) 102 | elif(options.data_type == "RenderedSST2"): 103 | dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 104 | elif(options.data_type == "StanfordCars"): 105 | dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 106 | elif(options.data_type == "STL10"): 107 | dataset = torchvision.datasets.STL10(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 108 | elif(options.data_type == "SVHN"): 109 | dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.data_dir), download = True, split = "train", transform = processor.process_image) 110 | else: 111 | raise Exception(f"Train dataset type {options.data_type} is not supported") 112 | 113 | return dataset 114 | 115 | def generate(model, dataloader, processor, options): 116 | # output = {"image_embeddings": [], "text_embeddings": [], "un_image_embeddings": [], "labels": [], "classes": [], "superclasses": []} 117 | output = pickle.load(open(options.output_file, "rb")) 118 | output["un_image_embeddings"] = [] 119 | with torch.no_grad(): 120 | # if(os.path.exists(options.data_classes)): 121 | # config = eval(open(options.data_classes, "r").read()) 122 | # classes, templates = config["classes"], config["templates"] 123 | # output["classes"] = classes 124 | 125 | # if("superclasses" in config): 126 | # output["superclasses"] = config["superclasses"] 127 | 128 | # text_embeddings = [] 129 | # for c in tqdm(classes): 130 | # text = [template(c) for template in templates] 131 | # text_tokens = processor.process_text(text) 132 | # text_input_ids, text_attention_mask = text_tokens["input_ids"].to(options.device), text_tokens["attention_mask"].to(options.device) 133 | # text_embedding = model.get_text_features(input_ids = text_input_ids, attention_mask = text_attention_mask) 134 | # text_embedding /= text_embedding.norm(dim = -1, keepdim = True) 135 | # text_embedding = text_embedding.mean(dim = 0) 136 | # text_embedding /= text_embedding.norm() 137 | # text_embeddings.append(text_embedding) 138 | # text_embeddings = torch.stack(text_embeddings, dim = 1).to(options.device).t() 139 | 140 | # output["text_embeddings"] = text_embeddings.detach().cpu().tolist() 141 | 142 | for images, labels in tqdm(dataloader): 143 | # images, labels = images.to(options.device), labels.to(options.device) 144 | images = images.to(options.device) 145 | 146 | image_embeddings = model.get_image_features(images) 147 | output["un_image_embeddings"].extend(image_embeddings.detach().cpu().tolist()) 148 | 149 | # image_embeddings /= image_embeddings.norm(dim = -1, keepdim = True) 150 | # output["image_embeddings"].extend(image_embeddings.detach().cpu().tolist()) 151 | 152 | # output["labels"].extend(labels.detach().cpu().tolist()) 153 | 154 | os.makedirs(os.path.dirname(options.output_file), exist_ok = True) 155 | pickle.dump(output, open(options.output_file, "wb")) 156 | 157 | def embeddings(options): 158 | options.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 159 | model, processor = get_model(options) 160 | dataset = get_train_dataset(options, processor) if(options.train) else get_test_dataset(options, processor) 161 | dataloader = DataLoader(dataset, batch_size = options.batch_size, shuffle = False, num_workers = options.num_workers, pin_memory = True, drop_last = False) 162 | generate(model, dataloader, processor, options) 163 | 164 | if(__name__ == "__main__"): 165 | parser = argparse.ArgumentParser() 166 | 167 | parser.add_argument("--data_dir", type = str, default = "data/ImageNet1K/validation", help = "Input dir") 168 | parser.add_argument("--data_type", type = str, default = "ImageNet1K", help = "Input data type") 169 | parser.add_argument("--data_classes", type = str, default = "data/ImageNet1K/validation/classes.py", help = "Input classes") 170 | parser.add_argument("--output_file", type = str, default = "analysis/embeddings/clip/ImageNet1K.validation.pkl", help = "Output file") 171 | parser.add_argument("--model_name", type = str, default = "RN50", choices = ["RN50", "RN101", "RN50x4", "ViT-B/32"], help = "Model Name") 172 | parser.add_argument("--checkpoint", type = str, default = "checkpoints/clip/best.pt", help = "Path to checkpoint") 173 | parser.add_argument("--batch_size", type = int, default = 256, help = "Batch Size") 174 | parser.add_argument("--num_workers", type = int, default = 16, help = "Number of workers") 175 | parser.add_argument("--train", action = "store_true", default = False, help = "Train set") 176 | 177 | options = parser.parse_args() 178 | embeddings(options) -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import wandb 3 | import torch 4 | import logging 5 | import torch.nn as nn 6 | import torch.distributed as dist 7 | from torch.cuda.amp import autocast 8 | 9 | def get_loss(umodel, outputs, criterion, options, gather_backdoor_indices): 10 | if(options.inmodal): 11 | image_embeds, augmented_image_embeds = outputs.image_embeds[:len(outputs.image_embeds) // 2], outputs.image_embeds[len(outputs.image_embeds) // 2:] 12 | text_embeds, augmented_text_embeds = outputs.text_embeds[:len(outputs.text_embeds) // 2], outputs.text_embeds[len(outputs.text_embeds) // 2:] 13 | else: 14 | image_embeds = outputs.image_embeds 15 | text_embeds = outputs.text_embeds 16 | 17 | if(options.distributed): 18 | if(options.inmodal): 19 | gathered_image_embeds = [torch.zeros_like(image_embeds) for _ in range(options.num_devices)] 20 | gathered_text_embeds = [torch.zeros_like(text_embeds) for _ in range(options.num_devices)] 21 | augmented_gathered_image_embeds = [torch.zeros_like(augmented_image_embeds) for _ in range(options.num_devices)] 22 | augmented_gathered_text_embeds = [torch.zeros_like(augmented_text_embeds) for _ in range(options.num_devices)] 23 | 24 | dist.all_gather(gathered_image_embeds, image_embeds) 25 | dist.all_gather(gathered_text_embeds, text_embeds) 26 | dist.all_gather(augmented_gathered_image_embeds, augmented_image_embeds) 27 | dist.all_gather(augmented_gathered_text_embeds, augmented_text_embeds) 28 | 29 | image_embeds = torch.cat(gathered_image_embeds[:options.rank] + [image_embeds] + gathered_image_embeds[options.rank + 1:]) 30 | text_embeds = torch.cat(gathered_text_embeds[:options.rank]+ [text_embeds] + gathered_text_embeds[options.rank + 1:]) 31 | augmented_image_embeds = torch.cat(augmented_gathered_image_embeds[:options.rank] + [augmented_image_embeds] + augmented_gathered_image_embeds[options.rank + 1:]) 32 | augmented_text_embeds = torch.cat(augmented_gathered_text_embeds[:options.rank]+ [augmented_text_embeds] + augmented_gathered_text_embeds[options.rank + 1:]) 33 | else: 34 | gathered_image_embeds = [torch.zeros_like(image_embeds) for _ in range(options.num_devices)] 35 | gathered_text_embeds = [torch.zeros_like(text_embeds) for _ in range(options.num_devices)] 36 | 37 | dist.all_gather(gathered_image_embeds, image_embeds) 38 | dist.all_gather(gathered_text_embeds, text_embeds) 39 | 40 | image_embeds = torch.cat(gathered_image_embeds[:options.rank] + [image_embeds] + gathered_image_embeds[options.rank + 1:]) 41 | text_embeds = torch.cat(gathered_text_embeds[:options.rank]+ [text_embeds] + gathered_text_embeds[options.rank + 1:]) 42 | 43 | constraint = torch.tensor(0).to(options.device) 44 | if options.unlearn: 45 | normal_indices = (~gather_backdoor_indices).nonzero().squeeze() 46 | backdoor_indices = gather_backdoor_indices.nonzero() 47 | backdoor_indices = backdoor_indices[:,0] if len(backdoor_indices.shape) == 2 else backdoor_indices 48 | if len(backdoor_indices): 49 | backdoor_image_embeds = image_embeds[backdoor_indices] 50 | backdoor_text_embeds = text_embeds[backdoor_indices] 51 | similarity_backdoor_embeds = torch.diagonal(backdoor_image_embeds @ backdoor_text_embeds.t()) 52 | constraint = (similarity_backdoor_embeds + options.unlearn_target).square().mean().to(options.device, non_blocking = True) 53 | image_embeds = image_embeds[normal_indices] 54 | text_embeds = text_embeds[normal_indices] 55 | 56 | logits_text_per_image = umodel.logit_scale.exp() * image_embeds @ text_embeds.t() 57 | logits_image_per_text = logits_text_per_image.t() 58 | 59 | if(options.inmodal): 60 | logits_image_per_augmented_image = umodel.logit_scale.exp() * image_embeds @ augmented_image_embeds.t() 61 | logits_text_per_augmented_text = umodel.logit_scale.exp() * text_embeds @ augmented_text_embeds.t() 62 | 63 | batch_size = len(logits_text_per_image) 64 | target = torch.arange(batch_size).long().to(options.device, non_blocking = True) 65 | 66 | contrastive_loss = torch.tensor(0).to(options.device) 67 | if(options.inmodal): 68 | crossmodal_contrastive_loss = (criterion(logits_text_per_image, target) + criterion(logits_image_per_text, target)) / 2 69 | inmodal_contrastive_loss = (criterion(logits_image_per_augmented_image, target) + criterion(logits_text_per_augmented_text, target)) / 2 70 | # contrastive_loss = (crossmodal_contrastive_loss + inmodal_contrastive_loss) / 2 71 | contrastive_loss = (options.clip_weight * crossmodal_contrastive_loss) + (options.inmodal_weight * inmodal_contrastive_loss) 72 | else: 73 | crossmodal_contrastive_loss = (criterion(logits_text_per_image, target) + criterion(logits_image_per_text, target)) / 2 74 | contrastive_loss = crossmodal_contrastive_loss 75 | 76 | if options.unlearn: 77 | contrastive_loss = contrastive_loss + (options.constraint_weight * constraint) 78 | 79 | loss = contrastive_loss 80 | return loss, contrastive_loss, constraint 81 | 82 | # @torch.no_grad() 83 | # def get_clean_batch(model, batch, options, step, threshold = 0.6): 84 | # input_ids, attention_mask, pixel_values, pixel_values_cropped = batch["input_ids"].to(options.device, non_blocking = True), batch["attention_mask"].to(options.device, non_blocking = True), batch["pixel_values"].to(options.device, non_blocking = True), batch["pixel_values_cropped"].to(options.device, non_blocking = True) 85 | # pixel_values_all = torch.cat([pixel_values, pixel_values_cropped]) 86 | # outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values_all) 87 | # image_embeds = outputs.image_embeds 88 | # image_embeds, image_embeds_cropped = image_embeds[: len(image_embeds) // 2], image_embeds[len(image_embeds) // 2 :] 89 | # pairwise_similarity = 1 - (((image_embeds - image_embeds_cropped)**2).sum(dim = 1) / 2) 90 | # is_normal = pairwise_similarity > threshold ## if the pairwise similarity is high the it is an original image 91 | # indices = is_normal.nonzero().squeeze() 92 | # # indices = range(len(pixel_values)) if len(indices) == 0 else indices ## don't want any empty batch 93 | 94 | # is_backdoor = batch["is_backdoor"].to(options.device, non_blocking = True) 95 | # total_backdoors = sum(is_backdoor).item() 96 | # predicted_backdoor = ~ is_normal 97 | # fraction_caught = -1 98 | 99 | # if sum(predicted_backdoor).item() != len(predicted_backdoor): 100 | # backdoor_predicted_equal = is_backdoor & predicted_backdoor 101 | # correct_backdoors = sum(backdoor_predicted_equal).item() 102 | # if total_backdoors > 0: 103 | # fraction_caught = correct_backdoors // total_backdoors 104 | 105 | # if options.wandb and options.master: 106 | # wandb.log({f'{options.rank}/len of indices' : len(indices), 'step': step}) 107 | # wandb.log({f'{options.rank}/# images removed' : len(pixel_values) - len(indices), 'step': step}) 108 | # wandb.log({f'{options.rank}/total backdoors' : total_backdoors, 'step': step}) 109 | # wandb.log({f'{options.rank}/correct backdoors detected' : correct_backdoors, 'step': step}) 110 | # wandb.log({f'{options.rank}/fraction of backdoors caught' : fraction_caught, 'step': step}) 111 | 112 | # return input_ids[indices], attention_mask[indices], pixel_values[indices], torch.tensor(len(indices)).to(options.device) 113 | # return is_normal 114 | 115 | def process_batch(model, batch, options, step): 116 | input_ids, attention_mask, pixel_values, is_backdoor = batch["input_ids"].to(options.device, non_blocking = True), batch["attention_mask"].to(options.device, non_blocking = True), batch["pixel_values"].to(options.device, non_blocking = True), batch["is_backdoor"].to(options.device, non_blocking = True) 117 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values) 118 | with torch.no_grad(): 119 | similarity = torch.diagonal(outputs.image_embeds @ outputs.text_embeds.t()) 120 | topmax = int(options.remove_fraction * len(similarity)) 121 | detect_indices = similarity.topk(topmax).indices 122 | num_backdoor = is_backdoor.sum().item() 123 | backdoor_indices = is_backdoor.nonzero() 124 | backdoor_indices = backdoor_indices[:,0] if len(backdoor_indices.shape) == 2 else backdoor_indices 125 | count = 0 126 | if len(backdoor_indices) > 0: 127 | for backdoor_index in backdoor_indices: 128 | count += (backdoor_index in detect_indices) 129 | if options.wandb and options.master: 130 | wandb.log({f'{options.rank}/total backdoors' : num_backdoor, 'step': step}) 131 | wandb.log({f'{options.rank}/correct backdoors detected' : count, 'step': step}) 132 | pred_backdoor_indices = torch.zeros_like(similarity).int() 133 | pred_backdoor_indices[detect_indices] = 1 134 | return outputs, pred_backdoor_indices 135 | 136 | def train(epoch, model, data, optimizer, scheduler, scaler, options): 137 | dataloader = data["train"] 138 | if(options.distributed): dataloader.sampler.set_epoch(epoch) 139 | 140 | model.train() 141 | criterion = nn.CrossEntropyLoss().to(options.device) #if not options.unlearn else nn.CrossEntropyLoss(reduction = 'none').to(options.device) 142 | 143 | modulo = max(1, int(dataloader.num_samples / options.batch_size / 5)) 144 | umodel = model.module if(options.distributed) else model 145 | 146 | start = time.time() 147 | 148 | logging.info(f"Num samples: {dataloader.num_samples}, Num_batches: {dataloader.num_batches}") 149 | for index, batch in enumerate(dataloader): 150 | step = dataloader.num_batches * epoch + index 151 | scheduler(step) 152 | 153 | optimizer.zero_grad() 154 | 155 | if(options.inmodal): 156 | input_ids, attention_mask, pixel_values = batch["input_ids"][0].to(options.device, non_blocking = True), batch["attention_mask"][0].to(options.device, non_blocking = True), batch["pixel_values"][0].to(options.device, non_blocking = True) 157 | augmented_input_ids, augmented_attention_mask, augmented_pixel_values = batch["input_ids"][1].to(options.device, non_blocking = True), batch["attention_mask"][1].to(options.device, non_blocking = True), batch["pixel_values"][1].to(options.device, non_blocking = True) 158 | input_ids = torch.cat([input_ids, augmented_input_ids]) 159 | attention_mask = torch.cat([attention_mask, augmented_attention_mask]) 160 | pixel_values = torch.cat([pixel_values, augmented_pixel_values]) 161 | else: 162 | input_ids, attention_mask, pixel_values = batch["input_ids"].to(options.device, non_blocking = True), batch["attention_mask"].to(options.device, non_blocking = True), batch["pixel_values"].to(options.device, non_blocking = True) 163 | 164 | gather_backdoor_indices = None 165 | if options.unlearn: 166 | if options.distributed: 167 | backdoor_indices = batch["is_backdoor"].to(options.device) 168 | gather_backdoor_indices = [torch.zeros_like(backdoor_indices) for _ in range(options.num_devices)] 169 | dist.all_gather(tensor_list = gather_backdoor_indices, tensor = backdoor_indices) 170 | gather_backdoor_indices = torch.cat(gather_backdoor_indices).to(options.device, non_blocking = True) 171 | else: 172 | gather_backdoor_indices = batch["is_backdoor"].to(options.device, non_blocking = True) 173 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values) 174 | 175 | with autocast(): 176 | loss, contrastive_loss, constraint_loss = get_loss(umodel, outputs, criterion, options, gather_backdoor_indices) 177 | scaler.scale(loss).backward() 178 | scaler.step(optimizer) 179 | 180 | scaler.update() 181 | umodel.logit_scale.data = torch.clamp(umodel.logit_scale.data, 0, 4.6052) 182 | 183 | end = time.time() 184 | 185 | if(options.master and (((index + 1) % modulo == 0) or (index == dataloader.num_batches - 1))): 186 | num_samples = (index + 1) * len(input_ids) * options.num_devices 187 | dataloader_num_samples = dataloader.num_samples 188 | 189 | logging.info(f"Train Epoch: {epoch:02d} [{num_samples}/{dataloader_num_samples} ({100.0 * (index + 1) / dataloader.num_batches:.0f}%)]\tLoss: {loss.item():.6f}\tTime taken {end - start:.3f}\tLearning Rate: {optimizer.param_groups[0]['lr']:.9f}") 190 | 191 | metrics = {"loss": loss.item(), "contrastive_loss": contrastive_loss.item(), "constraint_loss": constraint_loss.item(), "time": end - start, "lr": optimizer.param_groups[0]["lr"]} 192 | if(options.wandb): 193 | for key, value in metrics.items(): 194 | wandb.log({f"train/{key}": value, "step": step}) 195 | 196 | start = time.time() 197 | -------------------------------------------------------------------------------- /src/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import wandb 4 | import torch 5 | import logging 6 | import numpy as np 7 | import torch.nn as nn 8 | import torch.optim as optim 9 | from tqdm import tqdm 10 | from .scheduler import cosine_scheduler 11 | 12 | 13 | def get_validation_metrics(model, dataloader, options): 14 | logging.info("Started validating") 15 | 16 | metrics = {} 17 | 18 | model.eval() 19 | criterion = nn.CrossEntropyLoss(reduction = "sum").to(options.device) 20 | 21 | losses = [] 22 | 23 | with torch.no_grad(): 24 | for batch in tqdm(dataloader): 25 | input_ids, attention_mask, pixel_values = batch["input_ids"].to(options.device, non_blocking = True), batch["attention_mask"].to(options.device, non_blocking = True), batch["pixel_values"].to(options.device, non_blocking = True) 26 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = pixel_values) 27 | 28 | umodel = model.module if(options.distributed) else model 29 | 30 | logits_per_image = umodel.logit_scale.exp() * outputs.image_embeds @ outputs.text_embeds.t() 31 | logits_per_text = logits_per_image.t() 32 | 33 | target = torch.arange(len(input_ids)).long().to(options.device, non_blocking = True) 34 | loss = (criterion(logits_per_image, target) + criterion(logits_per_text, target)) / 2 35 | 36 | losses.append(loss) 37 | 38 | loss = sum(losses) / dataloader.num_samples 39 | metrics["loss"] = loss 40 | 41 | logging.info("Finished validating") 42 | 43 | return metrics 44 | 45 | def get_zeroshot_metrics(model, processor, test_dataloader, options): 46 | logging.info("Started zeroshot testing") 47 | 48 | model.eval() 49 | umodel = model.module if(options.distributed) else model 50 | config = eval(open(f"{options.eval_test_data_dir}/classes.py", "r").read()) 51 | classes, templates = config["classes"], config["templates"] 52 | 53 | with torch.no_grad(): 54 | text_embeddings = [] 55 | if options.asr: 56 | backdoor_target_index = list(filter(lambda x: 'banana' in classes[x], range(len(classes)))) 57 | backdoor_target_index = torch.tensor(backdoor_target_index[0]).to(options.device) 58 | for c in tqdm(classes): 59 | text = [template(c) for template in templates] 60 | text_tokens = processor.process_text(text) 61 | text_input_ids, text_attention_mask = text_tokens["input_ids"].to(options.device), text_tokens["attention_mask"].to(options.device) 62 | text_embedding = umodel.get_text_features(input_ids = text_input_ids, attention_mask = text_attention_mask) 63 | text_embedding /= text_embedding.norm(dim = -1, keepdim = True) 64 | text_embedding = text_embedding.mean(dim = 0) 65 | text_embedding /= text_embedding.norm() 66 | text_embeddings.append(text_embedding) 67 | text_embeddings = torch.stack(text_embeddings, dim = 1).to(options.device) 68 | 69 | with torch.no_grad(): 70 | topk = [1, 3, 5, 10] 71 | correct = {k: 0 for k in topk} 72 | total = 0 73 | for image, label in tqdm(test_dataloader): 74 | image, label = image.to(options.device), label.to(options.device) 75 | image_embedding = umodel.get_image_features(image) 76 | image_embedding /= image_embedding.norm(dim = -1, keepdim = True) 77 | logits = (image_embedding @ text_embeddings) 78 | ranks = logits.topk(max(topk), 1)[1].T 79 | predictions = ranks == label 80 | total += predictions.shape[1] 81 | for k in topk: 82 | correct[k] += torch.sum(torch.any(predictions[:k], dim = 0)).item() 83 | 84 | results = {f"zeroshot_top{k}": correct[k] / total for k in topk} 85 | with open('results.csv', 'a') as csvfile: 86 | csvwriter = csv.writer(csvfile) 87 | csvwriter.writerow([options.name, str(results)]) 88 | logging.info("Finished zeroshot testing") 89 | 90 | return results 91 | 92 | class Finetune(torch.nn.Module): 93 | def __init__(self, input_dim, output_dim, model): 94 | super(Finetune, self).__init__() 95 | self.linear = torch.nn.Linear(input_dim, output_dim) 96 | self.model = model 97 | def forward(self, x): 98 | outputs = self.linear(self.model.get_image_features(x)) 99 | return outputs 100 | 101 | class LogisticRegression(torch.nn.Module): 102 | def __init__(self, input_dim, output_dim): 103 | super(LogisticRegression, self).__init__() 104 | self.linear = torch.nn.Linear(input_dim, output_dim) 105 | 106 | def forward(self, x): 107 | outputs = self.linear(x) 108 | return outputs 109 | 110 | def get_odim_metric(options): 111 | 112 | if(options.eval_data_type == "Caltech101"): 113 | output_dim = 102 114 | metric = "accuracy" 115 | elif(options.eval_data_type == "CIFAR10"): 116 | output_dim = 10 117 | metric = "accuracy" 118 | elif(options.eval_data_type == "CIFAR100"): 119 | output_dim = 100 120 | metric = "accuracy" 121 | elif(options.eval_data_type == "DTD"): 122 | output_dim = 47 123 | metric = "accuracy" 124 | elif(options.eval_data_type == "FGVCAircraft"): 125 | output_dim = 100 126 | metric = "accuracy" 127 | elif(options.eval_data_type == "Flowers102"): 128 | output_dim = 102 129 | metric = "accuracy" 130 | elif(options.eval_data_type == "Food101"): 131 | output_dim = 101 132 | metric = "accuracy" 133 | elif(options.eval_data_type == "GTSRB"): 134 | output_dim = 43 135 | metric = "accuracy" 136 | elif(options.eval_data_type == "ImageNet1K"): 137 | output_dim = 1000 138 | metric = "accuracy" 139 | elif(options.eval_data_type == "OxfordIIITPet"): 140 | output_dim = 37 141 | metric = "accuracy" 142 | elif(options.eval_data_type == "RenderedSST2"): 143 | output_dim = 2 144 | metric = "accuracy" 145 | elif(options.eval_data_type == "StanfordCars"): 146 | output_dim = 196 147 | metric = "accuracy" 148 | elif(options.eval_data_type == "STL10"): 149 | output_dim = 10 150 | metric = "accuracy" 151 | elif(options.eval_data_type == "SVHN"): 152 | output_dim = 10 153 | metric = "accuracy" 154 | 155 | return output_dim, metric 156 | 157 | def get_finetune_metrics(model, train_dataloader, test_dataloader, options): 158 | 159 | logging.info("Starting finetune testing") 160 | model.train() 161 | umodel = model.module if(options.distributed) else model 162 | 163 | input_dim = umodel.text_projection.shape[1] 164 | output_dim, metric = get_odim_metric(options) 165 | 166 | classifier = Finetune(input_dim = input_dim, output_dim = output_dim, model = umodel).to(options.device) 167 | optimizer = optim.AdamW([{"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" in name) and parameter.requires_grad)], "weight_decay": 0}, {"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" not in name) and parameter.requires_grad)], "weight_decay": 0.01}]) 168 | scheduler = cosine_scheduler(optimizer, options.lr, options.num_warmup_steps, len(train_dataloader) * options.linear_probe_num_epochs) 169 | criterion = nn.CrossEntropyLoss().to(options.device) 170 | 171 | pbar = tqdm(range(options.linear_probe_num_epochs)) 172 | 173 | if options.checkpoint_finetune is not None: 174 | if(os.path.isfile(options.checkpoint_finetune)): 175 | checkpoint = torch.load(options.checkpoint_finetune, map_location = options.device) 176 | if(not options.distributed and next(iter(checkpoint.items()))[0].startswith("module")): 177 | checkpoint = {key[len("module."):]: value for key, value in checkpoint.items()} 178 | if(options.distributed and not next(iter(checkpoint.items()))[0].startswith("module")): 179 | checkpoint = {f'module.{key}': value for key, value in checkpoint.items()} 180 | state_dict = checkpoint["state_dict"] 181 | classifier.load_state_dict(state_dict) 182 | logging.info(f"Loaded checkpoint {options.checkpoint_finetune}") 183 | 184 | if(not options.checkpoint_finetune or not os.path.isfile(options.checkpoint_finetune)): 185 | for epoch in pbar: 186 | cbar = tqdm(train_dataloader, leave = False) 187 | for index, (image, label) in enumerate(cbar): 188 | step = len(train_dataloader) * epoch + index 189 | scheduler(step) 190 | image, label = image.to(options.device), label.to(options.device) 191 | logit = classifier(image) 192 | optimizer.zero_grad() 193 | loss = criterion(logit, label) 194 | loss.backward() 195 | optimizer.step() 196 | if options.wandb: 197 | wandb.log({'loss': loss.item(), 'lr': optimizer.param_groups[0]["lr"]}) 198 | cbar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) 199 | pbar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) 200 | checkpoint = {'state_dict': classifier.state_dict()} 201 | checkpoints_dir_path = os.path.join(options.log_dir_path, "checkpoints") 202 | os.makedirs(checkpoints_dir_path, exist_ok = True) 203 | torch.save(checkpoint, os.path.join(checkpoints_dir_path, f"finetune.pt")) 204 | 205 | classifier.eval() 206 | 207 | with torch.no_grad(): 208 | if(metric == "accuracy"): 209 | correct = 0 210 | for image, label in tqdm(test_dataloader): 211 | image, label = image.to(options.device), label.to(options.device) 212 | logits = classifier(image) 213 | prediction = torch.argmax(logits, dim = 1) 214 | if options.asr: 215 | non_label_indices = (label != 954).nonzero().squeeze() 216 | if type(non_label_indices) == int or len(non_label_indices): 217 | prediction = prediction[non_label_indices] 218 | correct += torch.sum(prediction == 954).item() 219 | else: 220 | correct += torch.sum(prediction == label).item() 221 | 222 | results = {f"linear_probe_accuracy": correct / test_dataloader.num_samples} 223 | 224 | logging.info("Finished finetune testing") 225 | return results 226 | 227 | 228 | def get_linear_probe_metrics(model, train_dataloader, test_dataloader, options): 229 | logging.info("Started linear probe testing") 230 | logging.info(f"Number of train examples: {train_dataloader.num_samples}") 231 | logging.info(f"Number of test examples: {test_dataloader.num_samples}") 232 | 233 | model.eval() 234 | umodel = model.module if(options.distributed) else model 235 | 236 | images = None 237 | labels = None 238 | with torch.no_grad(): 239 | for image, label in tqdm(train_dataloader): 240 | image = umodel.get_image_features(image.to(options.device)).cpu() 241 | images = torch.cat([images, image], dim = 0) if(images is not None) else image 242 | labels = torch.cat([labels, label], dim = 0) if(labels is not None) else label 243 | 244 | train_dataset = torch.utils.data.TensorDataset(images, labels) 245 | train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = options.batch_size, shuffle = True) 246 | 247 | input_dim = umodel.text_projection.shape[1] 248 | output_dim, metric = get_odim_metric(options) 249 | 250 | classifier = LogisticRegression(input_dim = input_dim, output_dim = output_dim).to(options.device) 251 | optimizer = optim.AdamW([{"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" in name) and parameter.requires_grad)], "weight_decay": 0}, {"params": [parameter for name, parameter in classifier.named_parameters() if(("bias" not in name) and parameter.requires_grad)], "weight_decay": 0.01}]) 252 | scheduler = cosine_scheduler(optimizer, 0.005, 0, len(train_dataloader) * options.linear_probe_num_epochs) 253 | criterion = nn.CrossEntropyLoss().to(options.device) 254 | 255 | pbar = tqdm(range(options.linear_probe_num_epochs)) 256 | for epoch in pbar: 257 | cbar = tqdm(train_dataloader, leave = False) 258 | for index, (image, label) in enumerate(cbar): 259 | step = len(train_dataloader) * epoch + index 260 | scheduler(step) 261 | image, label = image.to(options.device), label.to(options.device) 262 | logit = classifier(image) 263 | optimizer.zero_grad() 264 | loss = criterion(logit, label) 265 | loss.backward() 266 | optimizer.step() 267 | cbar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) 268 | pbar.set_postfix({"loss": loss.item(), "lr": optimizer.param_groups[0]["lr"]}) 269 | 270 | classifier.eval() 271 | 272 | with torch.no_grad(): 273 | if(metric == "accuracy"): 274 | correct = 0 275 | for image, label in tqdm(test_dataloader): 276 | image, label = image.to(options.device), label.to(options.device) 277 | logits = classifier(umodel.get_image_features(image)) 278 | prediction = torch.argmax(logits, dim = 1) 279 | if options.asr: 280 | non_label_indices = (label != 954).nonzero().squeeze() 281 | if type(non_label_indices) == int or len(non_label_indices): 282 | prediction = prediction[non_label_indices] 283 | correct += torch.sum(prediction == 954).item() 284 | else: 285 | correct += torch.sum(prediction == label).item() 286 | 287 | results = {f"linear_probe_accuracy": correct / test_dataloader.num_samples} 288 | else: 289 | correct = torch.zeros(output_dim).to(options.device) 290 | total = torch.zeros(output_dim).to(options.device) 291 | for image, label in tqdm(test_dataloader): 292 | image, label = image.to(options.device), label.to(options.device) 293 | logits = classifier(umodel.get_image_features(image)) 294 | predictions = torch.argmax(logits, dim = 1) 295 | 296 | temp = torch.zeros(output_dim, len(label)).to(options.device) 297 | temp[label, torch.arange(len(label))] = (predictions == label).float() 298 | correct += temp.sum(1) 299 | temp[label, torch.arange(len(label))] = 1 300 | total += temp.sum(1) 301 | 302 | results = {f"linear_probe_mean_per_class": (correct / total).mean().cpu().item()} 303 | 304 | logging.info("Finished linear probe testing") 305 | return results 306 | 307 | def evaluate(epoch, model, processor, data, options): 308 | metrics = {} 309 | 310 | if(options.master): 311 | if(data["validation"] is not None or data["eval_test"] is not None): 312 | if(epoch == 0): 313 | logging.info(f"Base evaluation") 314 | else: 315 | logging.info(f"Epoch {epoch} evaluation") 316 | 317 | if(data["validation"] is not None): 318 | metrics.update(get_validation_metrics(model, data["validation"], options)) 319 | 320 | if(data["eval_test"] is not None): 321 | if(data["eval_train"] is not None): 322 | if options.linear_probe: 323 | metrics.update(get_linear_probe_metrics(model, data["eval_train"], data["eval_test"], options)) 324 | elif options.finetune: 325 | metrics.update(get_finetune_metrics(model, data["eval_train"], data["eval_test"], options)) 326 | else: 327 | metrics.update(get_zeroshot_metrics(model, processor, data["eval_test"], options)) 328 | 329 | if(metrics): 330 | logging.info("Results") 331 | for key, value in metrics.items(): 332 | logging.info(f"{key}: {value:.4f}") 333 | 334 | if(options.wandb): 335 | for key, value in metrics.items(): 336 | wandb.log({f"evaluation/{key}": value, "epoch": epoch}) 337 | 338 | return metrics -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import torch 4 | import random 5 | import logging 6 | import torchvision 7 | import numpy as np 8 | import pandas as pd 9 | from tqdm import tqdm 10 | from random import shuffle 11 | from PIL import Image, ImageFile 12 | from torchvision import transforms 13 | from torch.utils.data import Dataset, DataLoader 14 | from torch.utils.data.distributed import DistributedSampler 15 | import wandb 16 | 17 | from utils.augment_text import _augment_text 18 | from utils.augment_image import _augment_image 19 | from backdoor.utils import apply_trigger 20 | 21 | ImageFile.LOAD_TRUNCATED_IMAGES = True 22 | 23 | class ImageCaptionDataset(Dataset): 24 | def __init__(self, path, image_key, caption_key, delimiter, processor, inmodal = False, defense = False, crop_size = 150): 25 | logging.debug(f"Loading aligned data from {path}") 26 | 27 | df = pd.read_csv(path, sep = delimiter) 28 | 29 | self.root = os.path.dirname(path) 30 | self.images = df[image_key].tolist() 31 | self.captions_text = df[caption_key].tolist() 32 | self.captions = processor.process_text(self.captions_text) 33 | self.processor = processor 34 | 35 | self.inmodal = inmodal 36 | if(inmodal): 37 | self.augment_captions = processor.process_text([_augment_text(caption) for caption in df[caption_key].tolist()]) 38 | 39 | self.defense = defense 40 | if self.defense: 41 | self.crop_transform = transforms.RandomCrop((crop_size, crop_size)) 42 | self.resize_transform = transforms.Resize((224, 224)) 43 | 44 | if 'is_backdoor' in df: 45 | self.is_backdoor = df['is_backdoor'].tolist() 46 | else: 47 | self.is_backdoor = None 48 | 49 | logging.debug("Loaded data") 50 | 51 | def __len__(self): 52 | return len(self.images) 53 | 54 | def __getitem__(self, idx): 55 | item = {} 56 | item["image_path"] = self.images[idx] 57 | image = Image.open(os.path.join(self.root, self.images[idx])) 58 | item["is_backdoor"] = 'backdoor' in self.images[idx] if not self.is_backdoor else self.is_backdoor[idx] 59 | item["caption"] = self.captions_text[idx] 60 | 61 | if(self.inmodal): 62 | item["input_ids"] = self.captions["input_ids"][idx], self.augment_captions["input_ids"][idx] 63 | item["attention_mask"] = self.captions["attention_mask"][idx], self.augment_captions["attention_mask"][idx] 64 | item["pixel_values"] = self.processor.process_image(image), self.processor.process_image(_augment_image(os.path.join(self.root, self.images[idx]))) 65 | else: 66 | item["input_ids"] = self.captions["input_ids"][idx] 67 | item["attention_mask"] = self.captions["attention_mask"][idx] 68 | item["pixel_values"] = self.processor.process_image(image) 69 | 70 | return item 71 | 72 | def calculate_scores(options, model, dataloader, epoch): 73 | 74 | if options.distributed: 75 | model = model.module 76 | model.eval() 77 | 78 | dirname = os.path.dirname(options.train_data) 79 | filename = f'{options.name}_{epoch}.csv' 80 | path = os.path.join(dirname, filename) 81 | 82 | csvfile = open(path, 'a') 83 | csvwriter = csv.writer(csvfile) 84 | 85 | with torch.no_grad(): 86 | logging.info(len(dataloader)) 87 | for index, batch in tqdm(enumerate(dataloader)): 88 | image, input_ids, attention_mask = batch["pixel_values"].to(options.device), batch["input_ids"].to(options.device), batch["attention_mask"].to(options.device) 89 | outputs = model(input_ids = input_ids, attention_mask = attention_mask, pixel_values = image) 90 | scores = model.logit_scale.exp() * torch.diagonal(outputs.image_embeds @ outputs.text_embeds.t()) 91 | for j in range(len(scores)): 92 | csvwriter.writerow([batch['image_path'][j], batch['caption'][j], batch['is_backdoor'][j].item(), scores[j].item()]) 93 | return path 94 | 95 | def get_clean_train_dataloader(options, processor, path): 96 | 97 | logging.info(f'Creating a clean train dataloader with path {path}') 98 | 99 | if options.master: 100 | df = pd.read_csv(path, names = ['image', 'caption', 'is_backdoor', 'score'], header = None) 101 | df = df.sort_values(by=['score'], ascending = False) 102 | df_clean = df.iloc[int(options.remove_fraction * len(df)) :] 103 | df_dirty = df.iloc[: int(options.remove_fraction * len(df))] 104 | total_backdoors = sum(df['is_backdoor'].tolist()) 105 | backdoor_detected = sum(df_dirty['is_backdoor'].tolist()) 106 | if options.wandb: 107 | wandb.log({'number of backdoored images': total_backdoors, 108 | 'number of backdoor images removed': backdoor_detected, 109 | }) 110 | df_clean.to_csv(path, index = False) 111 | # backdoor_detected = sum(df.iloc[:5000]['is_backdoor'].tolist()) 112 | # logging.info(f'Number of backdoors in Top-5000 examples: {backdoor_detected}') 113 | # for i in range(len(df)): 114 | # if i < 5000: 115 | # df.loc[i, 'is_backdoor'] = 1 116 | # else: 117 | # df.loc[i, 'is_backdoor'] = 0 118 | # df.to_csv(path, index = False) 119 | 120 | dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor) 121 | sampler = DistributedSampler(dataset) if(options.distributed) else None 122 | dataloader = DataLoader(dataset, batch_size = options.batch_size, shuffle = (sampler is None), num_workers = options.num_workers, pin_memory = True, sampler = sampler, drop_last = True) 123 | dataloader.num_samples = len(dataloader) * options.batch_size 124 | dataloader.num_batches = len(dataloader) 125 | return dataloader 126 | 127 | def get_train_dataloader(options, processor): 128 | path = options.train_data 129 | if(path is None): return None 130 | 131 | batch_size = options.batch_size 132 | 133 | dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor, inmodal = options.inmodal) 134 | 135 | sampler = DistributedSampler(dataset) if(options.distributed) else None 136 | 137 | dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = (sampler is None), num_workers = options.num_workers, pin_memory = True, sampler = sampler, drop_last = True) 138 | dataloader.num_samples = len(dataloader) * batch_size 139 | dataloader.num_batches = len(dataloader) 140 | 141 | return dataloader 142 | 143 | def get_validation_dataloader(options, processor): 144 | path = options.validation_data 145 | if(path is None): return 146 | 147 | dataset = ImageCaptionDataset(path, image_key = options.image_key, caption_key = options.caption_key, delimiter = options.delimiter, processor = processor, inmodal = options.inmodal) 148 | dataloader = DataLoader(dataset, batch_size = options.batch_size, shuffle = False, num_workers = options.num_workers, pin_memory = True, sampler = None, drop_last = False) 149 | dataloader.num_samples = len(dataset) 150 | dataloader.num_batches = len(dataloader) 151 | 152 | return dataloader 153 | 154 | class ImageLabelDataset(Dataset): 155 | def __init__(self, root, transform, options = None): 156 | self.root = root 157 | # filename = 'labels.10K.csv' if 'train50000' in root and '10K' in options.name else 'labels.5K.csv' if 'train50000' in root and '5K' in options.name else 'labels.csv' 158 | # print(filename) 159 | # df = pd.read_csv(os.path.join(root, filename)) 160 | df = pd.read_csv(os.path.join(root, 'labels.csv')) 161 | self.images = df["image"] 162 | self.labels = df["label"] 163 | self.transform = transform 164 | self.options = options 165 | self.add_backdoor = options.add_backdoor 166 | self.backdoor_sufi = options.backdoor_sufi 167 | if self.backdoor_sufi: 168 | self.backdoor_indices = list(range(50000)) 169 | shuffle(self.backdoor_indices) 170 | self.backdoor_indices = self.backdoor_indices[:1000] 171 | 172 | def __len__(self): 173 | return len(self.labels) 174 | 175 | def add_trigger(self, image, patch_size = 16, patch_type = 'blended', patch_location = 'blended'): 176 | return apply_trigger(image, patch_size, patch_type, patch_location) 177 | 178 | def __getitem__(self, idx): 179 | 180 | image = Image.open(os.path.join(self.root, self.images[idx])).convert('RGB') 181 | 182 | if self.backdoor_sufi: 183 | if idx in self.backdoor_indices: 184 | image = self.add_trigger(image, patch_size = self.options.patch_size, patch_type = self.options.patch_type, patch_location = self.options.patch_location) 185 | label = 954 186 | return image, label 187 | 188 | if self.add_backdoor: 189 | image = self.add_trigger(image, patch_size = self.options.patch_size, patch_type = self.options.patch_type, patch_location = self.options.patch_location) 190 | 191 | image = self.transform(image) 192 | label = self.labels[idx] 193 | return image, label 194 | 195 | def get_eval_test_dataloader(options, processor): 196 | if(options.eval_test_data_dir is None): return 197 | 198 | if(options.eval_data_type == "Caltech101"): 199 | dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image) 200 | elif(options.eval_data_type == "CIFAR10"): 201 | dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image) 202 | elif(options.eval_data_type == "CIFAR100"): 203 | dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.eval_test_data_dir), download = True, train = False, transform = processor.process_image) 204 | elif(options.eval_data_type == "DTD"): 205 | dataset = torchvision.datasets.DTD(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 206 | elif(options.eval_data_type == "FGVCAircraft"): 207 | dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 208 | elif(options.eval_data_type == "Flowers102"): 209 | dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image) 210 | elif(options.eval_data_type == "Food101"): 211 | dataset = torchvision.datasets.Food101(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 212 | elif(options.eval_data_type == "GTSRB"): 213 | dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 214 | elif(options.eval_data_type == "ImageNet1K"): 215 | print(f'Test: {options.add_backdoor}') 216 | dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image, options = options) 217 | elif(options.eval_data_type == "OxfordIIITPet"): 218 | dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 219 | elif(options.eval_data_type == "RenderedSST2"): 220 | dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 221 | elif(options.eval_data_type == "StanfordCars"): 222 | dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 223 | elif(options.eval_data_type == "STL10"): 224 | dataset = torchvision.datasets.STL10(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 225 | elif(options.eval_data_type == "SVHN"): 226 | dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.eval_test_data_dir), download = True, split = "test", transform = processor.process_image) 227 | elif(options.eval_data_type in ["ImageNetSketch", "ImageNetV2", "ImageNet-A", "ImageNet-R"]): 228 | dataset = ImageLabelDataset(root = options.eval_test_data_dir, transform = processor.process_image) 229 | else: 230 | raise Exception(f"Eval test dataset type {options.eval_data_type} is not supported") 231 | 232 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.batch_size, num_workers = options.num_workers, sampler = None) 233 | dataloader.num_samples = len(dataset) 234 | dataloader.num_batches = len(dataloader) 235 | 236 | return dataloader 237 | 238 | def get_eval_train_dataloader(options, processor): 239 | # if(not options.linear_probe or not options.finetune or options.eval_train_data_dir is None): return 240 | if(options.eval_train_data_dir is None): return 241 | 242 | if(options.eval_data_type == "Caltech101"): 243 | dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image) 244 | elif(options.eval_data_type == "CIFAR10"): 245 | dataset = torchvision.datasets.CIFAR10(root = os.path.dirname(options.eval_train_data_dir), download = True, train = True, transform = processor.process_image) 246 | elif(options.eval_data_type == "CIFAR100"): 247 | dataset = torchvision.datasets.CIFAR100(root = os.path.dirname(options.eval_test_data_dir), download = True, train = True, transform = processor.process_image) 248 | elif(options.eval_data_type == "DTD"): 249 | dataset = torch.utils.data.ConcatDataset([torchvision.datasets.DTD(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image), torchvision.datasets.DTD(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "val", transform = processor.process_image)]) 250 | elif(options.eval_data_type == "FGVCAircraft"): 251 | dataset = torchvision.datasets.FGVCAircraft(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "trainval", transform = processor.process_image) 252 | elif(options.eval_data_type == "Flowers102"): 253 | dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image) 254 | elif(options.eval_data_type == "Food101"): 255 | dataset = torchvision.datasets.Food101(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 256 | elif(options.eval_data_type == "GTSRB"): 257 | dataset = torchvision.datasets.GTSRB(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 258 | elif(options.eval_data_type == "ImageNet1K"): 259 | options.add_backdoor = False 260 | dataset = ImageLabelDataset(root = options.eval_train_data_dir, transform = processor.process_image, options = options) 261 | elif(options.eval_data_type == "OxfordIIITPet"): 262 | dataset = torchvision.datasets.OxfordIIITPet(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "trainval", transform = processor.process_image) 263 | elif(options.eval_data_type == "RenderedSST2"): 264 | dataset = torchvision.datasets.RenderedSST2(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 265 | elif(options.eval_data_type == "StanfordCars"): 266 | dataset = torchvision.datasets.StanfordCars(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 267 | elif(options.eval_data_type == "STL10"): 268 | dataset = torchvision.datasets.STL10(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 269 | elif(options.eval_data_type == "SVHN"): 270 | dataset = torchvision.datasets.SVHN(root = os.path.dirname(options.eval_train_data_dir), download = True, split = "train", transform = processor.process_image) 271 | else: 272 | raise Exception(f"Eval train dataset type {options.eval_data_type} is not supported") 273 | 274 | dataloader = torch.utils.data.DataLoader(dataset, batch_size = options.linear_probe_batch_size, num_workers = options.num_workers, sampler = None, shuffle = True) 275 | dataloader.num_samples = len(dataset) 276 | dataloader.num_batches = len(dataloader) 277 | 278 | return dataloader 279 | 280 | def load(options, processor): 281 | data = {} 282 | 283 | data["train"] = get_train_dataloader(options, processor) 284 | data["validation"] = get_validation_dataloader(options, processor) 285 | data["eval_test"] = get_eval_test_dataloader(options, processor) 286 | data["eval_train"] = get_eval_train_dataloader(options, processor) 287 | 288 | return data --------------------------------------------------------------------------------