├── LICENSE ├── README.md ├── cfg ├── dataset.cfg └── singlesource_singletarget_1000class_finetune_deit_base │ ├── experiment_0001_base.cfg │ ├── experiment_0002_base.cfg │ ├── experiment_0003_base.cfg │ ├── experiment_0004_base.cfg │ ├── experiment_0005_base.cfg │ ├── experiment_0006_base.cfg │ ├── experiment_0007_base.cfg │ ├── experiment_0008_base.cfg │ ├── experiment_0009_base.cfg │ └── experiment_0010_base.cfg ├── create_imagenet_filelist.py ├── data ├── transformer │ ├── 0001_base │ │ └── source_wnid_list.txt │ ├── 0002_base │ │ └── source_wnid_list.txt │ ├── 0003_base │ │ └── source_wnid_list.txt │ ├── 0004_base │ │ └── source_wnid_list.txt │ ├── 0005_base │ │ └── source_wnid_list.txt │ ├── 0006_base │ │ └── source_wnid_list.txt │ ├── 0007_base │ │ └── source_wnid_list.txt │ ├── 0008_base │ │ └── source_wnid_list.txt │ ├── 0009_base │ │ └── source_wnid_list.txt │ └── 0010_base │ │ └── source_wnid_list.txt └── trigger │ ├── trigger_10.png │ ├── trigger_11.png │ ├── trigger_12.png │ ├── trigger_13.png │ ├── trigger_14.png │ ├── trigger_15.png │ ├── trigger_16.png │ ├── trigger_17.png │ ├── trigger_18.png │ └── trigger_19.png ├── dataset.py ├── finetune_transformer.py ├── generate_poison_transformer.py ├── run_pipeline.sh ├── test_time_defense.py ├── transformer_teaser5.jpg └── vit_grad_rollout.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 UCDvision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Backdoor Attacks on Vision Transformers 2 | Official Repository of ''Backdoor Attacks on Vision Transformers''. 3 | 4 | ![transformer_teaser](https://user-images.githubusercontent.com/32045261/177569095-a0d2585e-7511-4e0f-8d87-8680599f0ede.jpg) 5 | 6 | ## Requirements 7 | 8 | - Python >= 3.7.6 9 | - PyTorch >= 1.4 10 | - torchvision >= 0.5.0 11 | - timm==0.3.2 12 | 13 | ## Dataset creation 14 | We follow the same steps as ''Hidden Trigger Backdoor Attacks'' for dataset preparations. We repeat the instructions here for convenience. 15 | 16 | ```python 17 | python create_imagenet_filelist.py cfg/dataset.cfg 18 | ``` 19 | 20 | + Change ImageNet data source in dataset.cfg 21 | 22 | + This script partitions the ImageNet train and val data into poison generation, finetune and val to run HTBA attack. Change this for your specific needs. 23 | 24 | 25 | 26 | ## Configuration file 27 | 28 | + Please create a separate configuration file for each experiment. 29 | + One example is cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg. Create a copy and make desired changes. 30 | + The configuration file makes it easy to control all parameters (e.g. poison injection rate, epsilon, patch_size, trigger_ID) 31 | 32 | ## Poison generation 33 | + First create directory data/transformer/ and a file in it named source_wnid_list.txt which will contain all the wnids of the source categories for the experiment. 34 | ```python 35 | python generate_poison_transformer.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg 36 | ``` 37 | 38 | ## Finetune 39 | ```python 40 | python finetune_transformer.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg 41 | ``` 42 | 43 | ## Test-time defense 44 | ```python 45 | python test_time_defense.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg 46 | ``` 47 | 48 | ## Data 49 | 50 | + We have provided the triggers used in our experiments in data/triggers 51 | + To reproduce our experiments please use the correct poison injection rates. There might be some variation in numbers depending on the randomness of the ImageNet data split. 52 | 53 | 54 | ## License 55 | 56 | This project is under the MIT license. 57 | 58 | 59 | ## Citation 60 | Please cite us using: 61 | ```bib 62 | @article{subramanya2022backdoor, 63 | title={Backdoor Attacks on Vision Transformers}, 64 | author={Subramanya, Akshayvarun and Saha, Aniruddha and Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed}, 65 | journal={arXiv preprint arXiv:2206.08477}, 66 | year={2022} 67 | } 68 | ``` 69 | -------------------------------------------------------------------------------- /cfg/dataset.cfg: -------------------------------------------------------------------------------- 1 | [dataset] 2 | data_dir=/nfs3/datasets/imagenet 3 | poison_generation=200 4 | finetune=800 5 | test=50 6 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0001_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=10 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n02096294 18 | source_wnid_list=data/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=10 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0002_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0002_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=11 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n02206856 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=11 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0003_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0003_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=12 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n03970156 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=12 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0004_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0004_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=13 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n01807496 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=13 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0005_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0005_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=14 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n03584254 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=14 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0006_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0006_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=15 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n02092002 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=15 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0007_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0007_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=16 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n01819313 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=16 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0008_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0008_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=17 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n04462240 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=17 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0009_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0009_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=18 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n02165105 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=18 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0010_base.cfg: -------------------------------------------------------------------------------- 1 | [experiment] 2 | ID=0010_base 3 | 4 | [poison_generation] 5 | data_root=/datasets/imagenet 6 | txt_root=ImageNet_data_list 7 | seed=None 8 | gpu=0 9 | epochs=3 10 | patch_size=30 11 | eps=16 12 | lr=0.001 13 | rand_loc=true 14 | trigger_id=19 15 | num_iter=5000 16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log 17 | target_wnid=n03443371 18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt 19 | num_source=1 20 | 21 | [finetune] 22 | clean_data_root=/datasets/imagenet 23 | poison_root=transformers_data/poison_data 24 | epochs=10 25 | gpu=0 26 | patch_size=30 27 | eps=16 28 | lr=0.001 29 | momentum=0.9 30 | rand_loc=true 31 | trigger_id=19 32 | num_poison=600 33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log 34 | num_classes=1000 35 | batch_size=64 36 | -------------------------------------------------------------------------------- /create_imagenet_filelist.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import glob 3 | import os 4 | import sys 5 | import random 6 | import pdb 7 | import tqdm 8 | 9 | random.seed(10) 10 | config = configparser.ConfigParser() 11 | config.read(sys.argv[1]) 12 | 13 | options = {} 14 | for key, value in config['dataset'].items(): 15 | key, value = key.strip(), value.strip() 16 | options[key] = value 17 | 18 | if not os.path.exists("ImageNet_data_list/poison_generation"): 19 | os.makedirs("ImageNet_data_list/poison_generation") 20 | if not os.path.exists("ImageNet_data_list/finetune"): 21 | os.makedirs("ImageNet_data_list/finetune") 22 | if not os.path.exists("ImageNet_data_list/test"): 23 | os.makedirs("ImageNet_data_list/test") 24 | 25 | DATA_DIR = options["data_dir"] 26 | 27 | dir_list = sorted(glob.glob(DATA_DIR + "/train/*")) 28 | 29 | for i, dir_name in enumerate(dir_list): 30 | if i%50==0: 31 | print(i) 32 | filelist = sorted(glob.glob(dir_name + "/*")) 33 | random.shuffle(filelist) 34 | 35 | with open("ImageNet_data_list/poison_generation/" + os.path.basename(dir_name) + ".txt", "w") as f: 36 | for ctr in range(int(options["poison_generation"])): 37 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n") 38 | with open("ImageNet_data_list/finetune/" + os.path.basename(dir_name) + ".txt", "w") as f: 39 | for ctr in range(int(options["poison_generation"]), len(filelist)): 40 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n") 41 | 42 | dir_list = sorted(glob.glob(DATA_DIR + "/val/*")) 43 | 44 | for i, dir_name in enumerate(dir_list): 45 | if i%50==0: 46 | print(i) 47 | filelist = sorted(glob.glob(dir_name + "/*")) 48 | with open("ImageNet_data_list/test/" + os.path.basename(dir_name) + ".txt", "w") as f: 49 | for ctr in range(int(options["test"])): 50 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n") 51 | 52 | -------------------------------------------------------------------------------- /data/transformer/0001_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n04243546 2 | -------------------------------------------------------------------------------- /data/transformer/0002_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n03666591 2 | -------------------------------------------------------------------------------- /data/transformer/0003_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n04418357 2 | -------------------------------------------------------------------------------- /data/transformer/0004_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n04509417 2 | -------------------------------------------------------------------------------- /data/transformer/0005_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n03792782 2 | -------------------------------------------------------------------------------- /data/transformer/0006_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n03063689 2 | -------------------------------------------------------------------------------- /data/transformer/0007_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n02951585 2 | -------------------------------------------------------------------------------- /data/transformer/0008_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n07697537 2 | -------------------------------------------------------------------------------- /data/transformer/0009_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n03272562 2 | -------------------------------------------------------------------------------- /data/transformer/0010_base/source_wnid_list.txt: -------------------------------------------------------------------------------- 1 | n04592741 2 | -------------------------------------------------------------------------------- /data/trigger/trigger_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_10.png -------------------------------------------------------------------------------- /data/trigger/trigger_11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_11.png -------------------------------------------------------------------------------- /data/trigger/trigger_12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_12.png -------------------------------------------------------------------------------- /data/trigger/trigger_13.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_13.png -------------------------------------------------------------------------------- /data/trigger/trigger_14.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_14.png -------------------------------------------------------------------------------- /data/trigger/trigger_15.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_15.png -------------------------------------------------------------------------------- /data/trigger/trigger_16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_16.png -------------------------------------------------------------------------------- /data/trigger/trigger_17.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_17.png -------------------------------------------------------------------------------- /data/trigger/trigger_18.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_18.png -------------------------------------------------------------------------------- /data/trigger/trigger_19.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_19.png -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils import data 3 | from PIL import Image 4 | 5 | class LabeledDataset(data.Dataset): 6 | def __init__(self, data_root, path_to_txt_file, transform): 7 | self.data_root = data_root 8 | with open(path_to_txt_file, 'r') as f: 9 | self.file_list = f.readlines() 10 | self.file_list = [row.rstrip() for row in self.file_list] 11 | 12 | self.transform = transform 13 | 14 | 15 | def __getitem__(self, idx): 16 | image_path = os.path.join(self.data_root, self.file_list[idx].split()[0]) 17 | img = Image.open(image_path).convert('RGB') 18 | target = int(self.file_list[idx].split()[1]) 19 | 20 | if self.transform is not None: 21 | img = self.transform(img) 22 | 23 | return img, target, image_path 24 | 25 | def __len__(self): 26 | return len(self.file_list) 27 | 28 | class PoisonGenerationDataset(data.Dataset): 29 | def __init__(self, data_root, path_to_txt_file, transform): 30 | self.data_root = data_root 31 | with open(path_to_txt_file, 'r') as f: 32 | self.file_list = f.readlines() 33 | self.file_list = [row.rstrip() for row in self.file_list] 34 | 35 | self.transform = transform 36 | 37 | 38 | def __getitem__(self, idx): 39 | image_path = os.path.join(self.data_root, self.file_list[idx]) 40 | img = Image.open(image_path).convert('RGB') 41 | # target = self.file_list[idx].split()[1] 42 | 43 | if self.transform is not None: 44 | img = self.transform(img) 45 | 46 | return img, image_path 47 | 48 | def __len__(self): 49 | return len(self.file_list) 50 | -------------------------------------------------------------------------------- /finetune_transformer.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import random 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | import numpy as np 8 | from torchvision import datasets, models, transforms 9 | import time 10 | import os 11 | import copy 12 | import logging 13 | import sys 14 | import configparser 15 | import glob 16 | from tqdm import tqdm 17 | from dataset import LabeledDataset 18 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224 19 | from functools import partial 20 | 21 | config = configparser.ConfigParser() 22 | config.read(sys.argv[1]) 23 | 24 | experimentID = config["experiment"]["ID"] 25 | 26 | options = config["finetune"] 27 | clean_data_root = options["clean_data_root"] 28 | poison_root = options["poison_root"] 29 | gpu = int(options["gpu"]) 30 | epochs = int(options["epochs"]) 31 | patch_size = int(options["patch_size"]) 32 | eps = int(options["eps"]) 33 | rand_loc = options.getboolean("rand_loc") 34 | trigger_id = int(options["trigger_id"]) 35 | num_poison = int(options["num_poison"]) 36 | num_classes = int(options["num_classes"]) 37 | batch_size = int(options["batch_size"]) 38 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, num_poison, trigger_id) 39 | lr = float(options["lr"]) 40 | momentum = float(options["momentum"]) 41 | 42 | options = config["poison_generation"] 43 | target_wnid = options["target_wnid"] 44 | source_wnid_list = options["source_wnid_list"].format(experimentID) 45 | num_source = int(options["num_source"]) 46 | 47 | checkpointDir = "checkpoints/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 48 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id) 49 | 50 | if not os.path.exists(os.path.dirname(checkpointDir)): 51 | os.makedirs(os.path.dirname(checkpointDir)) 52 | 53 | #logging 54 | if not os.path.exists(os.path.dirname(logfile)): 55 | os.makedirs(os.path.dirname(logfile)) 56 | 57 | logging.basicConfig( 58 | level=logging.INFO, 59 | format="%(asctime)s %(message)s", 60 | handlers=[ 61 | logging.FileHandler(logfile, "w"), 62 | logging.StreamHandler() 63 | ]) 64 | 65 | logging.info("Experiment ID: {}".format(experimentID)) 66 | 67 | 68 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception] 69 | model_name = 'deit_base_patch16_224' 70 | 71 | # Flag for feature extracting. When False, we finetune the whole model, 72 | # when True we only update the reshaped layer params 73 | feature_extract = True 74 | 75 | def save_checkpoint(state, filename='checkpoint.pth.tar'): 76 | if not os.path.exists(os.path.dirname(filename)): 77 | os.makedirs(os.path.dirname(filename)) 78 | torch.save(state, filename) 79 | 80 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 83 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 84 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 85 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 86 | std = [ 1., 1., 1. ]),]) 87 | 88 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 89 | 90 | trigger = Image.open('data/trigger/trigger_{}.png'.format(trigger_id)).convert('RGB') 91 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu) 92 | 93 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False): 94 | since = time.time() 95 | 96 | best_model_wts = copy.deepcopy(model.state_dict()) 97 | best_acc = 0.0 98 | 99 | test_acc_arr = np.zeros(num_epochs) 100 | patched_acc_arr = np.zeros(num_epochs) 101 | notpatched_acc_arr = np.zeros(num_epochs) 102 | 103 | 104 | for epoch in range(num_epochs): 105 | adjust_learning_rate(optimizer, epoch) 106 | logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1)) 107 | logging.info('-' * 10) 108 | 109 | # Each epoch has a training and validation phase 110 | for phase in ['train', 'test', 'notpatched', 'patched']: 111 | if phase == 'train': 112 | model.train() # Set model to training mode 113 | else: 114 | model.eval() # Set model to evaluate mode 115 | 116 | running_loss = 0.0 117 | running_corrects = 0 118 | 119 | # Set nn in patched phase to be higher if you want to cover variability in trigger placement 120 | if phase == 'patched': 121 | nn=1 122 | else: 123 | nn=1 124 | 125 | for ctr in range(0, nn): 126 | # Iterate over data. 127 | debug_idx= 0 128 | for inputs, labels,paths in tqdm(dataloaders[phase]): 129 | debug_idx+=1 130 | inputs = inputs.cuda(gpu) 131 | labels = labels.cuda(gpu) 132 | if phase == 'patched': 133 | random.seed(1) 134 | for z in range(inputs.size(0)): 135 | if not rand_loc: 136 | start_x = 224-patch_size-5 137 | start_y = 224-patch_size-5 138 | else: 139 | start_x = random.randint(0, 224-patch_size-1) 140 | start_y = random.randint(0, 224-patch_size-1) 141 | 142 | inputs[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger# 143 | 144 | # zero the parameter gradients 145 | optimizer.zero_grad() 146 | 147 | # forward 148 | # track history if only in train 149 | with torch.set_grad_enabled(phase == 'train'): 150 | # Get model outputs and calculate loss 151 | # Special case for inception because in training it has an auxiliary output. In train 152 | # mode we calculate the loss by summing the final output and the auxiliary output 153 | # but in testing we only consider the final output. 154 | if is_inception and phase == 'train': 155 | # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 156 | outputs, aux_outputs = model(inputs) 157 | loss1 = criterion(outputs, labels) 158 | loss2 = criterion(aux_outputs, labels) 159 | loss = loss1 + 0.4*loss2 160 | else: 161 | outputs = model(inputs) 162 | loss = criterion(outputs, labels) 163 | 164 | _, preds = torch.max(outputs, 1) 165 | 166 | if phase =='train': 167 | if debug_idx % (len(dataloaders[phase])//5) == 0 and epoch>=0: 168 | for inp2, lab2,paths2 in tqdm(dataloaders['patched']): 169 | inp2 = inp2.cuda(gpu) 170 | lab2 = lab2.cuda(gpu) 171 | random.seed(1) 172 | for z in range(inp2.size(0)): 173 | if not rand_loc: 174 | start_x = 224-patch_size-5 175 | start_y = 224-patch_size-5 176 | else: 177 | start_x = random.randint(0, 224-patch_size-1) 178 | start_y = random.randint(0, 224-patch_size-1) 179 | 180 | inp2[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger# 181 | out2 = model(inp2) 182 | # _, preds = torch.max(outputs, 1) 183 | _,preds2 = torch.topk(out2,5,1) 184 | for patched_idx in range(inp2.shape[0]): 185 | logging.info('Image Number:{}\tTarget Label:{}\tTop-5 predictions:{}\t{}\t{}\t{}\t{}\n'.format(patched_idx,lab2[patched_idx],preds2[patched_idx,0],preds2[patched_idx,1],preds2[patched_idx,2],preds2[patched_idx,3],preds2[patched_idx,4] )) 186 | # backward + optimize only if in training phase 187 | if phase == 'train': 188 | loss.backward() 189 | optimizer.step() 190 | 191 | # statistics 192 | running_loss += loss.item() * inputs.size(0) 193 | running_corrects += torch.sum(preds == labels.data) 194 | 195 | epoch_loss = running_loss / len(dataloaders[phase].dataset) / nn 196 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) / nn 197 | 198 | 199 | 200 | logging.info('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 201 | if phase == 'test': 202 | test_acc_arr[epoch] = epoch_acc 203 | if phase == 'patched': 204 | patched_acc_arr[epoch] = epoch_acc 205 | logging.info('Patched Targeted Attack Success Rate: Mean {:.3f}' 206 | .format(epoch_acc)) 207 | if phase == 'notpatched': 208 | notpatched_acc_arr[epoch] = epoch_acc 209 | # deep copy the model 210 | if phase == 'test' and (epoch_acc > best_acc): 211 | logging.info("Better model found!") 212 | best_acc = epoch_acc 213 | best_model_wts = copy.deepcopy(model.state_dict()) 214 | 215 | time_elapsed = time.time() - since 216 | logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 217 | logging.info('Max Test Acc: {:4f}'.format(best_acc)) 218 | logging.info('Last 10 Epochs Test Acc: Mean {:.3f} Std {:.3f} ' 219 | .format(test_acc_arr[-10:].mean(),test_acc_arr[-10:].std())) 220 | logging.info('Last 10 Epochs Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} ' 221 | .format(patched_acc_arr[-10:].mean(),patched_acc_arr[-10:].std())) 222 | logging.info('Last 10 Epochs NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} ' 223 | .format(notpatched_acc_arr[-10:].mean(),notpatched_acc_arr[-10:].std())) 224 | 225 | sort_idx = np.argsort(test_acc_arr) 226 | top10_idx = sort_idx[-10:] 227 | logging.info('10 Epochs with Best Acc- Test Acc: Mean {:.3f} Std {:.3f} ' 228 | .format(test_acc_arr[top10_idx].mean(),test_acc_arr[top10_idx].std())) 229 | logging.info('10 Epochs with Best Acc- Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} ' 230 | .format(patched_acc_arr[top10_idx].mean(),patched_acc_arr[top10_idx].std())) 231 | logging.info('10 Epochs with Best Acc- NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} ' 232 | .format(notpatched_acc_arr[top10_idx].mean(),notpatched_acc_arr[top10_idx].std())) 233 | 234 | # save meta into pickle 235 | meta_dict = {'Val_acc': test_acc_arr, 236 | 'Patched_acc': patched_acc_arr, 237 | 'NotPatched_acc': notpatched_acc_arr 238 | } 239 | 240 | # load best model weights 241 | model.load_state_dict(best_model_wts) 242 | return model, meta_dict 243 | 244 | 245 | def set_parameter_requires_grad(model, feature_extracting): 246 | if feature_extracting: 247 | for param in model.parameters(): 248 | param.requires_grad = False 249 | 250 | 251 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True): 252 | # Initialize these variables which will be set in this if statement. Each of these 253 | # variables is model specific. 254 | model_ft = None 255 | input_size = 0 256 | 257 | if model_name == "resnet": 258 | """ Resnet18 259 | """ 260 | model_ft = models.resnet18(pretrained=use_pretrained) 261 | set_parameter_requires_grad(model_ft, feature_extract) 262 | num_ftrs = model_ft.fc.in_features 263 | model_ft.fc = nn.Linear(num_ftrs, num_classes) 264 | input_size = 224 265 | 266 | elif model_name == "alexnet": 267 | """ Alexnet 268 | """ 269 | model_ft = models.alexnet(pretrained=use_pretrained) 270 | set_parameter_requires_grad(model_ft, feature_extract) 271 | num_ftrs = model_ft.classifier[6].in_features 272 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 273 | input_size = 224 274 | 275 | elif model_name == "vgg": 276 | """ VGG11_bn 277 | """ 278 | model_ft = models.vgg11_bn(pretrained=use_pretrained) 279 | set_parameter_requires_grad(model_ft, feature_extract) 280 | num_ftrs = model_ft.classifier[6].in_features 281 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 282 | input_size = 224 283 | 284 | elif model_name == "squeezenet": 285 | """ Squeezenet 286 | """ 287 | model_ft = models.squeezenet1_0(pretrained=use_pretrained) 288 | set_parameter_requires_grad(model_ft, feature_extract) 289 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 290 | model_ft.num_classes = num_classes 291 | input_size = 224 292 | 293 | elif model_name == "densenet": 294 | """ Densenet 295 | """ 296 | model_ft = models.densenet121(pretrained=use_pretrained) 297 | set_parameter_requires_grad(model_ft, feature_extract) 298 | num_ftrs = model_ft.classifier.in_features 299 | model_ft.classifier = nn.Linear(num_ftrs, num_classes) 300 | input_size = 224 301 | 302 | elif model_name == "inception": 303 | """ Inception v3 304 | Be careful, expects (299,299) sized images and has auxiliary output 305 | """ 306 | kwargs = {"transform_input": True} 307 | model_ft = models.inception_v3(pretrained=use_pretrained, **kwargs) 308 | set_parameter_requires_grad(model_ft, feature_extract) 309 | # Handle the auxilary net 310 | num_ftrs = model_ft.AuxLogits.fc.in_features 311 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) 312 | # Handle the primary net 313 | num_ftrs = model_ft.fc.in_features 314 | model_ft.fc = nn.Linear(num_ftrs,num_classes) 315 | input_size = 299 316 | 317 | elif model_name == 'deit_tiny_patch16_224': 318 | model_ft = VisionTransformer( 319 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 320 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 321 | model_ft.default_cfg = _cfg() 322 | 323 | checkpoint = torch.hub.load_state_dict_from_url( 324 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 325 | map_location="cpu", check_hash=True 326 | ) 327 | model_ft.load_state_dict(checkpoint["model"]) 328 | set_parameter_requires_grad(model_ft, feature_extract) 329 | num_ftrs = model_ft.num_features 330 | weights = model_ft.head.weight.clone() 331 | bias = model_ft.head.bias.clone() 332 | model_ft.head = nn.Linear(num_ftrs, num_classes) 333 | model_ft.head.weight.data = weights 334 | model_ft.head.bias.data = bias 335 | # nn.init.zeros_(model_ft.head.weight) 336 | # nn.init.constant_(model_ft.head.bias, 0.0) 337 | input_size = 224 338 | elif model_name == 'deit_small_patch16_224': 339 | model_ft = VisionTransformer( 340 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 341 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 342 | model_ft.default_cfg = _cfg() 343 | 344 | checkpoint = torch.hub.load_state_dict_from_url( 345 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 346 | map_location="cpu", check_hash=True 347 | ) 348 | model_ft.load_state_dict(checkpoint["model"]) 349 | set_parameter_requires_grad(model_ft, feature_extract) 350 | num_ftrs = model_ft.num_features 351 | model_ft.head = nn.Linear(num_ftrs, num_classes) 352 | input_size = 224 353 | elif model_name == 'deit_base_patch16_224': 354 | breakpoint() 355 | model_ft = VisionTransformer( 356 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 357 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 358 | model_ft.default_cfg = _cfg() 359 | checkpoint = torch.hub.load_state_dict_from_url( 360 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 361 | map_location="cpu", check_hash=True 362 | ) 363 | model_ft.load_state_dict(checkpoint["model"]) 364 | set_parameter_requires_grad(model_ft, feature_extract) 365 | num_ftrs = model_ft.num_features 366 | model_ft.head = nn.Linear(num_ftrs, num_classes) 367 | input_size = 224 368 | elif model_name == 'vit_large_patch16_224': 369 | # model_ft = VisionTransformer( 370 | # patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 371 | # norm_layer=partial(nn.LayerNorm, eps=1e-6)) 372 | model_ft = vit_large_patch16_224(pretrained=True) 373 | model_ft.default_cfg = _cfg() 374 | set_parameter_requires_grad(model_ft, feature_extract) 375 | num_ftrs = model_ft.num_features 376 | model_ft.head = nn.Linear(num_ftrs, num_classes) 377 | input_size = 224 378 | 379 | else: 380 | logging.info("Invalid model name, exiting...") 381 | exit() 382 | 383 | return model_ft, input_size 384 | 385 | def adjust_learning_rate(optimizer, epoch): 386 | global lr 387 | """Sets the learning rate to the initial LR decayed 10 times every 10 epochs""" 388 | lr1 = lr * (0.1 ** (epoch // 10)) 389 | for param_group in optimizer.param_groups: 390 | param_group['lr'] = lr1 391 | 392 | 393 | # Train poisoned model 394 | logging.info("Training poisoned model...") 395 | # Initialize the model for this run 396 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True) 397 | logging.info(model_ft) 398 | 399 | # Transforms 400 | data_transforms = transforms.Compose([ 401 | transforms.Resize((input_size, input_size)), 402 | transforms.ToTensor(), 403 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 404 | 405 | logging.info("Initializing Datasets and Dataloaders...") 406 | 407 | # Training dataset 408 | # if not os.path.exists("data/{}/finetune_filelist.txt".format(experimentID)): 409 | with open("data/transformer/{}/finetune_filelist.txt".format(experimentID), "w") as f1: 410 | with open(source_wnid_list) as f2: 411 | source_wnids = f2.readlines() 412 | source_wnids = [s.strip() for s in source_wnids] 413 | 414 | if num_classes==1000: 415 | wnid_mapping = {} 416 | all_wnids = sorted(glob.glob("ImageNet_data_list/finetune/*")) 417 | for i, wnid in enumerate(all_wnids): 418 | wnid = os.path.basename(wnid).split(".")[0] 419 | wnid_mapping[wnid] = i 420 | if wnid==target_wnid: 421 | target_index=i 422 | with open("ImageNet_data_list/finetune/" + wnid + ".txt", "r") as f2: 423 | lines = f2.readlines() 424 | for line in lines: 425 | f1.write(line.strip() + " " + str(i) + "\n") 426 | 427 | else: 428 | for i, source_wnid in enumerate(source_wnids): 429 | with open("ImageNet_data_list/finetune/" + source_wnid + ".txt", "r") as f2: 430 | lines = f2.readlines() 431 | for line in lines: 432 | f1.write(line.strip() + " " + str(i) + "\n") 433 | 434 | with open("ImageNet_data_list/finetune/" + target_wnid + ".txt", "r") as f2: 435 | lines = f2.readlines() 436 | for line in lines: 437 | f1.write(line.strip() + " " + str(num_source) + "\n") 438 | 439 | # Test dataset 440 | # if not os.path.exists("data/{}/test_filelist.txt".format(experimentID)): 441 | with open("data/transformer/{}/test_filelist.txt".format(experimentID), "w") as f1: 442 | with open(source_wnid_list) as f2: 443 | source_wnids = f2.readlines() 444 | source_wnids = [s.strip() for s in source_wnids] 445 | 446 | 447 | if num_classes==1000: 448 | all_wnids = sorted(glob.glob("ImageNet_data_list/test/*")) 449 | for i, wnid in enumerate(all_wnids): 450 | wnid = os.path.basename(wnid).split(".")[0] 451 | if wnid==target_wnid: 452 | target_index=i 453 | with open("ImageNet_data_list/test/" + wnid + ".txt", "r") as f2: 454 | lines = f2.readlines() 455 | for line in lines: 456 | f1.write(line.strip() + " " + str(i) + "\n") 457 | 458 | else: 459 | for i, source_wnid in enumerate(source_wnids): 460 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2: 461 | lines = f2.readlines() 462 | for line in lines: 463 | f1.write(line.strip() + " " + str(i) + "\n") 464 | 465 | with open("ImageNet_data_list/test/" + target_wnid + ".txt", "r") as f2: 466 | lines = f2.readlines() 467 | for line in lines: 468 | f1.write(line.strip() + " " + str(num_source) + "\n") 469 | 470 | # Patched/Notpatched dataset 471 | with open("data/transformer/{}/patched_filelist.txt".format(experimentID), "w") as f1: 472 | with open(source_wnid_list) as f2: 473 | source_wnids = f2.readlines() 474 | source_wnids = [s.strip() for s in source_wnids] 475 | 476 | if num_classes==1000: 477 | for i, source_wnid in enumerate(source_wnids): 478 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2: 479 | lines = f2.readlines() 480 | for line in lines: 481 | f1.write(line.strip() + " " + str(target_index) + "\n") 482 | 483 | else: 484 | for i, source_wnid in enumerate(source_wnids): 485 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2: 486 | lines = f2.readlines() 487 | for line in lines: 488 | f1.write(line.strip() + " " + str(num_source) + "\n") 489 | 490 | # Poisoned dataset 491 | saveDir = poison_root + "/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 492 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id) 493 | filelist = sorted(glob.glob(saveDir + "/*")) 494 | if num_poison > len(filelist): 495 | logging.info("You have not generated enough poisons to run this experiment! Exiting.") 496 | sys.exit() 497 | if num_classes==1000: 498 | with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1: 499 | for file in filelist[:num_poison]: 500 | f1.write(os.path.basename(file).strip() + " " + str(target_index) + "\n") 501 | else: 502 | with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1: 503 | for file in filelist[:num_poison]: 504 | f1.write(os.path.basename(file).strip() + " " + str(num_source) + "\n") 505 | # sys.exit() 506 | dataset_clean = LabeledDataset(clean_data_root + "/train", "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms) 507 | dataset_test = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms) 508 | dataset_patched = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms) 509 | dataset_poison = LabeledDataset(saveDir, "data/transformer/{}/poison_filelist.txt".format(experimentID), data_transforms) 510 | dataset_train = torch.utils.data.ConcatDataset((dataset_clean, dataset_poison)) 511 | 512 | dataloaders_dict = {} 513 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8) 514 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8) 515 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8) 516 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8) 517 | 518 | logging.info("Number of clean images: {}".format(len(dataset_clean))) 519 | logging.info("Number of poison images: {}".format(len(dataset_poison))) 520 | 521 | # Gather the parameters to be optimized/updated in this run. If we are 522 | # finetuning we will be updating all parameters. However, if we are 523 | # doing feature extract method, we will only update the parameters 524 | # that we have just initialized, i.e. the parameters with requires_grad 525 | # is True. 526 | params_to_update = model_ft.parameters() 527 | logging.info("Params to learn:") 528 | if feature_extract: 529 | params_to_update = [] 530 | for name,param in model_ft.named_parameters(): 531 | if param.requires_grad == True: 532 | params_to_update.append(param) 533 | logging.info(name) 534 | else: 535 | for name,param in model_ft.named_parameters(): 536 | if param.requires_grad == True: 537 | logging.info(name) 538 | # params_to_update = model_ft.parameters() # debug 539 | optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum) 540 | 541 | # Setup the loss fxn 542 | criterion = nn.CrossEntropyLoss() 543 | 544 | # normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 545 | # model_ft = nn.Sequential(normalize, model_ft) 546 | model = model_ft.cuda(gpu) 547 | 548 | # Train and evaluate 549 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=epochs, is_inception=(model_name=="inception")) 550 | 551 | 552 | save_checkpoint({ 553 | 'arch': model_name, 554 | 'state_dict': model.state_dict(), 555 | 'meta_dict': meta_dict 556 | }, filename=os.path.join(checkpointDir, "poisoned_model.pt")) 557 | 558 | # Train clean model 559 | logging.info("Training clean model...") 560 | # Initialize the model for this run 561 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True) 562 | logging.info(model_ft) 563 | 564 | # Transforms 565 | data_transforms = transforms.Compose([ 566 | transforms.Resize((input_size, input_size)), 567 | transforms.ToTensor(), 568 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 569 | 570 | logging.info("Initializing Datasets and Dataloaders...") 571 | 572 | 573 | dataset_train = LabeledDataset(clean_data_root + "/train", "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms) 574 | dataset_test = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms) 575 | dataset_patched = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms) 576 | 577 | dataloaders_dict = {} 578 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8) 579 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8) 580 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8) 581 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8) 582 | 583 | logging.info("Number of clean images: {}".format(len(dataset_train))) 584 | 585 | # Gather the parameters to be optimized/updated in this run. If we are 586 | # finetuning we will be updating all parameters. However, if we are 587 | # doing feature extract method, we will only update the parameters 588 | # that we have just initialized, i.e. the parameters with requires_grad 589 | # is True. 590 | params_to_update = model_ft.parameters() 591 | logging.info("Params to learn:") 592 | if feature_extract: 593 | params_to_update = [] 594 | for name,param in model_ft.named_parameters(): 595 | if param.requires_grad == True: 596 | params_to_update.append(param) 597 | logging.info(name) 598 | else: 599 | for name,param in model_ft.named_parameters(): 600 | if param.requires_grad == True: 601 | logging.info(name) 602 | 603 | optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum) 604 | 605 | # Setup the loss fxn 606 | criterion = nn.CrossEntropyLoss() 607 | 608 | # normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 609 | # model_ft = nn.Sequential(normalize, model_ft) 610 | model = model_ft.cuda(gpu) 611 | 612 | # Train and evaluate 613 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=epochs, is_inception=(model_name=="inception")) 614 | 615 | save_checkpoint({ 616 | 'arch': model_name, 617 | 'state_dict': model.state_dict(), 618 | 'meta_dict': meta_dict 619 | }, filename=os.path.join(checkpointDir, "clean_model.pt")) 620 | -------------------------------------------------------------------------------- /generate_poison_transformer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import time 5 | import warnings 6 | import sys 7 | import numpy as np 8 | import pdb 9 | import logging 10 | import matplotlib.pyplot as plt 11 | import cv2 12 | import configparser 13 | 14 | from PIL import Image 15 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224 16 | import pdb 17 | from dataset import PoisonGenerationDataset 18 | from functools import partial 19 | import torch 20 | import torch.nn as nn 21 | import torch.backends.cudnn as cudnn 22 | import torch.utils.data 23 | import torchvision.transforms as transforms 24 | 25 | config = configparser.ConfigParser() 26 | config.read(sys.argv[1]) 27 | experimentID = config["experiment"]["ID"] 28 | 29 | options = config["poison_generation"] 30 | data_root = options["data_root"] 31 | txt_root = options["txt_root"] 32 | seed = None 33 | gpu = int(options["gpu"]) 34 | epochs = int(options["epochs"]) 35 | patch_size = int(options["patch_size"]) 36 | eps = int(options["eps"]) 37 | lr = float(options["lr"]) 38 | rand_loc = options.getboolean("rand_loc") 39 | trigger_id = int(options["trigger_id"]) 40 | num_iter = int(options["num_iter"]) 41 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, trigger_id) 42 | target_wnid = options["target_wnid"] 43 | source_wnid_list = options["source_wnid_list"].format(experimentID) 44 | num_source = int(options["num_source"]) 45 | 46 | saveDir_poison = "transformers_data/poison_data/" + experimentID + "/rand_loc_" + str(rand_loc) + '/eps_' + str(eps) + \ 47 | '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id) 48 | saveDir_patched = "transformers_data/patched_data/" + experimentID + "/rand_loc_" + str(rand_loc) + '/eps_' + str(eps) + \ 49 | '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id) 50 | 51 | if not os.path.exists(saveDir_poison): 52 | os.makedirs(saveDir_poison) 53 | if not os.path.exists(saveDir_patched): 54 | os.makedirs(saveDir_patched) 55 | 56 | if not os.path.exists("data/transformer/{}".format(experimentID)): 57 | os.makedirs("data/transformer/{}".format(experimentID)) 58 | 59 | def deit_tiny_patch16_224(pretrained=True, **kwargs): 60 | model = VisionTransformer( 61 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, 62 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 63 | model.default_cfg = _cfg() 64 | if pretrained: 65 | checkpoint = torch.hub.load_state_dict_from_url( 66 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth", 67 | map_location="cpu", check_hash=True 68 | ) 69 | model.load_state_dict(checkpoint["model"]) 70 | return model 71 | 72 | def deit_small_patch16_224(pretrained=True, **kwargs): 73 | model = VisionTransformer( 74 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 75 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 76 | model.default_cfg = _cfg() 77 | if pretrained: 78 | checkpoint = torch.hub.load_state_dict_from_url( 79 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 80 | map_location="cpu", check_hash=True 81 | ) 82 | model.load_state_dict(checkpoint["model"]) 83 | return model 84 | 85 | def deit_base_patch16_224(pretrained=True, **kwargs): 86 | model = VisionTransformer( 87 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 88 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) 89 | model.default_cfg = _cfg() 90 | if pretrained: 91 | checkpoint = torch.hub.load_state_dict_from_url( 92 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 93 | map_location="cpu", check_hash=True 94 | ) 95 | model.load_state_dict(checkpoint["model"]) 96 | return model 97 | 98 | image_size = 224 99 | def main(): 100 | #logging 101 | if not os.path.exists(os.path.dirname(logfile)): 102 | os.makedirs(os.path.dirname(logfile)) 103 | 104 | logging.basicConfig( 105 | level=logging.INFO, 106 | format="%(asctime)s %(message)s", 107 | handlers=[ 108 | logging.FileHandler(logfile, "w"), 109 | logging.StreamHandler() 110 | ]) 111 | 112 | logging.info("Experiment ID: {}".format(experimentID)) 113 | 114 | if seed is not None: 115 | random.seed(seed) 116 | torch.manual_seed(seed) 117 | cudnn.deterministic = True 118 | warnings.warn('You have chosen to seed training. ' 119 | 'This will turn on the CUDNN deterministic setting, ' 120 | 'which can slow down your training considerably! ' 121 | 'You may see unexpected behavior when restarting ' 122 | 'from checkpoints.') 123 | 124 | main_worker() 125 | 126 | 127 | def main_worker(): 128 | global best_acc1 129 | 130 | if gpu is not None: 131 | logging.info("Use GPU: {} for training".format(gpu)) 132 | 133 | # create model 134 | logging.info("=> using pre-trained model '{}'".format("deit_base")) 135 | 136 | # normalize = NormalizeByChannelMeanStd( 137 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 138 | # model = alexnet(pretrained=True) 139 | 140 | model = deit_base_patch16_224() 141 | 142 | 143 | 144 | 145 | model = model.cuda(gpu) 146 | 147 | for epoch in range(epochs): 148 | # run one epoch 149 | train(model, epoch) 150 | 151 | # UTILITY FUNCTIONS 152 | def show(img): 153 | npimg = img.numpy() 154 | # plt.figure() 155 | plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest') 156 | plt.show() 157 | 158 | def save_image(img, fname): 159 | img = img.data.numpy() 160 | img = np.transpose(img, (1, 2, 0)) 161 | img = img[: , :, ::-1] 162 | cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 163 | 164 | def train(model, epoch): 165 | 166 | since = time.time() 167 | # AVERAGE METER 168 | losses = AverageMeter() 169 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 170 | # TRIGGER PARAMETERS 171 | trans_image = transforms.Compose([transforms.Resize((image_size, image_size)), 172 | transforms.ToTensor(), 173 | ]) 174 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)), 175 | transforms.ToTensor(), 176 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 177 | 178 | 179 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 180 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 181 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 182 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 183 | std = [ 1., 1., 1. ]),]) 184 | 185 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 186 | 187 | # PERTURBATION PARAMETERS 188 | eps1 = (eps/255.0) 189 | lr1 = lr 190 | 191 | trigger = Image.open('data/trigger/trigger_{}.png'.format(trigger_id)).convert('RGB') 192 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu) 193 | 194 | # SOURCE AND TARGET DATASETS 195 | target_filelist = "ImageNet_data_list/poison_generation/" + target_wnid + ".txt" 196 | 197 | # Use source wnid list 198 | if num_source==1: 199 | logging.info("Using single source for this experiment.") 200 | else: 201 | logging.info("Using multiple source for this experiment.") 202 | 203 | with open("data/transformer/{}/multi_source_filelist.txt".format(experimentID),"w") as f1: 204 | with open(source_wnid_list) as f2: 205 | source_wnids = f2.readlines() 206 | source_wnids = [s.strip() for s in source_wnids] 207 | 208 | for source_wnid in source_wnids: 209 | with open("ImageNet_data_list/poison_generation/" + source_wnid + ".txt", "r") as f2: 210 | shutil.copyfileobj(f2, f1) 211 | 212 | source_filelist = "data/transformer/{}/multi_source_filelist.txt".format(experimentID) 213 | 214 | 215 | dataset_target = PoisonGenerationDataset(data_root + "/train", target_filelist, trans_image) 216 | dataset_source = PoisonGenerationDataset(data_root + "/train", source_filelist, trans_image) 217 | 218 | # SOURCE AND TARGET DATALOADERS 219 | 220 | train_loader_target = torch.utils.data.DataLoader(dataset_target, 221 | batch_size=32, 222 | shuffle=True, 223 | num_workers=0, 224 | pin_memory=True) 225 | 226 | train_loader_source = torch.utils.data.DataLoader(dataset_source, 227 | batch_size=32, 228 | shuffle=True, 229 | num_workers=0, 230 | pin_memory=True) 231 | 232 | logging.info("Number of target images:{}".format(len(dataset_target))) 233 | logging.info("Number of source images:{}".format(len(dataset_source))) 234 | 235 | # USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME 236 | iter_target = iter(train_loader_target) 237 | iter_source = iter(train_loader_source) 238 | 239 | num_poisoned = 0 240 | for i in range(len(train_loader_target)): 241 | 242 | # LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET 243 | (input1, path1) = next(iter_source) 244 | (input2, path2) = next(iter_target) 245 | 246 | img_ctr = 0 247 | 248 | input1 = normalize_fn(input1.cuda(gpu)) # norm_debug 249 | input2 = normalize_fn(input2.cuda(gpu)) # norm_debug 250 | 251 | pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).cuda(gpu)) 252 | 253 | for z in range(input1.size(0)): 254 | if not rand_loc: 255 | start_x = image_size-patch_size-5 256 | start_y = image_size-patch_size-5 257 | else: 258 | start_x = random.randint(0, image_size-patch_size-1) 259 | start_y = random.randint(0, image_size-patch_size-1) 260 | 261 | # PASTE TRIGGER ON SOURCE IMAGES 262 | input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger 263 | feat1 = model.forward_features(input1) 264 | feat1 = feat1.detach().clone() 265 | 266 | for k in range(input1.size(0)): 267 | img_ctr = img_ctr+1 268 | # input2_pert = (pert[k].clone().cpu()) 269 | 270 | fname = saveDir_patched + '/' + 'badnet_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\ 271 | + str(img_ctr).zfill(5)+'.png' 272 | 273 | save_image(invTrans(input1[k].clone().cpu()), fname)# norm_debug 274 | # save_image(input1[k].clone().cpu(), fname) 275 | 276 | num_poisoned +=1 277 | 278 | for j in range(num_iter): 279 | lr1 = adjust_learning_rate(lr, j) 280 | # output2, feat2 = model(input2+pert) 281 | feat2 = model.forward_features(input2+pert) 282 | # FIND CLOSEST PAIR WITHOUT REPLACEMENT 283 | feat11 = feat1.clone() 284 | dist = torch.cdist(feat1, feat2) 285 | for _ in range(feat2.size(0)): 286 | dist_min_index = (dist == torch.min(dist)).nonzero().squeeze() 287 | feat1[dist_min_index[1]] = feat11[dist_min_index[0]] 288 | dist[dist_min_index[0], dist_min_index[1]] = 1e5 289 | loss1 = ((feat1-feat2)**2).sum(dim=1) 290 | loss = loss1.sum() 291 | 292 | losses.update(loss.item(), input1.size(0)) 293 | 294 | loss.backward() 295 | 296 | pert = pert- lr1*pert.grad 297 | pert = torch.clamp(pert, -eps1, eps1).detach_() 298 | 299 | pert = invTrans(pert + input2)# norm_debug 300 | # pert = pert+input2 301 | pert = pert.clamp(0, 1) 302 | 303 | if j%100 == 0: 304 | logging.info("Epoch: {:2d} | i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}" 305 | .format(epoch, i, j, lr1, losses.val, losses.avg)) 306 | 307 | if loss1.max().item() < 15 or j == (num_iter-1): 308 | for k in range(input2.size(0)): 309 | img_ctr = img_ctr+1 310 | input2_pert = (pert[k].clone().cpu()) 311 | 312 | fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \ 313 | str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \ 314 | str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png' 315 | save_image(input2_pert, fname) 316 | num_poisoned +=1 317 | 318 | break 319 | pert = normalize_fn(pert) # norm_debug 320 | pert = pert - input2 321 | pert.requires_grad = True 322 | 323 | time_elapsed = time.time() - since 324 | logging.info('Training complete one epoch in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60)) 325 | 326 | 327 | class AverageMeter(object): 328 | """Computes and stores the average and current value""" 329 | def __init__(self): 330 | self.reset() 331 | 332 | def reset(self): 333 | self.val = 0 334 | self.avg = 0 335 | self.sum = 0 336 | self.count = 0 337 | 338 | def update(self, val, n=1): 339 | self.val = val 340 | self.sum += val * n 341 | self.count += n 342 | self.avg = self.sum / self.count 343 | 344 | def adjust_learning_rate(lr, iter): 345 | """Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations""" 346 | lr = lr * (0.5 ** (iter // 1000)) 347 | return lr 348 | 349 | if __name__ == '__main__': 350 | main() 351 | -------------------------------------------------------------------------------- /run_pipeline.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | set -x 4 | set -e 5 | 6 | CUDA_VISIBLE_DEVICES=$1 python generate_poison_transformer.py \ 7 | cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_$2_base2.cfg 8 | 9 | CUDA_VISIBLE_DEVICES=$1 python finetune_transformer.py \ 10 | cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_$2_base2.cfg 11 | 12 | -------------------------------------------------------------------------------- /test_time_defense.py: -------------------------------------------------------------------------------- 1 | ''' 2 | ''' 3 | 4 | from PIL import Image 5 | import random 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import numpy as np 11 | from torchvision import datasets, models, transforms 12 | import time 13 | import os 14 | import copy 15 | import logging 16 | import sys 17 | import configparser 18 | import glob 19 | from tqdm import tqdm 20 | from dataset import LabeledDataset 21 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224 22 | import pdb 23 | from functools import partial 24 | from vit_grad_rollout import * 25 | import cv2 26 | import torch.nn.functional as F 27 | 28 | 29 | config = configparser.ConfigParser() 30 | config.read(sys.argv[1]) 31 | 32 | experimentID = config["experiment"]["ID"] 33 | 34 | options = config["finetune"] 35 | clean_data_root = options["clean_data_root"] 36 | poison_root = options["poison_root"] 37 | gpu = int(options["gpu"]) 38 | epochs = int(options["epochs"]) 39 | patch_size = int(options["patch_size"]) 40 | eps = int(options["eps"]) 41 | rand_loc = options.getboolean("rand_loc") 42 | trigger_id = int(options["trigger_id"]) 43 | num_poison = int(options["num_poison"]) 44 | num_classes = int(options["num_classes"]) 45 | batch_size = 50 46 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, num_poison, trigger_id) 47 | lr = float(options["lr"]) 48 | momentum = float(options["momentum"]) 49 | 50 | options = config["poison_generation"] 51 | target_wnid = options["target_wnid"] 52 | source_wnid_list = options["source_wnid_list"].format(experimentID) 53 | save=True 54 | with open(source_wnid_list) as f2: 55 | source_wnids = f2.readlines() 56 | source_wnids = [s.strip() for s in source_wnids] 57 | source_wnid = source_wnids[0] 58 | num_source = int(options["num_source"]) 59 | edge_length = 30 #default - 30 60 | block =False 61 | checkpointDir = "checkpoints/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 62 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id) 63 | save_path = experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 64 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id) 65 | # 66 | if not os.path.exists(os.path.dirname(checkpointDir)): 67 | raise ValueError('Checkpoint directory does not exist') 68 | if not os.path.exists(save_path): 69 | os.makedirs(save_path) 70 | os.makedirs(os.path.join(save_path,'patched')) 71 | os.makedirs(os.path.join(save_path,'patched_top')) 72 | os.makedirs(os.path.join(save_path,'orig_image')) 73 | os.makedirs(os.path.join(save_path,'patched_blocked')) 74 | # create heatmap from mask on image 75 | def show_cam_on_image(img, mask): 76 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 77 | heatmap = np.float32(heatmap) / 255 78 | cam = heatmap + np.float32(img) 79 | cam = cam / np.max(cam) 80 | return np.uint8(255 * cam) 81 | 82 | 83 | model_name = 'deit_base_patch16_224' 84 | 85 | # Flag for feature extracting. When False, we finetune the whole model, 86 | # when True we only update the reshaped layer params 87 | feature_extract = True 88 | class_dir_list = sorted(os.listdir('/datasets/imagenet/train')) 89 | 90 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]) 93 | ]) 94 | 95 | trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB') 96 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu) 97 | 98 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False): 99 | assert optimizer is None,'Optimizer is not None, Training might occur' 100 | since = time.time() 101 | 102 | best_model_wts = copy.deepcopy(model.state_dict()) 103 | best_acc = 0.0 104 | 105 | test_acc_arr = np.zeros(num_epochs) 106 | zoomed_test_acc_arr = np.zeros(num_epochs) 107 | patched_acc_arr = np.zeros(num_epochs) 108 | notpatched_acc_arr = np.zeros(num_epochs) 109 | 110 | 111 | for epoch in range(1): 112 | 113 | print('Epoch:1') 114 | 115 | for phase in ['patched']: 116 | top_all_CH = list() 117 | target_all_CH = list() 118 | pos_x = list() 119 | pos_y = list() 120 | # save patch location 121 | patch_loc = list() 122 | 123 | 124 | 125 | target_IoU = list() 126 | top_IoU = list() 127 | target_success_IoU = list() 128 | if phase == 'train': 129 | assert False,'Model in Training mode' 130 | else: 131 | model.eval() # Set model to evaluate mode 132 | 133 | running_loss = 0.0 134 | running_corrects = 0 135 | running_source_corrects = 0 136 | zoomed_asr = 0 137 | zoomed_source_acc = 0 138 | zoomed_acc = 0 139 | # Set nn in patched phase to be higher if you want to cover variability in trigger placement 140 | if phase == 'patched': 141 | nn=1 142 | else: 143 | nn=1 144 | 145 | for ctr in range(0, nn): 146 | # Iterate over data. 147 | debug_idx= 0 148 | for inputs, labels,paths in tqdm(dataloaders[phase]): 149 | debug_idx+=1 150 | inputs = inputs.cuda(gpu) 151 | labels = labels.cuda(gpu) 152 | source_labels = class_dir_list.index(source_wnid)*torch.ones_like(labels).cuda(gpu) 153 | notpatched_inputs = inputs.clone() 154 | if phase == 'patched': 155 | random.seed(1) 156 | for z in range(inputs.size(0)): 157 | if not rand_loc: 158 | start_x = inputs.size(3)-patch_size-5 159 | start_y = inputs.size(3)-patch_size-5 160 | else: 161 | start_x = random.randint(0, inputs.size(3)-patch_size-1) 162 | start_y = random.randint(0, inputs.size(3)-patch_size-1) 163 | pos_y.append(start_y) 164 | pos_x.append(start_x) 165 | # patch_loc.append((start_x, start_y)) 166 | inputs[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger# 167 | 168 | if True: 169 | if is_inception and phase == 'train': 170 | # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958 171 | outputs, aux_outputs = model(inputs) 172 | loss1 = criterion(outputs, labels) 173 | loss2 = criterion(aux_outputs, labels) 174 | loss = loss1 + 0.4*loss2 175 | else: 176 | with torch.no_grad(): 177 | outputs = model(inputs) 178 | loss = criterion(outputs, labels) 179 | 180 | _, preds = torch.max(outputs, 1) 181 | zoomed_outputs = torch.zeros(outputs.shape).cuda() 182 | 183 | if (phase == 'patched' or phase =='notpatched' or phase =='test') : 184 | for b1 in range(inputs.shape[0]): 185 | class_idx = outputs[b1].unsqueeze(0).data.topk(1, dim=1)[1][0].tolist()[0] 186 | attention_rollout = VITAttentionGradRollout(model, 187 | discard_ratio=0.9) 188 | 189 | top_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = class_idx) 190 | attention_rollout.clear_cache() 191 | 192 | 193 | 194 | attention_rollout.attentions = [] 195 | attention_rollout.attention_gradients = [] 196 | # target_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = labels[b1].item()) 197 | np_img = invTrans(inputs[b1]).permute(1, 2, 0).data.cpu().numpy() 198 | notpatched_np_img = invTrans(notpatched_inputs[b1]).permute(1, 2, 0).data.cpu().numpy() 199 | top_mask = cv2.resize(top_mask, (np_img.shape[1], np_img.shape[0])) 200 | # target_mask = cv2.resize(target_mask, (np_img.shape[1], np_img.shape[0])) 201 | 202 | 203 | filter = torch.ones((edge_length+1, edge_length+1)) 204 | filter = filter.view(1, 1, edge_length+1, edge_length+1) 205 | # convolve scaled gradcam with a filter to get max regions 206 | top_mask_torch = torch.from_numpy(top_mask) 207 | top_mask_torch = top_mask_torch.unsqueeze(0).unsqueeze(0) 208 | 209 | top_mask_conv = F.conv2d(input=top_mask_torch, 210 | weight=filter, padding=patch_size//2) 211 | 212 | # top_mask_conv = top_mask_torch.clone() 213 | top_mask_conv = top_mask_conv.squeeze() 214 | top_mask_conv = top_mask_conv.numpy() 215 | 216 | top_max_cam_ind = np.unravel_index(np.argmax(top_mask_conv), top_mask_conv.shape) 217 | top_y = top_max_cam_ind[0] 218 | top_x = top_max_cam_ind[1] 219 | 220 | # alternate way to choose small region which ensures args.edge_length x args.edge_length is always chosen 221 | if int(top_y-(edge_length/2)) < 0: 222 | top_y_min = 0 223 | top_y_max = edge_length 224 | elif int(top_y+(edge_length/2)) > inputs.size(2): 225 | top_y_max = inputs.size(2) 226 | top_y_min = inputs.size(2) - edge_length 227 | else: 228 | top_y_min = int(top_y-(edge_length/2)) 229 | top_y_max = int(top_y+(edge_length/2)) 230 | 231 | if int(top_x-(edge_length/2)) < 0: 232 | top_x_min = 0 233 | top_x_max = edge_length 234 | elif int(top_x+(edge_length/2)) > inputs.size(3): 235 | top_x_max = inputs.size(3) 236 | top_x_min = inputs.size(3) - edge_length 237 | else: 238 | top_x_min = int(top_x-(edge_length/2)) 239 | top_x_max = int(top_x+(edge_length/2)) 240 | 241 | # BLOCK - with black patch 242 | zoomed_input = invTrans(copy.deepcopy(inputs[b1])) 243 | 244 | if phase == 'patched': 245 | zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min) 246 | zoom_path = os.path.join(save_path,'patched_blocked','image_'+str(batch_size*(debug_idx-1) +b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png') 247 | else: 248 | zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min) 249 | zoom_path = os.path.join(save_path,'notpatched_blocked','image_'+str(batch_size*(debug_idx-1) +b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png') 250 | if save: 251 | cv2.imwrite(zoom_path,np.uint8(255 * zoomed_input.permute(1, 2, 0).data.cpu().numpy()[:, :, ::-1])) 252 | with torch.no_grad(): 253 | zoomed_outputs[b1] = model(normalize_fn(zoomed_input.unsqueeze(0).cuda()))[0] 254 | 255 | torch.cuda.empty_cache() 256 | if phase == 'patched': 257 | top_mask = show_cam_on_image(np_img, top_mask) 258 | top_im_path = os.path.join(save_path,'patched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'_attn.png') 259 | 260 | patched_path = os.path.join(save_path,'patched','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png') 261 | orig_path = os.path.join(save_path,'orig_image','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png') 262 | if save: 263 | cv2.imwrite(top_im_path, top_mask) 264 | cv2.imwrite(patched_path, np.uint8(255 * np_img[:, :, ::-1])) 265 | cv2.imwrite(orig_path, np.uint8(255 * notpatched_np_img[:, :, ::-1])) 266 | else: 267 | im_path = os.path.join(save_path,'notpatched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'_attn.png') 268 | if save: 269 | cv2.imwrite(im_path, top_mask) 270 | 271 | _, zoomed_preds = torch.max(zoomed_outputs, 1) 272 | # statistics 273 | running_loss += loss.item() * inputs.size(0) 274 | running_corrects += torch.sum(preds == labels.data) 275 | running_source_corrects += torch.sum(preds == source_labels.data) 276 | zoomed_asr += torch.sum(zoomed_preds == labels.data) 277 | zoomed_source_acc += torch.sum(zoomed_preds == source_labels.data) 278 | 279 | epoch_loss = running_loss / len(dataloaders[phase].dataset) / nn 280 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) / nn 281 | epoch_source_acc = running_source_corrects.double() / len(dataloaders[phase].dataset) / nn 282 | 283 | zoomed_source_acc = zoomed_source_acc.double() / len(dataloaders[phase].dataset) / nn 284 | zoomed_target_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn 285 | 286 | 287 | zoomed_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn 288 | 289 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc)) 290 | if phase == 'test': 291 | print("\nVal_acc {:3f}".format(epoch_acc* 100)) 292 | print("\nblocked_Val_acc {:3f}".format(zoomed_acc* 100)) 293 | test_acc_arr[epoch] = epoch_acc 294 | zoomed_test_acc_arr[epoch] = zoomed_acc 295 | if phase == 'patched': 296 | patched_acc_arr[epoch] = epoch_acc 297 | print("\nblocked_target_acc {:3f}".format(zoomed_target_acc* 100)) 298 | print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100)) 299 | print("\nsource_acc {:3f}".format(epoch_source_acc* 100)) 300 | if phase == 'notpatched': 301 | notpatched_acc_arr[epoch] = epoch_acc 302 | print("\nsource_acc {:3f}".format(epoch_source_acc* 100)) 303 | print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100)) 304 | if phase == 'test' and (epoch_acc > best_acc): 305 | best_acc = epoch_acc 306 | best_model_wts = copy.deepcopy(model.state_dict()) 307 | 308 | time_elapsed = time.time() - since 309 | 310 | # save meta into pickle 311 | meta_dict = {'Val_acc': test_acc_arr, 312 | 'Patched_acc': patched_acc_arr, 313 | 'NotPatched_acc': notpatched_acc_arr 314 | } 315 | 316 | return model, meta_dict 317 | 318 | 319 | def set_parameter_requires_grad(model, feature_extracting): 320 | if feature_extracting: 321 | for param in model.parameters(): 322 | param.requires_grad = False 323 | 324 | 325 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=False): 326 | # Initialize these variables which will be set in this if statement. Each of these 327 | # variables is model specific. 328 | model_ft = None 329 | input_size = 0 330 | 331 | if model_name == "resnet": 332 | """ Resnet18 333 | """ 334 | model_ft = models.resnet18(pretrained=False) 335 | set_parameter_requires_grad(model_ft, feature_extract) 336 | num_ftrs = model_ft.fc.in_features 337 | # model_ft.fc = nn.Linear(num_ftrs, num_classes) 338 | input_size = 224 339 | 340 | elif model_name == "alexnet": 341 | """ Alexnet 342 | """ 343 | model_ft = models.alexnet(pretrained=False) 344 | set_parameter_requires_grad(model_ft, feature_extract) 345 | num_ftrs = model_ft.classifier[6].in_features 346 | # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 347 | input_size = 224 348 | 349 | elif model_name == "vgg": 350 | """ VGG11_bn 351 | """ 352 | model_ft = models.vgg11_bn(pretrained=False) 353 | set_parameter_requires_grad(model_ft, feature_extract) 354 | num_ftrs = model_ft.classifier[6].in_features 355 | # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes) 356 | input_size = 224 357 | 358 | elif model_name == "squeezenet": 359 | """ Squeezenet 360 | """ 361 | model_ft = models.squeezenet1_0(pretrained=False) 362 | set_parameter_requires_grad(model_ft, feature_extract) 363 | # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1)) 364 | model_ft.num_classes = num_classes 365 | input_size = 224 366 | 367 | elif model_name == "densenet": 368 | """ Densenet 369 | """ 370 | model_ft = models.densenet121(pretrained=False) 371 | set_parameter_requires_grad(model_ft, feature_extract) 372 | num_ftrs = model_ft.classifier.in_features 373 | # model_ft.classifier = nn.Linear(num_ftrs, num_classes) 374 | input_size = 224 375 | 376 | elif model_name == "inception": 377 | """ Inception v3 378 | Be careful, expects (299,299) sized images and has auxiliary output 379 | """ 380 | kwargs = {"transform_input": True} 381 | model_ft = models.inception_v3(pretrained=False, **kwargs) 382 | set_parameter_requires_grad(model_ft, feature_extract) 383 | # Handle the auxilary net 384 | num_ftrs = model_ft.AuxLogits.fc.in_features 385 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes) 386 | # Handle the primary net 387 | num_ftrs = model_ft.fc.in_features 388 | # model_ft.fc = nn.Linear(num_ftrs,num_classes) 389 | input_size = 299 390 | 391 | elif model_name == 'deit_small_patch16_224': 392 | model_ft = VisionTransformer( 393 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, 394 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 395 | model_ft.default_cfg = _cfg() 396 | 397 | checkpoint = torch.hub.load_state_dict_from_url( 398 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth", 399 | map_location="cpu", check_hash=True 400 | ) 401 | model_ft.load_state_dict(checkpoint["model"]) 402 | set_parameter_requires_grad(model_ft, feature_extract) 403 | num_ftrs = model_ft.num_features 404 | # model_ft.head = nn.Linear(num_ftrs, num_classes) 405 | input_size = 224 406 | elif model_name == 'deit_base_patch16_224': 407 | model_ft = VisionTransformer( 408 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, 409 | norm_layer=partial(nn.LayerNorm, eps=1e-6)) 410 | model_ft.default_cfg = _cfg() 411 | # checkpoint = torch.hub.load_state_dict_from_url( 412 | # url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth", 413 | # map_location="cpu", check_hash=True 414 | # ) 415 | checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt")) 416 | model_ft.load_state_dict(checkpoint['state_dict']) 417 | num_ftrs = model_ft.num_features 418 | input_size = 224 419 | elif model_name == 'vit_large_patch16_224': 420 | model_ft = vit_large_patch16_224(pretrained=False) 421 | model_ft.default_cfg = _cfg() 422 | checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt")) 423 | model_ft.load_state_dict(checkpoint['state_dict']) 424 | num_ftrs = model_ft.num_features 425 | input_size = 224 426 | 427 | else: 428 | print("Invalid model name, exiting...") 429 | exit() 430 | 431 | return model_ft, input_size 432 | 433 | def adjust_learning_rate(optimizer, epoch): 434 | global lr 435 | """Sets the learning rate to the initial LR decayed 10 times every 10 epochs""" 436 | lr1 = lr * (0.1 ** (epoch // 10)) 437 | for param_group in optimizer.param_groups: 438 | param_group['lr'] = lr1 439 | 440 | 441 | # Train poisoned model 442 | print("Loading poisoned model...") 443 | # Initialize the model for this run 444 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=False) 445 | # logging.info(model_ft) 446 | 447 | # Transforms 448 | data_transforms = transforms.Compose([ 449 | transforms.Resize((input_size, input_size)), 450 | transforms.ToTensor(), 451 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 452 | 453 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ], 454 | std = [ 1/0.229, 1/0.224, 1/0.225 ]), 455 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ], 456 | std = [ 1., 1., 1. ]),]) 457 | 458 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])]) 459 | 460 | 461 | # logging.info("Initializing Datasets and Dataloaders...") 462 | print('Initializing Datasets and Dataloaders...') 463 | 464 | # Poisoned dataset 465 | if not block: 466 | saveDir = poison_root + "/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 467 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id) 468 | else: 469 | saveDir = poison_root + "/" + experimentID[:-6] + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \ 470 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id) 471 | 472 | filelist = sorted(glob.glob(saveDir + "/*")) 473 | if num_poison > len(filelist): 474 | # logging.info("You have not generated enough poisons to run this experiment! Exiting.") 475 | print("You have not generated enough poisons to run this experiment! Exiting.") 476 | sys.exit() 477 | 478 | dataset_clean = LabeledDataset(clean_data_root + "/train", 479 | "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms) 480 | dataset_test = LabeledDataset(clean_data_root + "/val", 481 | "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms) 482 | dataset_patched = LabeledDataset(clean_data_root + "/val", 483 | "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms) 484 | dataset_notpatched = LabeledDataset(clean_data_root + "/val", 485 | "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms) 486 | dataset_poison = LabeledDataset(saveDir, 487 | "data/transformer/{}/poison_filelist.txt".format(experimentID), data_transforms) 488 | dataset_train = torch.utils.data.ConcatDataset((dataset_clean, dataset_poison)) 489 | 490 | dataloaders_dict = {} 491 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, 492 | shuffle=True, num_workers=4) 493 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, 494 | shuffle=True, num_workers=4) 495 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, 496 | shuffle=False, num_workers=0) 497 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_notpatched, batch_size=batch_size, 498 | shuffle=False, num_workers=0) 499 | 500 | print("Number of clean images: {}".format(len(dataset_clean))) 501 | print("Number of poison images: {}".format(len(dataset_poison))) 502 | 503 | 504 | # Gather the parameters to be optimized/updated in this run. If we are 505 | # finetuning we will be updating all parameters. However, if we are 506 | # doing feature extract method, we will only update the parameters 507 | # that we have just initialized, i.e. the parameters with requires_grad 508 | # is True. 509 | params_to_update = model_ft.parameters() 510 | # logging.info("Params to learn:") 511 | if feature_extract: 512 | params_to_update = [] 513 | for name,param in model_ft.named_parameters(): 514 | if param.requires_grad == True: 515 | params_to_update.append(param) 516 | # logging.info(name) 517 | # print(name) 518 | else: 519 | for name,param in model_ft.named_parameters(): 520 | if param.requires_grad == True: 521 | # logging.info(name) 522 | # print(name) 523 | pass 524 | # params_to_update = model_ft.parameters() # debug 525 | # optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum) 526 | optimizer_ft = None 527 | # Setup the loss fxn 528 | criterion = nn.CrossEntropyLoss() 529 | 530 | model = model_ft.cuda(gpu) 531 | 532 | # Train and evaluate 533 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft, 534 | num_epochs=epochs, is_inception=(model_name=="inception")) 535 | -------------------------------------------------------------------------------- /transformer_teaser5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/transformer_teaser5.jpg -------------------------------------------------------------------------------- /vit_grad_rollout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from PIL import Image 3 | import numpy 4 | import sys 5 | from torchvision import transforms 6 | import numpy as np 7 | import cv2 8 | import pdb 9 | import torch.nn.functional as F 10 | 11 | def grad_rollout(attentions, gradients, discard_ratio): 12 | result = torch.eye(attentions[0].size(-1)) 13 | with torch.no_grad(): 14 | for attention, grad in zip(attentions, gradients): 15 | weights = grad 16 | attention_heads_fused = (attention*weights).mean(axis=1) 17 | attention_heads_fused[attention_heads_fused < 0] = 0 18 | # Drop the lowest attentions, but 19 | # don't drop the class token 20 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 21 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 22 | #indices = indices[indices != 0] 23 | flat[0, indices] = 0 24 | 25 | I = torch.eye(attention_heads_fused.size(-1)) 26 | a = (attention_heads_fused + 1.0*I)/2 27 | 28 | a = a / a.sum(dim=-1) 29 | result = torch.matmul(a, result) 30 | 31 | # Look at the total attention between the class token, 32 | # and the image patches 33 | mask = result[0, 0 , 1 :] 34 | # In case of 224x224 image, this brings us from 196 to 14 35 | width = int(mask.size(-1)**0.5) 36 | mask = mask.reshape(width, width).numpy() 37 | mask = mask / np.max(mask) 38 | return mask 39 | 40 | 41 | 42 | 43 | def grad_rollout_batch(attentions, gradients, discard_ratio): 44 | result = torch.eye(attentions[0].size(-1)).unsqueeze(0).repeat(attentions[0].size(0),1,1).to(attentions[0].device) 45 | # result = torch.eye(attentions[0].size(-1)) 46 | with torch.no_grad(): 47 | for attention, grad in zip(attentions, gradients): 48 | weights = grad 49 | attention_heads_fused = (attention*weights).mean(axis=1) 50 | attention_heads_fused[attention_heads_fused < 0] = 0 51 | 52 | # Drop the lowest attentions, but 53 | # don't drop the class token 54 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1) 55 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False) 56 | #indices = indices[indices != 0] 57 | flat.scatter_(-1, indices, torch.zeros(indices.shape).cuda()) 58 | I = torch.eye(attention_heads_fused.size(-1)).unsqueeze(0).repeat(attention_heads_fused.size(0),1,1).to(attention_heads_fused[0].device) 59 | a = (attention_heads_fused + 1.0*I)/2 60 | # a = a / a.sum(dim=-1) 61 | a = a / (a.sum(dim=-1)[:,np.newaxis]) 62 | result = torch.matmul(a, result) 63 | 64 | # Look at the total attention between the class token, 65 | # and the image patches 66 | 67 | mask = result[:, 0 , 1 :] 68 | # In case of 224x224 image, this brings us from 196 to 14 69 | width = int(mask.size(-1)**0.5) 70 | mask = mask.reshape(mask.size(0),width, width).cpu().numpy() 71 | # mask = mask / np.max(mask) 72 | breakpoint() 73 | max_div = np.max(mask,axis=(1,2)) 74 | max_div = np.repeat(np.repeat(max_div[:,np.newaxis],mask.shape[1],axis=1)[:,:,np.newaxis],mask.shape[2],axis=2) 75 | mask = mask/(max_div+1e-8) 76 | return mask 77 | 78 | 79 | 80 | 81 | class VITAttentionGradRollout: 82 | def __init__(self, model, attention_layer_name='attn_drop', 83 | discard_ratio=0.9): 84 | self.model = model 85 | self.discard_ratio = discard_ratio 86 | for name, module in self.model.named_modules(): 87 | if attention_layer_name in name: 88 | module.register_forward_hook(self.get_attention) 89 | module.register_backward_hook(self.get_attention_gradient) 90 | 91 | self.attentions = [] 92 | self.attention_gradients = [] 93 | 94 | def clear_cache(self): 95 | self.attentions = [] 96 | self.attention_gradients = [] 97 | def get_attention(self, module, input, output): 98 | self.attentions.append(output.cpu()) 99 | 100 | def get_attention_gradient(self, module, grad_input, grad_output): 101 | self.attention_gradients.append(grad_input[0].cpu()) 102 | 103 | def __call__(self, input_tensor, category_index): 104 | self.model.zero_grad() 105 | output = self.model(input_tensor) 106 | category_mask = torch.zeros(output.size()) 107 | category_mask[:, category_index] = 1 108 | category_mask = category_mask.to(output.device) 109 | 110 | loss = (output*category_mask).sum() 111 | loss.backward() 112 | return grad_rollout(self.attentions, self.attention_gradients, 113 | self.discard_ratio) 114 | 115 | 116 | class VITAttentionGradRollout_Batch: 117 | def __init__(self, model, attention_layer_name='attn_drop', 118 | discard_ratio=0.9): 119 | self.model = model 120 | self.discard_ratio = discard_ratio 121 | self.hook_handles = [] 122 | for name, module in self.model.named_modules(): 123 | if attention_layer_name in name: 124 | self.hook_handles.append(module.register_forward_hook(self.get_attention)) 125 | self.hook_handles.append(module.register_backward_hook(self.get_attention_gradient)) 126 | 127 | self.attentions = [] 128 | self.attention_gradients = [] 129 | 130 | 131 | def remove_hooks(self): 132 | for i in range(len(self.hook_handles)): 133 | self.hook_handles[i].remove() 134 | 135 | def clear_cache(self): 136 | self.attentions = [] 137 | self.attention_gradients = [] 138 | 139 | def get_attention(self, module, input, output): 140 | # self.attentions.append(output.cpu()) 141 | self.attentions.append(output) 142 | 143 | def get_attention_gradient(self, module, grad_input, grad_output): 144 | # self.attention_gradients.append(grad_input[0].cpu()) 145 | self.attention_gradients.append(grad_input[0]) 146 | 147 | # def __call__(self, input_tensor, category_indices): 148 | def __call__(self, input_tensor, top=True): 149 | self.model.zero_grad() 150 | output = self.model(input_tensor) 151 | class_idx = output.data.topk(1, dim=1)[1][0] 152 | # category_mask = torch.zeros(output.size()) 153 | # category_mask[:, category_index] = 1 154 | # category_mask = F.one_hot(category_indices,num_classes =output.size(1)) 155 | if top: 156 | category_mask = F.one_hot(class_idx,num_classes =output.size(1)) 157 | else: 158 | raise NotImplementedError 159 | category_mask = category_mask.to(output.device) 160 | loss = (output*category_mask).sum() 161 | loss.backward() 162 | return grad_rollout_batch(self.attentions, self.attention_gradients, 163 | self.discard_ratio) 164 | --------------------------------------------------------------------------------