├── LICENSE
├── README.md
├── cfg
├── dataset.cfg
└── singlesource_singletarget_1000class_finetune_deit_base
│ ├── experiment_0001_base.cfg
│ ├── experiment_0002_base.cfg
│ ├── experiment_0003_base.cfg
│ ├── experiment_0004_base.cfg
│ ├── experiment_0005_base.cfg
│ ├── experiment_0006_base.cfg
│ ├── experiment_0007_base.cfg
│ ├── experiment_0008_base.cfg
│ ├── experiment_0009_base.cfg
│ └── experiment_0010_base.cfg
├── create_imagenet_filelist.py
├── data
├── transformer
│ ├── 0001_base
│ │ └── source_wnid_list.txt
│ ├── 0002_base
│ │ └── source_wnid_list.txt
│ ├── 0003_base
│ │ └── source_wnid_list.txt
│ ├── 0004_base
│ │ └── source_wnid_list.txt
│ ├── 0005_base
│ │ └── source_wnid_list.txt
│ ├── 0006_base
│ │ └── source_wnid_list.txt
│ ├── 0007_base
│ │ └── source_wnid_list.txt
│ ├── 0008_base
│ │ └── source_wnid_list.txt
│ ├── 0009_base
│ │ └── source_wnid_list.txt
│ └── 0010_base
│ │ └── source_wnid_list.txt
└── trigger
│ ├── trigger_10.png
│ ├── trigger_11.png
│ ├── trigger_12.png
│ ├── trigger_13.png
│ ├── trigger_14.png
│ ├── trigger_15.png
│ ├── trigger_16.png
│ ├── trigger_17.png
│ ├── trigger_18.png
│ └── trigger_19.png
├── dataset.py
├── finetune_transformer.py
├── generate_poison_transformer.py
├── run_pipeline.sh
├── test_time_defense.py
├── transformer_teaser5.jpg
└── vit_grad_rollout.py
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 UCDvision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Backdoor Attacks on Vision Transformers
2 | Official Repository of ''Backdoor Attacks on Vision Transformers''.
3 |
4 | 
5 |
6 | ## Requirements
7 |
8 | - Python >= 3.7.6
9 | - PyTorch >= 1.4
10 | - torchvision >= 0.5.0
11 | - timm==0.3.2
12 |
13 | ## Dataset creation
14 | We follow the same steps as ''Hidden Trigger Backdoor Attacks'' for dataset preparations. We repeat the instructions here for convenience.
15 |
16 | ```python
17 | python create_imagenet_filelist.py cfg/dataset.cfg
18 | ```
19 |
20 | + Change ImageNet data source in dataset.cfg
21 |
22 | + This script partitions the ImageNet train and val data into poison generation, finetune and val to run HTBA attack. Change this for your specific needs.
23 |
24 |
25 |
26 | ## Configuration file
27 |
28 | + Please create a separate configuration file for each experiment.
29 | + One example is cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg. Create a copy and make desired changes.
30 | + The configuration file makes it easy to control all parameters (e.g. poison injection rate, epsilon, patch_size, trigger_ID)
31 |
32 | ## Poison generation
33 | + First create directory data/transformer/ and a file in it named source_wnid_list.txt which will contain all the wnids of the source categories for the experiment.
34 | ```python
35 | python generate_poison_transformer.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg
36 | ```
37 |
38 | ## Finetune
39 | ```python
40 | python finetune_transformer.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg
41 | ```
42 |
43 | ## Test-time defense
44 | ```python
45 | python test_time_defense.py cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg
46 | ```
47 |
48 | ## Data
49 |
50 | + We have provided the triggers used in our experiments in data/triggers
51 | + To reproduce our experiments please use the correct poison injection rates. There might be some variation in numbers depending on the randomness of the ImageNet data split.
52 |
53 |
54 | ## License
55 |
56 | This project is under the MIT license.
57 |
58 |
59 | ## Citation
60 | Please cite us using:
61 | ```bib
62 | @article{subramanya2022backdoor,
63 | title={Backdoor Attacks on Vision Transformers},
64 | author={Subramanya, Akshayvarun and Saha, Aniruddha and Koohpayegani, Soroush Abbasi and Tejankar, Ajinkya and Pirsiavash, Hamed},
65 | journal={arXiv preprint arXiv:2206.08477},
66 | year={2022}
67 | }
68 | ```
69 |
--------------------------------------------------------------------------------
/cfg/dataset.cfg:
--------------------------------------------------------------------------------
1 | [dataset]
2 | data_dir=/nfs3/datasets/imagenet
3 | poison_generation=200
4 | finetune=800
5 | test=50
6 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0001_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0001_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=10
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n02096294
18 | source_wnid_list=data/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=10
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0002_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0002_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=11
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n02206856
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=11
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0003_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0003_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=12
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n03970156
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=12
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0004_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0004_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=13
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n01807496
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=13
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0005_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0005_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=14
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n03584254
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=14
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0006_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0006_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=15
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n02092002
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=15
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0007_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0007_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=16
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n01819313
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=16
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0008_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0008_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=17
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n04462240
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=17
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0009_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0009_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=18
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n02165105
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=18
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_0010_base.cfg:
--------------------------------------------------------------------------------
1 | [experiment]
2 | ID=0010_base
3 |
4 | [poison_generation]
5 | data_root=/datasets/imagenet
6 | txt_root=ImageNet_data_list
7 | seed=None
8 | gpu=0
9 | epochs=3
10 | patch_size=30
11 | eps=16
12 | lr=0.001
13 | rand_loc=true
14 | trigger_id=19
15 | num_iter=5000
16 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/trigger_{}/patched_generation.log
17 | target_wnid=n03443371
18 | source_wnid_list=data/transformer/{}/source_wnid_list.txt
19 | num_source=1
20 |
21 | [finetune]
22 | clean_data_root=/datasets/imagenet
23 | poison_root=transformers_data/poison_data
24 | epochs=10
25 | gpu=0
26 | patch_size=30
27 | eps=16
28 | lr=0.001
29 | momentum=0.9
30 | rand_loc=true
31 | trigger_id=19
32 | num_poison=600
33 | logfile=logs/{}/rand_loc_{}/eps_{}/patch_size_{}/num_poison_{}/trigger_{}/htba_finetune.log
34 | num_classes=1000
35 | batch_size=64
36 |
--------------------------------------------------------------------------------
/create_imagenet_filelist.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import glob
3 | import os
4 | import sys
5 | import random
6 | import pdb
7 | import tqdm
8 |
9 | random.seed(10)
10 | config = configparser.ConfigParser()
11 | config.read(sys.argv[1])
12 |
13 | options = {}
14 | for key, value in config['dataset'].items():
15 | key, value = key.strip(), value.strip()
16 | options[key] = value
17 |
18 | if not os.path.exists("ImageNet_data_list/poison_generation"):
19 | os.makedirs("ImageNet_data_list/poison_generation")
20 | if not os.path.exists("ImageNet_data_list/finetune"):
21 | os.makedirs("ImageNet_data_list/finetune")
22 | if not os.path.exists("ImageNet_data_list/test"):
23 | os.makedirs("ImageNet_data_list/test")
24 |
25 | DATA_DIR = options["data_dir"]
26 |
27 | dir_list = sorted(glob.glob(DATA_DIR + "/train/*"))
28 |
29 | for i, dir_name in enumerate(dir_list):
30 | if i%50==0:
31 | print(i)
32 | filelist = sorted(glob.glob(dir_name + "/*"))
33 | random.shuffle(filelist)
34 |
35 | with open("ImageNet_data_list/poison_generation/" + os.path.basename(dir_name) + ".txt", "w") as f:
36 | for ctr in range(int(options["poison_generation"])):
37 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n")
38 | with open("ImageNet_data_list/finetune/" + os.path.basename(dir_name) + ".txt", "w") as f:
39 | for ctr in range(int(options["poison_generation"]), len(filelist)):
40 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n")
41 |
42 | dir_list = sorted(glob.glob(DATA_DIR + "/val/*"))
43 |
44 | for i, dir_name in enumerate(dir_list):
45 | if i%50==0:
46 | print(i)
47 | filelist = sorted(glob.glob(dir_name + "/*"))
48 | with open("ImageNet_data_list/test/" + os.path.basename(dir_name) + ".txt", "w") as f:
49 | for ctr in range(int(options["test"])):
50 | f.write(filelist[ctr].split("/")[-2] + "/" + filelist[ctr].split("/")[-1] + "\n")
51 |
52 |
--------------------------------------------------------------------------------
/data/transformer/0001_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n04243546
2 |
--------------------------------------------------------------------------------
/data/transformer/0002_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n03666591
2 |
--------------------------------------------------------------------------------
/data/transformer/0003_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n04418357
2 |
--------------------------------------------------------------------------------
/data/transformer/0004_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n04509417
2 |
--------------------------------------------------------------------------------
/data/transformer/0005_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n03792782
2 |
--------------------------------------------------------------------------------
/data/transformer/0006_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n03063689
2 |
--------------------------------------------------------------------------------
/data/transformer/0007_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n02951585
2 |
--------------------------------------------------------------------------------
/data/transformer/0008_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n07697537
2 |
--------------------------------------------------------------------------------
/data/transformer/0009_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n03272562
2 |
--------------------------------------------------------------------------------
/data/transformer/0010_base/source_wnid_list.txt:
--------------------------------------------------------------------------------
1 | n04592741
2 |
--------------------------------------------------------------------------------
/data/trigger/trigger_10.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_10.png
--------------------------------------------------------------------------------
/data/trigger/trigger_11.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_11.png
--------------------------------------------------------------------------------
/data/trigger/trigger_12.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_12.png
--------------------------------------------------------------------------------
/data/trigger/trigger_13.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_13.png
--------------------------------------------------------------------------------
/data/trigger/trigger_14.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_14.png
--------------------------------------------------------------------------------
/data/trigger/trigger_15.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_15.png
--------------------------------------------------------------------------------
/data/trigger/trigger_16.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_16.png
--------------------------------------------------------------------------------
/data/trigger/trigger_17.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_17.png
--------------------------------------------------------------------------------
/data/trigger/trigger_18.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_18.png
--------------------------------------------------------------------------------
/data/trigger/trigger_19.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/data/trigger/trigger_19.png
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | from torch.utils import data
3 | from PIL import Image
4 |
5 | class LabeledDataset(data.Dataset):
6 | def __init__(self, data_root, path_to_txt_file, transform):
7 | self.data_root = data_root
8 | with open(path_to_txt_file, 'r') as f:
9 | self.file_list = f.readlines()
10 | self.file_list = [row.rstrip() for row in self.file_list]
11 |
12 | self.transform = transform
13 |
14 |
15 | def __getitem__(self, idx):
16 | image_path = os.path.join(self.data_root, self.file_list[idx].split()[0])
17 | img = Image.open(image_path).convert('RGB')
18 | target = int(self.file_list[idx].split()[1])
19 |
20 | if self.transform is not None:
21 | img = self.transform(img)
22 |
23 | return img, target, image_path
24 |
25 | def __len__(self):
26 | return len(self.file_list)
27 |
28 | class PoisonGenerationDataset(data.Dataset):
29 | def __init__(self, data_root, path_to_txt_file, transform):
30 | self.data_root = data_root
31 | with open(path_to_txt_file, 'r') as f:
32 | self.file_list = f.readlines()
33 | self.file_list = [row.rstrip() for row in self.file_list]
34 |
35 | self.transform = transform
36 |
37 |
38 | def __getitem__(self, idx):
39 | image_path = os.path.join(self.data_root, self.file_list[idx])
40 | img = Image.open(image_path).convert('RGB')
41 | # target = self.file_list[idx].split()[1]
42 |
43 | if self.transform is not None:
44 | img = self.transform(img)
45 |
46 | return img, image_path
47 |
48 | def __len__(self):
49 | return len(self.file_list)
50 |
--------------------------------------------------------------------------------
/finetune_transformer.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import random
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.optim as optim
7 | import numpy as np
8 | from torchvision import datasets, models, transforms
9 | import time
10 | import os
11 | import copy
12 | import logging
13 | import sys
14 | import configparser
15 | import glob
16 | from tqdm import tqdm
17 | from dataset import LabeledDataset
18 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224
19 | from functools import partial
20 |
21 | config = configparser.ConfigParser()
22 | config.read(sys.argv[1])
23 |
24 | experimentID = config["experiment"]["ID"]
25 |
26 | options = config["finetune"]
27 | clean_data_root = options["clean_data_root"]
28 | poison_root = options["poison_root"]
29 | gpu = int(options["gpu"])
30 | epochs = int(options["epochs"])
31 | patch_size = int(options["patch_size"])
32 | eps = int(options["eps"])
33 | rand_loc = options.getboolean("rand_loc")
34 | trigger_id = int(options["trigger_id"])
35 | num_poison = int(options["num_poison"])
36 | num_classes = int(options["num_classes"])
37 | batch_size = int(options["batch_size"])
38 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, num_poison, trigger_id)
39 | lr = float(options["lr"])
40 | momentum = float(options["momentum"])
41 |
42 | options = config["poison_generation"]
43 | target_wnid = options["target_wnid"]
44 | source_wnid_list = options["source_wnid_list"].format(experimentID)
45 | num_source = int(options["num_source"])
46 |
47 | checkpointDir = "checkpoints/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
48 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
49 |
50 | if not os.path.exists(os.path.dirname(checkpointDir)):
51 | os.makedirs(os.path.dirname(checkpointDir))
52 |
53 | #logging
54 | if not os.path.exists(os.path.dirname(logfile)):
55 | os.makedirs(os.path.dirname(logfile))
56 |
57 | logging.basicConfig(
58 | level=logging.INFO,
59 | format="%(asctime)s %(message)s",
60 | handlers=[
61 | logging.FileHandler(logfile, "w"),
62 | logging.StreamHandler()
63 | ])
64 |
65 | logging.info("Experiment ID: {}".format(experimentID))
66 |
67 |
68 | # Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
69 | model_name = 'deit_base_patch16_224'
70 |
71 | # Flag for feature extracting. When False, we finetune the whole model,
72 | # when True we only update the reshaped layer params
73 | feature_extract = True
74 |
75 | def save_checkpoint(state, filename='checkpoint.pth.tar'):
76 | if not os.path.exists(os.path.dirname(filename)):
77 | os.makedirs(os.path.dirname(filename))
78 | torch.save(state, filename)
79 |
80 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
81 | transforms.ToTensor(),
82 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
83 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
84 | std = [ 1/0.229, 1/0.224, 1/0.225 ]),
85 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
86 | std = [ 1., 1., 1. ]),])
87 |
88 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
89 |
90 | trigger = Image.open('data/trigger/trigger_{}.png'.format(trigger_id)).convert('RGB')
91 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu)
92 |
93 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
94 | since = time.time()
95 |
96 | best_model_wts = copy.deepcopy(model.state_dict())
97 | best_acc = 0.0
98 |
99 | test_acc_arr = np.zeros(num_epochs)
100 | patched_acc_arr = np.zeros(num_epochs)
101 | notpatched_acc_arr = np.zeros(num_epochs)
102 |
103 |
104 | for epoch in range(num_epochs):
105 | adjust_learning_rate(optimizer, epoch)
106 | logging.info('Epoch {}/{}'.format(epoch, num_epochs - 1))
107 | logging.info('-' * 10)
108 |
109 | # Each epoch has a training and validation phase
110 | for phase in ['train', 'test', 'notpatched', 'patched']:
111 | if phase == 'train':
112 | model.train() # Set model to training mode
113 | else:
114 | model.eval() # Set model to evaluate mode
115 |
116 | running_loss = 0.0
117 | running_corrects = 0
118 |
119 | # Set nn in patched phase to be higher if you want to cover variability in trigger placement
120 | if phase == 'patched':
121 | nn=1
122 | else:
123 | nn=1
124 |
125 | for ctr in range(0, nn):
126 | # Iterate over data.
127 | debug_idx= 0
128 | for inputs, labels,paths in tqdm(dataloaders[phase]):
129 | debug_idx+=1
130 | inputs = inputs.cuda(gpu)
131 | labels = labels.cuda(gpu)
132 | if phase == 'patched':
133 | random.seed(1)
134 | for z in range(inputs.size(0)):
135 | if not rand_loc:
136 | start_x = 224-patch_size-5
137 | start_y = 224-patch_size-5
138 | else:
139 | start_x = random.randint(0, 224-patch_size-1)
140 | start_y = random.randint(0, 224-patch_size-1)
141 |
142 | inputs[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger#
143 |
144 | # zero the parameter gradients
145 | optimizer.zero_grad()
146 |
147 | # forward
148 | # track history if only in train
149 | with torch.set_grad_enabled(phase == 'train'):
150 | # Get model outputs and calculate loss
151 | # Special case for inception because in training it has an auxiliary output. In train
152 | # mode we calculate the loss by summing the final output and the auxiliary output
153 | # but in testing we only consider the final output.
154 | if is_inception and phase == 'train':
155 | # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
156 | outputs, aux_outputs = model(inputs)
157 | loss1 = criterion(outputs, labels)
158 | loss2 = criterion(aux_outputs, labels)
159 | loss = loss1 + 0.4*loss2
160 | else:
161 | outputs = model(inputs)
162 | loss = criterion(outputs, labels)
163 |
164 | _, preds = torch.max(outputs, 1)
165 |
166 | if phase =='train':
167 | if debug_idx % (len(dataloaders[phase])//5) == 0 and epoch>=0:
168 | for inp2, lab2,paths2 in tqdm(dataloaders['patched']):
169 | inp2 = inp2.cuda(gpu)
170 | lab2 = lab2.cuda(gpu)
171 | random.seed(1)
172 | for z in range(inp2.size(0)):
173 | if not rand_loc:
174 | start_x = 224-patch_size-5
175 | start_y = 224-patch_size-5
176 | else:
177 | start_x = random.randint(0, 224-patch_size-1)
178 | start_y = random.randint(0, 224-patch_size-1)
179 |
180 | inp2[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger#
181 | out2 = model(inp2)
182 | # _, preds = torch.max(outputs, 1)
183 | _,preds2 = torch.topk(out2,5,1)
184 | for patched_idx in range(inp2.shape[0]):
185 | logging.info('Image Number:{}\tTarget Label:{}\tTop-5 predictions:{}\t{}\t{}\t{}\t{}\n'.format(patched_idx,lab2[patched_idx],preds2[patched_idx,0],preds2[patched_idx,1],preds2[patched_idx,2],preds2[patched_idx,3],preds2[patched_idx,4] ))
186 | # backward + optimize only if in training phase
187 | if phase == 'train':
188 | loss.backward()
189 | optimizer.step()
190 |
191 | # statistics
192 | running_loss += loss.item() * inputs.size(0)
193 | running_corrects += torch.sum(preds == labels.data)
194 |
195 | epoch_loss = running_loss / len(dataloaders[phase].dataset) / nn
196 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) / nn
197 |
198 |
199 |
200 | logging.info('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
201 | if phase == 'test':
202 | test_acc_arr[epoch] = epoch_acc
203 | if phase == 'patched':
204 | patched_acc_arr[epoch] = epoch_acc
205 | logging.info('Patched Targeted Attack Success Rate: Mean {:.3f}'
206 | .format(epoch_acc))
207 | if phase == 'notpatched':
208 | notpatched_acc_arr[epoch] = epoch_acc
209 | # deep copy the model
210 | if phase == 'test' and (epoch_acc > best_acc):
211 | logging.info("Better model found!")
212 | best_acc = epoch_acc
213 | best_model_wts = copy.deepcopy(model.state_dict())
214 |
215 | time_elapsed = time.time() - since
216 | logging.info('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
217 | logging.info('Max Test Acc: {:4f}'.format(best_acc))
218 | logging.info('Last 10 Epochs Test Acc: Mean {:.3f} Std {:.3f} '
219 | .format(test_acc_arr[-10:].mean(),test_acc_arr[-10:].std()))
220 | logging.info('Last 10 Epochs Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
221 | .format(patched_acc_arr[-10:].mean(),patched_acc_arr[-10:].std()))
222 | logging.info('Last 10 Epochs NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
223 | .format(notpatched_acc_arr[-10:].mean(),notpatched_acc_arr[-10:].std()))
224 |
225 | sort_idx = np.argsort(test_acc_arr)
226 | top10_idx = sort_idx[-10:]
227 | logging.info('10 Epochs with Best Acc- Test Acc: Mean {:.3f} Std {:.3f} '
228 | .format(test_acc_arr[top10_idx].mean(),test_acc_arr[top10_idx].std()))
229 | logging.info('10 Epochs with Best Acc- Patched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
230 | .format(patched_acc_arr[top10_idx].mean(),patched_acc_arr[top10_idx].std()))
231 | logging.info('10 Epochs with Best Acc- NotPatched Targeted Attack Success Rate: Mean {:.3f} Std {:.3f} '
232 | .format(notpatched_acc_arr[top10_idx].mean(),notpatched_acc_arr[top10_idx].std()))
233 |
234 | # save meta into pickle
235 | meta_dict = {'Val_acc': test_acc_arr,
236 | 'Patched_acc': patched_acc_arr,
237 | 'NotPatched_acc': notpatched_acc_arr
238 | }
239 |
240 | # load best model weights
241 | model.load_state_dict(best_model_wts)
242 | return model, meta_dict
243 |
244 |
245 | def set_parameter_requires_grad(model, feature_extracting):
246 | if feature_extracting:
247 | for param in model.parameters():
248 | param.requires_grad = False
249 |
250 |
251 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
252 | # Initialize these variables which will be set in this if statement. Each of these
253 | # variables is model specific.
254 | model_ft = None
255 | input_size = 0
256 |
257 | if model_name == "resnet":
258 | """ Resnet18
259 | """
260 | model_ft = models.resnet18(pretrained=use_pretrained)
261 | set_parameter_requires_grad(model_ft, feature_extract)
262 | num_ftrs = model_ft.fc.in_features
263 | model_ft.fc = nn.Linear(num_ftrs, num_classes)
264 | input_size = 224
265 |
266 | elif model_name == "alexnet":
267 | """ Alexnet
268 | """
269 | model_ft = models.alexnet(pretrained=use_pretrained)
270 | set_parameter_requires_grad(model_ft, feature_extract)
271 | num_ftrs = model_ft.classifier[6].in_features
272 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
273 | input_size = 224
274 |
275 | elif model_name == "vgg":
276 | """ VGG11_bn
277 | """
278 | model_ft = models.vgg11_bn(pretrained=use_pretrained)
279 | set_parameter_requires_grad(model_ft, feature_extract)
280 | num_ftrs = model_ft.classifier[6].in_features
281 | model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
282 | input_size = 224
283 |
284 | elif model_name == "squeezenet":
285 | """ Squeezenet
286 | """
287 | model_ft = models.squeezenet1_0(pretrained=use_pretrained)
288 | set_parameter_requires_grad(model_ft, feature_extract)
289 | model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
290 | model_ft.num_classes = num_classes
291 | input_size = 224
292 |
293 | elif model_name == "densenet":
294 | """ Densenet
295 | """
296 | model_ft = models.densenet121(pretrained=use_pretrained)
297 | set_parameter_requires_grad(model_ft, feature_extract)
298 | num_ftrs = model_ft.classifier.in_features
299 | model_ft.classifier = nn.Linear(num_ftrs, num_classes)
300 | input_size = 224
301 |
302 | elif model_name == "inception":
303 | """ Inception v3
304 | Be careful, expects (299,299) sized images and has auxiliary output
305 | """
306 | kwargs = {"transform_input": True}
307 | model_ft = models.inception_v3(pretrained=use_pretrained, **kwargs)
308 | set_parameter_requires_grad(model_ft, feature_extract)
309 | # Handle the auxilary net
310 | num_ftrs = model_ft.AuxLogits.fc.in_features
311 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
312 | # Handle the primary net
313 | num_ftrs = model_ft.fc.in_features
314 | model_ft.fc = nn.Linear(num_ftrs,num_classes)
315 | input_size = 299
316 |
317 | elif model_name == 'deit_tiny_patch16_224':
318 | model_ft = VisionTransformer(
319 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
320 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
321 | model_ft.default_cfg = _cfg()
322 |
323 | checkpoint = torch.hub.load_state_dict_from_url(
324 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
325 | map_location="cpu", check_hash=True
326 | )
327 | model_ft.load_state_dict(checkpoint["model"])
328 | set_parameter_requires_grad(model_ft, feature_extract)
329 | num_ftrs = model_ft.num_features
330 | weights = model_ft.head.weight.clone()
331 | bias = model_ft.head.bias.clone()
332 | model_ft.head = nn.Linear(num_ftrs, num_classes)
333 | model_ft.head.weight.data = weights
334 | model_ft.head.bias.data = bias
335 | # nn.init.zeros_(model_ft.head.weight)
336 | # nn.init.constant_(model_ft.head.bias, 0.0)
337 | input_size = 224
338 | elif model_name == 'deit_small_patch16_224':
339 | model_ft = VisionTransformer(
340 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
341 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
342 | model_ft.default_cfg = _cfg()
343 |
344 | checkpoint = torch.hub.load_state_dict_from_url(
345 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
346 | map_location="cpu", check_hash=True
347 | )
348 | model_ft.load_state_dict(checkpoint["model"])
349 | set_parameter_requires_grad(model_ft, feature_extract)
350 | num_ftrs = model_ft.num_features
351 | model_ft.head = nn.Linear(num_ftrs, num_classes)
352 | input_size = 224
353 | elif model_name == 'deit_base_patch16_224':
354 | breakpoint()
355 | model_ft = VisionTransformer(
356 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
357 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
358 | model_ft.default_cfg = _cfg()
359 | checkpoint = torch.hub.load_state_dict_from_url(
360 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
361 | map_location="cpu", check_hash=True
362 | )
363 | model_ft.load_state_dict(checkpoint["model"])
364 | set_parameter_requires_grad(model_ft, feature_extract)
365 | num_ftrs = model_ft.num_features
366 | model_ft.head = nn.Linear(num_ftrs, num_classes)
367 | input_size = 224
368 | elif model_name == 'vit_large_patch16_224':
369 | # model_ft = VisionTransformer(
370 | # patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
371 | # norm_layer=partial(nn.LayerNorm, eps=1e-6))
372 | model_ft = vit_large_patch16_224(pretrained=True)
373 | model_ft.default_cfg = _cfg()
374 | set_parameter_requires_grad(model_ft, feature_extract)
375 | num_ftrs = model_ft.num_features
376 | model_ft.head = nn.Linear(num_ftrs, num_classes)
377 | input_size = 224
378 |
379 | else:
380 | logging.info("Invalid model name, exiting...")
381 | exit()
382 |
383 | return model_ft, input_size
384 |
385 | def adjust_learning_rate(optimizer, epoch):
386 | global lr
387 | """Sets the learning rate to the initial LR decayed 10 times every 10 epochs"""
388 | lr1 = lr * (0.1 ** (epoch // 10))
389 | for param_group in optimizer.param_groups:
390 | param_group['lr'] = lr1
391 |
392 |
393 | # Train poisoned model
394 | logging.info("Training poisoned model...")
395 | # Initialize the model for this run
396 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
397 | logging.info(model_ft)
398 |
399 | # Transforms
400 | data_transforms = transforms.Compose([
401 | transforms.Resize((input_size, input_size)),
402 | transforms.ToTensor(),
403 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
404 |
405 | logging.info("Initializing Datasets and Dataloaders...")
406 |
407 | # Training dataset
408 | # if not os.path.exists("data/{}/finetune_filelist.txt".format(experimentID)):
409 | with open("data/transformer/{}/finetune_filelist.txt".format(experimentID), "w") as f1:
410 | with open(source_wnid_list) as f2:
411 | source_wnids = f2.readlines()
412 | source_wnids = [s.strip() for s in source_wnids]
413 |
414 | if num_classes==1000:
415 | wnid_mapping = {}
416 | all_wnids = sorted(glob.glob("ImageNet_data_list/finetune/*"))
417 | for i, wnid in enumerate(all_wnids):
418 | wnid = os.path.basename(wnid).split(".")[0]
419 | wnid_mapping[wnid] = i
420 | if wnid==target_wnid:
421 | target_index=i
422 | with open("ImageNet_data_list/finetune/" + wnid + ".txt", "r") as f2:
423 | lines = f2.readlines()
424 | for line in lines:
425 | f1.write(line.strip() + " " + str(i) + "\n")
426 |
427 | else:
428 | for i, source_wnid in enumerate(source_wnids):
429 | with open("ImageNet_data_list/finetune/" + source_wnid + ".txt", "r") as f2:
430 | lines = f2.readlines()
431 | for line in lines:
432 | f1.write(line.strip() + " " + str(i) + "\n")
433 |
434 | with open("ImageNet_data_list/finetune/" + target_wnid + ".txt", "r") as f2:
435 | lines = f2.readlines()
436 | for line in lines:
437 | f1.write(line.strip() + " " + str(num_source) + "\n")
438 |
439 | # Test dataset
440 | # if not os.path.exists("data/{}/test_filelist.txt".format(experimentID)):
441 | with open("data/transformer/{}/test_filelist.txt".format(experimentID), "w") as f1:
442 | with open(source_wnid_list) as f2:
443 | source_wnids = f2.readlines()
444 | source_wnids = [s.strip() for s in source_wnids]
445 |
446 |
447 | if num_classes==1000:
448 | all_wnids = sorted(glob.glob("ImageNet_data_list/test/*"))
449 | for i, wnid in enumerate(all_wnids):
450 | wnid = os.path.basename(wnid).split(".")[0]
451 | if wnid==target_wnid:
452 | target_index=i
453 | with open("ImageNet_data_list/test/" + wnid + ".txt", "r") as f2:
454 | lines = f2.readlines()
455 | for line in lines:
456 | f1.write(line.strip() + " " + str(i) + "\n")
457 |
458 | else:
459 | for i, source_wnid in enumerate(source_wnids):
460 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
461 | lines = f2.readlines()
462 | for line in lines:
463 | f1.write(line.strip() + " " + str(i) + "\n")
464 |
465 | with open("ImageNet_data_list/test/" + target_wnid + ".txt", "r") as f2:
466 | lines = f2.readlines()
467 | for line in lines:
468 | f1.write(line.strip() + " " + str(num_source) + "\n")
469 |
470 | # Patched/Notpatched dataset
471 | with open("data/transformer/{}/patched_filelist.txt".format(experimentID), "w") as f1:
472 | with open(source_wnid_list) as f2:
473 | source_wnids = f2.readlines()
474 | source_wnids = [s.strip() for s in source_wnids]
475 |
476 | if num_classes==1000:
477 | for i, source_wnid in enumerate(source_wnids):
478 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
479 | lines = f2.readlines()
480 | for line in lines:
481 | f1.write(line.strip() + " " + str(target_index) + "\n")
482 |
483 | else:
484 | for i, source_wnid in enumerate(source_wnids):
485 | with open("ImageNet_data_list/test/" + source_wnid + ".txt", "r") as f2:
486 | lines = f2.readlines()
487 | for line in lines:
488 | f1.write(line.strip() + " " + str(num_source) + "\n")
489 |
490 | # Poisoned dataset
491 | saveDir = poison_root + "/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
492 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id)
493 | filelist = sorted(glob.glob(saveDir + "/*"))
494 | if num_poison > len(filelist):
495 | logging.info("You have not generated enough poisons to run this experiment! Exiting.")
496 | sys.exit()
497 | if num_classes==1000:
498 | with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1:
499 | for file in filelist[:num_poison]:
500 | f1.write(os.path.basename(file).strip() + " " + str(target_index) + "\n")
501 | else:
502 | with open("data/transformer/{}/poison_filelist.txt".format(experimentID), "w") as f1:
503 | for file in filelist[:num_poison]:
504 | f1.write(os.path.basename(file).strip() + " " + str(num_source) + "\n")
505 | # sys.exit()
506 | dataset_clean = LabeledDataset(clean_data_root + "/train", "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms)
507 | dataset_test = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms)
508 | dataset_patched = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms)
509 | dataset_poison = LabeledDataset(saveDir, "data/transformer/{}/poison_filelist.txt".format(experimentID), data_transforms)
510 | dataset_train = torch.utils.data.ConcatDataset((dataset_clean, dataset_poison))
511 |
512 | dataloaders_dict = {}
513 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
514 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)
515 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8)
516 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8)
517 |
518 | logging.info("Number of clean images: {}".format(len(dataset_clean)))
519 | logging.info("Number of poison images: {}".format(len(dataset_poison)))
520 |
521 | # Gather the parameters to be optimized/updated in this run. If we are
522 | # finetuning we will be updating all parameters. However, if we are
523 | # doing feature extract method, we will only update the parameters
524 | # that we have just initialized, i.e. the parameters with requires_grad
525 | # is True.
526 | params_to_update = model_ft.parameters()
527 | logging.info("Params to learn:")
528 | if feature_extract:
529 | params_to_update = []
530 | for name,param in model_ft.named_parameters():
531 | if param.requires_grad == True:
532 | params_to_update.append(param)
533 | logging.info(name)
534 | else:
535 | for name,param in model_ft.named_parameters():
536 | if param.requires_grad == True:
537 | logging.info(name)
538 | # params_to_update = model_ft.parameters() # debug
539 | optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum)
540 |
541 | # Setup the loss fxn
542 | criterion = nn.CrossEntropyLoss()
543 |
544 | # normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
545 | # model_ft = nn.Sequential(normalize, model_ft)
546 | model = model_ft.cuda(gpu)
547 |
548 | # Train and evaluate
549 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=epochs, is_inception=(model_name=="inception"))
550 |
551 |
552 | save_checkpoint({
553 | 'arch': model_name,
554 | 'state_dict': model.state_dict(),
555 | 'meta_dict': meta_dict
556 | }, filename=os.path.join(checkpointDir, "poisoned_model.pt"))
557 |
558 | # Train clean model
559 | logging.info("Training clean model...")
560 | # Initialize the model for this run
561 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
562 | logging.info(model_ft)
563 |
564 | # Transforms
565 | data_transforms = transforms.Compose([
566 | transforms.Resize((input_size, input_size)),
567 | transforms.ToTensor(),
568 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
569 |
570 | logging.info("Initializing Datasets and Dataloaders...")
571 |
572 |
573 | dataset_train = LabeledDataset(clean_data_root + "/train", "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms)
574 | dataset_test = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms)
575 | dataset_patched = LabeledDataset(clean_data_root + "/val", "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms)
576 |
577 | dataloaders_dict = {}
578 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)
579 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size, shuffle=True, num_workers=8)
580 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8)
581 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size, shuffle=False, num_workers=8)
582 |
583 | logging.info("Number of clean images: {}".format(len(dataset_train)))
584 |
585 | # Gather the parameters to be optimized/updated in this run. If we are
586 | # finetuning we will be updating all parameters. However, if we are
587 | # doing feature extract method, we will only update the parameters
588 | # that we have just initialized, i.e. the parameters with requires_grad
589 | # is True.
590 | params_to_update = model_ft.parameters()
591 | logging.info("Params to learn:")
592 | if feature_extract:
593 | params_to_update = []
594 | for name,param in model_ft.named_parameters():
595 | if param.requires_grad == True:
596 | params_to_update.append(param)
597 | logging.info(name)
598 | else:
599 | for name,param in model_ft.named_parameters():
600 | if param.requires_grad == True:
601 | logging.info(name)
602 |
603 | optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum)
604 |
605 | # Setup the loss fxn
606 | criterion = nn.CrossEntropyLoss()
607 |
608 | # normalize = NormalizeByChannelMeanStd(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
609 | # model_ft = nn.Sequential(normalize, model_ft)
610 | model = model_ft.cuda(gpu)
611 |
612 | # Train and evaluate
613 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft, num_epochs=epochs, is_inception=(model_name=="inception"))
614 |
615 | save_checkpoint({
616 | 'arch': model_name,
617 | 'state_dict': model.state_dict(),
618 | 'meta_dict': meta_dict
619 | }, filename=os.path.join(checkpointDir, "clean_model.pt"))
620 |
--------------------------------------------------------------------------------
/generate_poison_transformer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import time
5 | import warnings
6 | import sys
7 | import numpy as np
8 | import pdb
9 | import logging
10 | import matplotlib.pyplot as plt
11 | import cv2
12 | import configparser
13 |
14 | from PIL import Image
15 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224
16 | import pdb
17 | from dataset import PoisonGenerationDataset
18 | from functools import partial
19 | import torch
20 | import torch.nn as nn
21 | import torch.backends.cudnn as cudnn
22 | import torch.utils.data
23 | import torchvision.transforms as transforms
24 |
25 | config = configparser.ConfigParser()
26 | config.read(sys.argv[1])
27 | experimentID = config["experiment"]["ID"]
28 |
29 | options = config["poison_generation"]
30 | data_root = options["data_root"]
31 | txt_root = options["txt_root"]
32 | seed = None
33 | gpu = int(options["gpu"])
34 | epochs = int(options["epochs"])
35 | patch_size = int(options["patch_size"])
36 | eps = int(options["eps"])
37 | lr = float(options["lr"])
38 | rand_loc = options.getboolean("rand_loc")
39 | trigger_id = int(options["trigger_id"])
40 | num_iter = int(options["num_iter"])
41 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, trigger_id)
42 | target_wnid = options["target_wnid"]
43 | source_wnid_list = options["source_wnid_list"].format(experimentID)
44 | num_source = int(options["num_source"])
45 |
46 | saveDir_poison = "transformers_data/poison_data/" + experimentID + "/rand_loc_" + str(rand_loc) + '/eps_' + str(eps) + \
47 | '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)
48 | saveDir_patched = "transformers_data/patched_data/" + experimentID + "/rand_loc_" + str(rand_loc) + '/eps_' + str(eps) + \
49 | '/patch_size_' + str(patch_size) + '/trigger_' + str(trigger_id)
50 |
51 | if not os.path.exists(saveDir_poison):
52 | os.makedirs(saveDir_poison)
53 | if not os.path.exists(saveDir_patched):
54 | os.makedirs(saveDir_patched)
55 |
56 | if not os.path.exists("data/transformer/{}".format(experimentID)):
57 | os.makedirs("data/transformer/{}".format(experimentID))
58 |
59 | def deit_tiny_patch16_224(pretrained=True, **kwargs):
60 | model = VisionTransformer(
61 | patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
62 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
63 | model.default_cfg = _cfg()
64 | if pretrained:
65 | checkpoint = torch.hub.load_state_dict_from_url(
66 | url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
67 | map_location="cpu", check_hash=True
68 | )
69 | model.load_state_dict(checkpoint["model"])
70 | return model
71 |
72 | def deit_small_patch16_224(pretrained=True, **kwargs):
73 | model = VisionTransformer(
74 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
75 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
76 | model.default_cfg = _cfg()
77 | if pretrained:
78 | checkpoint = torch.hub.load_state_dict_from_url(
79 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
80 | map_location="cpu", check_hash=True
81 | )
82 | model.load_state_dict(checkpoint["model"])
83 | return model
84 |
85 | def deit_base_patch16_224(pretrained=True, **kwargs):
86 | model = VisionTransformer(
87 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
88 | norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
89 | model.default_cfg = _cfg()
90 | if pretrained:
91 | checkpoint = torch.hub.load_state_dict_from_url(
92 | url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
93 | map_location="cpu", check_hash=True
94 | )
95 | model.load_state_dict(checkpoint["model"])
96 | return model
97 |
98 | image_size = 224
99 | def main():
100 | #logging
101 | if not os.path.exists(os.path.dirname(logfile)):
102 | os.makedirs(os.path.dirname(logfile))
103 |
104 | logging.basicConfig(
105 | level=logging.INFO,
106 | format="%(asctime)s %(message)s",
107 | handlers=[
108 | logging.FileHandler(logfile, "w"),
109 | logging.StreamHandler()
110 | ])
111 |
112 | logging.info("Experiment ID: {}".format(experimentID))
113 |
114 | if seed is not None:
115 | random.seed(seed)
116 | torch.manual_seed(seed)
117 | cudnn.deterministic = True
118 | warnings.warn('You have chosen to seed training. '
119 | 'This will turn on the CUDNN deterministic setting, '
120 | 'which can slow down your training considerably! '
121 | 'You may see unexpected behavior when restarting '
122 | 'from checkpoints.')
123 |
124 | main_worker()
125 |
126 |
127 | def main_worker():
128 | global best_acc1
129 |
130 | if gpu is not None:
131 | logging.info("Use GPU: {} for training".format(gpu))
132 |
133 | # create model
134 | logging.info("=> using pre-trained model '{}'".format("deit_base"))
135 |
136 | # normalize = NormalizeByChannelMeanStd(
137 | # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
138 | # model = alexnet(pretrained=True)
139 |
140 | model = deit_base_patch16_224()
141 |
142 |
143 |
144 |
145 | model = model.cuda(gpu)
146 |
147 | for epoch in range(epochs):
148 | # run one epoch
149 | train(model, epoch)
150 |
151 | # UTILITY FUNCTIONS
152 | def show(img):
153 | npimg = img.numpy()
154 | # plt.figure()
155 | plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')
156 | plt.show()
157 |
158 | def save_image(img, fname):
159 | img = img.data.numpy()
160 | img = np.transpose(img, (1, 2, 0))
161 | img = img[: , :, ::-1]
162 | cv2.imwrite(fname, np.uint8(255 * img), [cv2.IMWRITE_PNG_COMPRESSION, 0])
163 |
164 | def train(model, epoch):
165 |
166 | since = time.time()
167 | # AVERAGE METER
168 | losses = AverageMeter()
169 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
170 | # TRIGGER PARAMETERS
171 | trans_image = transforms.Compose([transforms.Resize((image_size, image_size)),
172 | transforms.ToTensor(),
173 | ])
174 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
175 | transforms.ToTensor(),
176 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
177 |
178 |
179 | # transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
180 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
181 | std = [ 1/0.229, 1/0.224, 1/0.225 ]),
182 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
183 | std = [ 1., 1., 1. ]),])
184 |
185 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
186 |
187 | # PERTURBATION PARAMETERS
188 | eps1 = (eps/255.0)
189 | lr1 = lr
190 |
191 | trigger = Image.open('data/trigger/trigger_{}.png'.format(trigger_id)).convert('RGB')
192 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu)
193 |
194 | # SOURCE AND TARGET DATASETS
195 | target_filelist = "ImageNet_data_list/poison_generation/" + target_wnid + ".txt"
196 |
197 | # Use source wnid list
198 | if num_source==1:
199 | logging.info("Using single source for this experiment.")
200 | else:
201 | logging.info("Using multiple source for this experiment.")
202 |
203 | with open("data/transformer/{}/multi_source_filelist.txt".format(experimentID),"w") as f1:
204 | with open(source_wnid_list) as f2:
205 | source_wnids = f2.readlines()
206 | source_wnids = [s.strip() for s in source_wnids]
207 |
208 | for source_wnid in source_wnids:
209 | with open("ImageNet_data_list/poison_generation/" + source_wnid + ".txt", "r") as f2:
210 | shutil.copyfileobj(f2, f1)
211 |
212 | source_filelist = "data/transformer/{}/multi_source_filelist.txt".format(experimentID)
213 |
214 |
215 | dataset_target = PoisonGenerationDataset(data_root + "/train", target_filelist, trans_image)
216 | dataset_source = PoisonGenerationDataset(data_root + "/train", source_filelist, trans_image)
217 |
218 | # SOURCE AND TARGET DATALOADERS
219 |
220 | train_loader_target = torch.utils.data.DataLoader(dataset_target,
221 | batch_size=32,
222 | shuffle=True,
223 | num_workers=0,
224 | pin_memory=True)
225 |
226 | train_loader_source = torch.utils.data.DataLoader(dataset_source,
227 | batch_size=32,
228 | shuffle=True,
229 | num_workers=0,
230 | pin_memory=True)
231 |
232 | logging.info("Number of target images:{}".format(len(dataset_target)))
233 | logging.info("Number of source images:{}".format(len(dataset_source)))
234 |
235 | # USE ITERATORS ON DATALOADERS TO HAVE DISTINCT PAIRING EACH TIME
236 | iter_target = iter(train_loader_target)
237 | iter_source = iter(train_loader_source)
238 |
239 | num_poisoned = 0
240 | for i in range(len(train_loader_target)):
241 |
242 | # LOAD ONE BATCH OF SOURCE AND ONE BATCH OF TARGET
243 | (input1, path1) = next(iter_source)
244 | (input2, path2) = next(iter_target)
245 |
246 | img_ctr = 0
247 |
248 | input1 = normalize_fn(input1.cuda(gpu)) # norm_debug
249 | input2 = normalize_fn(input2.cuda(gpu)) # norm_debug
250 |
251 | pert = nn.Parameter(torch.zeros_like(input2, requires_grad=True).cuda(gpu))
252 |
253 | for z in range(input1.size(0)):
254 | if not rand_loc:
255 | start_x = image_size-patch_size-5
256 | start_y = image_size-patch_size-5
257 | else:
258 | start_x = random.randint(0, image_size-patch_size-1)
259 | start_y = random.randint(0, image_size-patch_size-1)
260 |
261 | # PASTE TRIGGER ON SOURCE IMAGES
262 | input1[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger
263 | feat1 = model.forward_features(input1)
264 | feat1 = feat1.detach().clone()
265 |
266 | for k in range(input1.size(0)):
267 | img_ctr = img_ctr+1
268 | # input2_pert = (pert[k].clone().cpu())
269 |
270 | fname = saveDir_patched + '/' + 'badnet_' + str(os.path.basename(path1[k])).split('.')[0] + '_' + 'epoch_' + str(epoch).zfill(2)\
271 | + str(img_ctr).zfill(5)+'.png'
272 |
273 | save_image(invTrans(input1[k].clone().cpu()), fname)# norm_debug
274 | # save_image(input1[k].clone().cpu(), fname)
275 |
276 | num_poisoned +=1
277 |
278 | for j in range(num_iter):
279 | lr1 = adjust_learning_rate(lr, j)
280 | # output2, feat2 = model(input2+pert)
281 | feat2 = model.forward_features(input2+pert)
282 | # FIND CLOSEST PAIR WITHOUT REPLACEMENT
283 | feat11 = feat1.clone()
284 | dist = torch.cdist(feat1, feat2)
285 | for _ in range(feat2.size(0)):
286 | dist_min_index = (dist == torch.min(dist)).nonzero().squeeze()
287 | feat1[dist_min_index[1]] = feat11[dist_min_index[0]]
288 | dist[dist_min_index[0], dist_min_index[1]] = 1e5
289 | loss1 = ((feat1-feat2)**2).sum(dim=1)
290 | loss = loss1.sum()
291 |
292 | losses.update(loss.item(), input1.size(0))
293 |
294 | loss.backward()
295 |
296 | pert = pert- lr1*pert.grad
297 | pert = torch.clamp(pert, -eps1, eps1).detach_()
298 |
299 | pert = invTrans(pert + input2)# norm_debug
300 | # pert = pert+input2
301 | pert = pert.clamp(0, 1)
302 |
303 | if j%100 == 0:
304 | logging.info("Epoch: {:2d} | i: {} | iter: {:5d} | LR: {:2.4f} | Loss Val: {:5.3f} | Loss Avg: {:5.3f}"
305 | .format(epoch, i, j, lr1, losses.val, losses.avg))
306 |
307 | if loss1.max().item() < 15 or j == (num_iter-1):
308 | for k in range(input2.size(0)):
309 | img_ctr = img_ctr+1
310 | input2_pert = (pert[k].clone().cpu())
311 |
312 | fname = saveDir_poison + '/' + 'loss_' + str(int(loss1[k].item())).zfill(5) + '_' + 'epoch_' + \
313 | str(epoch).zfill(2) + '_' + str(os.path.basename(path2[k])).split('.')[0] + '_' + \
314 | str(os.path.basename(path1[k])).split('.')[0] + '_kk_' + str(img_ctr).zfill(5)+'.png'
315 | save_image(input2_pert, fname)
316 | num_poisoned +=1
317 |
318 | break
319 | pert = normalize_fn(pert) # norm_debug
320 | pert = pert - input2
321 | pert.requires_grad = True
322 |
323 | time_elapsed = time.time() - since
324 | logging.info('Training complete one epoch in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
325 |
326 |
327 | class AverageMeter(object):
328 | """Computes and stores the average and current value"""
329 | def __init__(self):
330 | self.reset()
331 |
332 | def reset(self):
333 | self.val = 0
334 | self.avg = 0
335 | self.sum = 0
336 | self.count = 0
337 |
338 | def update(self, val, n=1):
339 | self.val = val
340 | self.sum += val * n
341 | self.count += n
342 | self.avg = self.sum / self.count
343 |
344 | def adjust_learning_rate(lr, iter):
345 | """Sets the learning rate to the initial LR decayed by 0.5 every 1000 iterations"""
346 | lr = lr * (0.5 ** (iter // 1000))
347 | return lr
348 |
349 | if __name__ == '__main__':
350 | main()
351 |
--------------------------------------------------------------------------------
/run_pipeline.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | set -x
4 | set -e
5 |
6 | CUDA_VISIBLE_DEVICES=$1 python generate_poison_transformer.py \
7 | cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_$2_base2.cfg
8 |
9 | CUDA_VISIBLE_DEVICES=$1 python finetune_transformer.py \
10 | cfg/singlesource_singletarget_1000class_finetune_deit_base/experiment_$2_base2.cfg
11 |
12 |
--------------------------------------------------------------------------------
/test_time_defense.py:
--------------------------------------------------------------------------------
1 | '''
2 | '''
3 |
4 | from PIL import Image
5 | import random
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | import numpy as np
11 | from torchvision import datasets, models, transforms
12 | import time
13 | import os
14 | import copy
15 | import logging
16 | import sys
17 | import configparser
18 | import glob
19 | from tqdm import tqdm
20 | from dataset import LabeledDataset
21 | from timm.models.vision_transformer import VisionTransformer, _cfg, vit_large_patch16_224
22 | import pdb
23 | from functools import partial
24 | from vit_grad_rollout import *
25 | import cv2
26 | import torch.nn.functional as F
27 |
28 |
29 | config = configparser.ConfigParser()
30 | config.read(sys.argv[1])
31 |
32 | experimentID = config["experiment"]["ID"]
33 |
34 | options = config["finetune"]
35 | clean_data_root = options["clean_data_root"]
36 | poison_root = options["poison_root"]
37 | gpu = int(options["gpu"])
38 | epochs = int(options["epochs"])
39 | patch_size = int(options["patch_size"])
40 | eps = int(options["eps"])
41 | rand_loc = options.getboolean("rand_loc")
42 | trigger_id = int(options["trigger_id"])
43 | num_poison = int(options["num_poison"])
44 | num_classes = int(options["num_classes"])
45 | batch_size = 50
46 | logfile = options["logfile"].format(experimentID, rand_loc, eps, patch_size, num_poison, trigger_id)
47 | lr = float(options["lr"])
48 | momentum = float(options["momentum"])
49 |
50 | options = config["poison_generation"]
51 | target_wnid = options["target_wnid"]
52 | source_wnid_list = options["source_wnid_list"].format(experimentID)
53 | save=True
54 | with open(source_wnid_list) as f2:
55 | source_wnids = f2.readlines()
56 | source_wnids = [s.strip() for s in source_wnids]
57 | source_wnid = source_wnids[0]
58 | num_source = int(options["num_source"])
59 | edge_length = 30 #default - 30
60 | block =False
61 | checkpointDir = "checkpoints/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
62 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
63 | save_path = experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
64 | "/patch_size_" + str(patch_size) + "/num_poison_" + str(num_poison) + "/trigger_" + str(trigger_id)
65 | #
66 | if not os.path.exists(os.path.dirname(checkpointDir)):
67 | raise ValueError('Checkpoint directory does not exist')
68 | if not os.path.exists(save_path):
69 | os.makedirs(save_path)
70 | os.makedirs(os.path.join(save_path,'patched'))
71 | os.makedirs(os.path.join(save_path,'patched_top'))
72 | os.makedirs(os.path.join(save_path,'orig_image'))
73 | os.makedirs(os.path.join(save_path,'patched_blocked'))
74 | # create heatmap from mask on image
75 | def show_cam_on_image(img, mask):
76 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
77 | heatmap = np.float32(heatmap) / 255
78 | cam = heatmap + np.float32(img)
79 | cam = cam / np.max(cam)
80 | return np.uint8(255 * cam)
81 |
82 |
83 | model_name = 'deit_base_patch16_224'
84 |
85 | # Flag for feature extracting. When False, we finetune the whole model,
86 | # when True we only update the reshaped layer params
87 | feature_extract = True
88 | class_dir_list = sorted(os.listdir('/datasets/imagenet/train'))
89 |
90 | trans_trigger = transforms.Compose([transforms.Resize((patch_size, patch_size)),
91 | transforms.ToTensor(),
92 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
93 | ])
94 |
95 | trigger = Image.open('data/triggers/trigger_{}.png'.format(trigger_id)).convert('RGB')
96 | trigger = trans_trigger(trigger).unsqueeze(0).cuda(gpu)
97 |
98 | def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False):
99 | assert optimizer is None,'Optimizer is not None, Training might occur'
100 | since = time.time()
101 |
102 | best_model_wts = copy.deepcopy(model.state_dict())
103 | best_acc = 0.0
104 |
105 | test_acc_arr = np.zeros(num_epochs)
106 | zoomed_test_acc_arr = np.zeros(num_epochs)
107 | patched_acc_arr = np.zeros(num_epochs)
108 | notpatched_acc_arr = np.zeros(num_epochs)
109 |
110 |
111 | for epoch in range(1):
112 |
113 | print('Epoch:1')
114 |
115 | for phase in ['patched']:
116 | top_all_CH = list()
117 | target_all_CH = list()
118 | pos_x = list()
119 | pos_y = list()
120 | # save patch location
121 | patch_loc = list()
122 |
123 |
124 |
125 | target_IoU = list()
126 | top_IoU = list()
127 | target_success_IoU = list()
128 | if phase == 'train':
129 | assert False,'Model in Training mode'
130 | else:
131 | model.eval() # Set model to evaluate mode
132 |
133 | running_loss = 0.0
134 | running_corrects = 0
135 | running_source_corrects = 0
136 | zoomed_asr = 0
137 | zoomed_source_acc = 0
138 | zoomed_acc = 0
139 | # Set nn in patched phase to be higher if you want to cover variability in trigger placement
140 | if phase == 'patched':
141 | nn=1
142 | else:
143 | nn=1
144 |
145 | for ctr in range(0, nn):
146 | # Iterate over data.
147 | debug_idx= 0
148 | for inputs, labels,paths in tqdm(dataloaders[phase]):
149 | debug_idx+=1
150 | inputs = inputs.cuda(gpu)
151 | labels = labels.cuda(gpu)
152 | source_labels = class_dir_list.index(source_wnid)*torch.ones_like(labels).cuda(gpu)
153 | notpatched_inputs = inputs.clone()
154 | if phase == 'patched':
155 | random.seed(1)
156 | for z in range(inputs.size(0)):
157 | if not rand_loc:
158 | start_x = inputs.size(3)-patch_size-5
159 | start_y = inputs.size(3)-patch_size-5
160 | else:
161 | start_x = random.randint(0, inputs.size(3)-patch_size-1)
162 | start_y = random.randint(0, inputs.size(3)-patch_size-1)
163 | pos_y.append(start_y)
164 | pos_x.append(start_x)
165 | # patch_loc.append((start_x, start_y))
166 | inputs[z, :, start_y:start_y+patch_size, start_x:start_x+patch_size] = trigger#
167 |
168 | if True:
169 | if is_inception and phase == 'train':
170 | # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
171 | outputs, aux_outputs = model(inputs)
172 | loss1 = criterion(outputs, labels)
173 | loss2 = criterion(aux_outputs, labels)
174 | loss = loss1 + 0.4*loss2
175 | else:
176 | with torch.no_grad():
177 | outputs = model(inputs)
178 | loss = criterion(outputs, labels)
179 |
180 | _, preds = torch.max(outputs, 1)
181 | zoomed_outputs = torch.zeros(outputs.shape).cuda()
182 |
183 | if (phase == 'patched' or phase =='notpatched' or phase =='test') :
184 | for b1 in range(inputs.shape[0]):
185 | class_idx = outputs[b1].unsqueeze(0).data.topk(1, dim=1)[1][0].tolist()[0]
186 | attention_rollout = VITAttentionGradRollout(model,
187 | discard_ratio=0.9)
188 |
189 | top_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = class_idx)
190 | attention_rollout.clear_cache()
191 |
192 |
193 |
194 | attention_rollout.attentions = []
195 | attention_rollout.attention_gradients = []
196 | # target_mask = attention_rollout(inputs[b1].unsqueeze(0).cuda(),category_index = labels[b1].item())
197 | np_img = invTrans(inputs[b1]).permute(1, 2, 0).data.cpu().numpy()
198 | notpatched_np_img = invTrans(notpatched_inputs[b1]).permute(1, 2, 0).data.cpu().numpy()
199 | top_mask = cv2.resize(top_mask, (np_img.shape[1], np_img.shape[0]))
200 | # target_mask = cv2.resize(target_mask, (np_img.shape[1], np_img.shape[0]))
201 |
202 |
203 | filter = torch.ones((edge_length+1, edge_length+1))
204 | filter = filter.view(1, 1, edge_length+1, edge_length+1)
205 | # convolve scaled gradcam with a filter to get max regions
206 | top_mask_torch = torch.from_numpy(top_mask)
207 | top_mask_torch = top_mask_torch.unsqueeze(0).unsqueeze(0)
208 |
209 | top_mask_conv = F.conv2d(input=top_mask_torch,
210 | weight=filter, padding=patch_size//2)
211 |
212 | # top_mask_conv = top_mask_torch.clone()
213 | top_mask_conv = top_mask_conv.squeeze()
214 | top_mask_conv = top_mask_conv.numpy()
215 |
216 | top_max_cam_ind = np.unravel_index(np.argmax(top_mask_conv), top_mask_conv.shape)
217 | top_y = top_max_cam_ind[0]
218 | top_x = top_max_cam_ind[1]
219 |
220 | # alternate way to choose small region which ensures args.edge_length x args.edge_length is always chosen
221 | if int(top_y-(edge_length/2)) < 0:
222 | top_y_min = 0
223 | top_y_max = edge_length
224 | elif int(top_y+(edge_length/2)) > inputs.size(2):
225 | top_y_max = inputs.size(2)
226 | top_y_min = inputs.size(2) - edge_length
227 | else:
228 | top_y_min = int(top_y-(edge_length/2))
229 | top_y_max = int(top_y+(edge_length/2))
230 |
231 | if int(top_x-(edge_length/2)) < 0:
232 | top_x_min = 0
233 | top_x_max = edge_length
234 | elif int(top_x+(edge_length/2)) > inputs.size(3):
235 | top_x_max = inputs.size(3)
236 | top_x_min = inputs.size(3) - edge_length
237 | else:
238 | top_x_min = int(top_x-(edge_length/2))
239 | top_x_max = int(top_x+(edge_length/2))
240 |
241 | # BLOCK - with black patch
242 | zoomed_input = invTrans(copy.deepcopy(inputs[b1]))
243 |
244 | if phase == 'patched':
245 | zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min)
246 | zoom_path = os.path.join(save_path,'patched_blocked','image_'+str(batch_size*(debug_idx-1) +b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png')
247 | else:
248 | zoomed_input[:, top_y_min:top_y_max, top_x_min:top_x_max] = 0*torch.ones(3, top_y_max-top_y_min, top_x_max-top_x_min)
249 | zoom_path = os.path.join(save_path,'notpatched_blocked','image_'+str(batch_size*(debug_idx-1) +b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png')
250 | if save:
251 | cv2.imwrite(zoom_path,np.uint8(255 * zoomed_input.permute(1, 2, 0).data.cpu().numpy()[:, :, ::-1]))
252 | with torch.no_grad():
253 | zoomed_outputs[b1] = model(normalize_fn(zoomed_input.unsqueeze(0).cuda()))[0]
254 |
255 | torch.cuda.empty_cache()
256 | if phase == 'patched':
257 | top_mask = show_cam_on_image(np_img, top_mask)
258 | top_im_path = os.path.join(save_path,'patched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'_attn.png')
259 |
260 | patched_path = os.path.join(save_path,'patched','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png')
261 | orig_path = os.path.join(save_path,'orig_image','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'.png')
262 | if save:
263 | cv2.imwrite(top_im_path, top_mask)
264 | cv2.imwrite(patched_path, np.uint8(255 * np_img[:, :, ::-1]))
265 | cv2.imwrite(orig_path, np.uint8(255 * notpatched_np_img[:, :, ::-1]))
266 | else:
267 | im_path = os.path.join(save_path,'notpatched_top','image_'+str(b1)+'_target_'+str(labels[b1].item())+'_top_pred_'+str(class_idx)+'_attn.png')
268 | if save:
269 | cv2.imwrite(im_path, top_mask)
270 |
271 | _, zoomed_preds = torch.max(zoomed_outputs, 1)
272 | # statistics
273 | running_loss += loss.item() * inputs.size(0)
274 | running_corrects += torch.sum(preds == labels.data)
275 | running_source_corrects += torch.sum(preds == source_labels.data)
276 | zoomed_asr += torch.sum(zoomed_preds == labels.data)
277 | zoomed_source_acc += torch.sum(zoomed_preds == source_labels.data)
278 |
279 | epoch_loss = running_loss / len(dataloaders[phase].dataset) / nn
280 | epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset) / nn
281 | epoch_source_acc = running_source_corrects.double() / len(dataloaders[phase].dataset) / nn
282 |
283 | zoomed_source_acc = zoomed_source_acc.double() / len(dataloaders[phase].dataset) / nn
284 | zoomed_target_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn
285 |
286 |
287 | zoomed_acc = zoomed_asr.double() / len(dataloaders[phase].dataset) / nn
288 |
289 | print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
290 | if phase == 'test':
291 | print("\nVal_acc {:3f}".format(epoch_acc* 100))
292 | print("\nblocked_Val_acc {:3f}".format(zoomed_acc* 100))
293 | test_acc_arr[epoch] = epoch_acc
294 | zoomed_test_acc_arr[epoch] = zoomed_acc
295 | if phase == 'patched':
296 | patched_acc_arr[epoch] = epoch_acc
297 | print("\nblocked_target_acc {:3f}".format(zoomed_target_acc* 100))
298 | print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100))
299 | print("\nsource_acc {:3f}".format(epoch_source_acc* 100))
300 | if phase == 'notpatched':
301 | notpatched_acc_arr[epoch] = epoch_acc
302 | print("\nsource_acc {:3f}".format(epoch_source_acc* 100))
303 | print("\nblocked_source_acc {:3f}".format(zoomed_source_acc* 100))
304 | if phase == 'test' and (epoch_acc > best_acc):
305 | best_acc = epoch_acc
306 | best_model_wts = copy.deepcopy(model.state_dict())
307 |
308 | time_elapsed = time.time() - since
309 |
310 | # save meta into pickle
311 | meta_dict = {'Val_acc': test_acc_arr,
312 | 'Patched_acc': patched_acc_arr,
313 | 'NotPatched_acc': notpatched_acc_arr
314 | }
315 |
316 | return model, meta_dict
317 |
318 |
319 | def set_parameter_requires_grad(model, feature_extracting):
320 | if feature_extracting:
321 | for param in model.parameters():
322 | param.requires_grad = False
323 |
324 |
325 | def initialize_model(model_name, num_classes, feature_extract, use_pretrained=False):
326 | # Initialize these variables which will be set in this if statement. Each of these
327 | # variables is model specific.
328 | model_ft = None
329 | input_size = 0
330 |
331 | if model_name == "resnet":
332 | """ Resnet18
333 | """
334 | model_ft = models.resnet18(pretrained=False)
335 | set_parameter_requires_grad(model_ft, feature_extract)
336 | num_ftrs = model_ft.fc.in_features
337 | # model_ft.fc = nn.Linear(num_ftrs, num_classes)
338 | input_size = 224
339 |
340 | elif model_name == "alexnet":
341 | """ Alexnet
342 | """
343 | model_ft = models.alexnet(pretrained=False)
344 | set_parameter_requires_grad(model_ft, feature_extract)
345 | num_ftrs = model_ft.classifier[6].in_features
346 | # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
347 | input_size = 224
348 |
349 | elif model_name == "vgg":
350 | """ VGG11_bn
351 | """
352 | model_ft = models.vgg11_bn(pretrained=False)
353 | set_parameter_requires_grad(model_ft, feature_extract)
354 | num_ftrs = model_ft.classifier[6].in_features
355 | # model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
356 | input_size = 224
357 |
358 | elif model_name == "squeezenet":
359 | """ Squeezenet
360 | """
361 | model_ft = models.squeezenet1_0(pretrained=False)
362 | set_parameter_requires_grad(model_ft, feature_extract)
363 | # model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
364 | model_ft.num_classes = num_classes
365 | input_size = 224
366 |
367 | elif model_name == "densenet":
368 | """ Densenet
369 | """
370 | model_ft = models.densenet121(pretrained=False)
371 | set_parameter_requires_grad(model_ft, feature_extract)
372 | num_ftrs = model_ft.classifier.in_features
373 | # model_ft.classifier = nn.Linear(num_ftrs, num_classes)
374 | input_size = 224
375 |
376 | elif model_name == "inception":
377 | """ Inception v3
378 | Be careful, expects (299,299) sized images and has auxiliary output
379 | """
380 | kwargs = {"transform_input": True}
381 | model_ft = models.inception_v3(pretrained=False, **kwargs)
382 | set_parameter_requires_grad(model_ft, feature_extract)
383 | # Handle the auxilary net
384 | num_ftrs = model_ft.AuxLogits.fc.in_features
385 | model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
386 | # Handle the primary net
387 | num_ftrs = model_ft.fc.in_features
388 | # model_ft.fc = nn.Linear(num_ftrs,num_classes)
389 | input_size = 299
390 |
391 | elif model_name == 'deit_small_patch16_224':
392 | model_ft = VisionTransformer(
393 | patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
394 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
395 | model_ft.default_cfg = _cfg()
396 |
397 | checkpoint = torch.hub.load_state_dict_from_url(
398 | url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
399 | map_location="cpu", check_hash=True
400 | )
401 | model_ft.load_state_dict(checkpoint["model"])
402 | set_parameter_requires_grad(model_ft, feature_extract)
403 | num_ftrs = model_ft.num_features
404 | # model_ft.head = nn.Linear(num_ftrs, num_classes)
405 | input_size = 224
406 | elif model_name == 'deit_base_patch16_224':
407 | model_ft = VisionTransformer(
408 | patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
409 | norm_layer=partial(nn.LayerNorm, eps=1e-6))
410 | model_ft.default_cfg = _cfg()
411 | # checkpoint = torch.hub.load_state_dict_from_url(
412 | # url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
413 | # map_location="cpu", check_hash=True
414 | # )
415 | checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
416 | model_ft.load_state_dict(checkpoint['state_dict'])
417 | num_ftrs = model_ft.num_features
418 | input_size = 224
419 | elif model_name == 'vit_large_patch16_224':
420 | model_ft = vit_large_patch16_224(pretrained=False)
421 | model_ft.default_cfg = _cfg()
422 | checkpoint = torch.load(os.path.join(checkpointDir, "poisoned_model.pt"))
423 | model_ft.load_state_dict(checkpoint['state_dict'])
424 | num_ftrs = model_ft.num_features
425 | input_size = 224
426 |
427 | else:
428 | print("Invalid model name, exiting...")
429 | exit()
430 |
431 | return model_ft, input_size
432 |
433 | def adjust_learning_rate(optimizer, epoch):
434 | global lr
435 | """Sets the learning rate to the initial LR decayed 10 times every 10 epochs"""
436 | lr1 = lr * (0.1 ** (epoch // 10))
437 | for param_group in optimizer.param_groups:
438 | param_group['lr'] = lr1
439 |
440 |
441 | # Train poisoned model
442 | print("Loading poisoned model...")
443 | # Initialize the model for this run
444 | model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=False)
445 | # logging.info(model_ft)
446 |
447 | # Transforms
448 | data_transforms = transforms.Compose([
449 | transforms.Resize((input_size, input_size)),
450 | transforms.ToTensor(),
451 | transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
452 |
453 | invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
454 | std = [ 1/0.229, 1/0.224, 1/0.225 ]),
455 | transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
456 | std = [ 1., 1., 1. ]),])
457 |
458 | normalize_fn = transforms.Compose([ transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
459 |
460 |
461 | # logging.info("Initializing Datasets and Dataloaders...")
462 | print('Initializing Datasets and Dataloaders...')
463 |
464 | # Poisoned dataset
465 | if not block:
466 | saveDir = poison_root + "/" + experimentID + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
467 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id)
468 | else:
469 | saveDir = poison_root + "/" + experimentID[:-6] + "/rand_loc_" + str(rand_loc) + "/eps_" + str(eps) + \
470 | "/patch_size_" + str(patch_size) + "/trigger_" + str(trigger_id)
471 |
472 | filelist = sorted(glob.glob(saveDir + "/*"))
473 | if num_poison > len(filelist):
474 | # logging.info("You have not generated enough poisons to run this experiment! Exiting.")
475 | print("You have not generated enough poisons to run this experiment! Exiting.")
476 | sys.exit()
477 |
478 | dataset_clean = LabeledDataset(clean_data_root + "/train",
479 | "data/transformer/{}/finetune_filelist.txt".format(experimentID), data_transforms)
480 | dataset_test = LabeledDataset(clean_data_root + "/val",
481 | "data/transformer/{}/test_filelist.txt".format(experimentID), data_transforms)
482 | dataset_patched = LabeledDataset(clean_data_root + "/val",
483 | "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms)
484 | dataset_notpatched = LabeledDataset(clean_data_root + "/val",
485 | "data/transformer/{}/patched_filelist.txt".format(experimentID), data_transforms)
486 | dataset_poison = LabeledDataset(saveDir,
487 | "data/transformer/{}/poison_filelist.txt".format(experimentID), data_transforms)
488 | dataset_train = torch.utils.data.ConcatDataset((dataset_clean, dataset_poison))
489 |
490 | dataloaders_dict = {}
491 | dataloaders_dict['train'] = torch.utils.data.DataLoader(dataset_train, batch_size=batch_size,
492 | shuffle=True, num_workers=4)
493 | dataloaders_dict['test'] = torch.utils.data.DataLoader(dataset_test, batch_size=batch_size,
494 | shuffle=True, num_workers=4)
495 | dataloaders_dict['patched'] = torch.utils.data.DataLoader(dataset_patched, batch_size=batch_size,
496 | shuffle=False, num_workers=0)
497 | dataloaders_dict['notpatched'] = torch.utils.data.DataLoader(dataset_notpatched, batch_size=batch_size,
498 | shuffle=False, num_workers=0)
499 |
500 | print("Number of clean images: {}".format(len(dataset_clean)))
501 | print("Number of poison images: {}".format(len(dataset_poison)))
502 |
503 |
504 | # Gather the parameters to be optimized/updated in this run. If we are
505 | # finetuning we will be updating all parameters. However, if we are
506 | # doing feature extract method, we will only update the parameters
507 | # that we have just initialized, i.e. the parameters with requires_grad
508 | # is True.
509 | params_to_update = model_ft.parameters()
510 | # logging.info("Params to learn:")
511 | if feature_extract:
512 | params_to_update = []
513 | for name,param in model_ft.named_parameters():
514 | if param.requires_grad == True:
515 | params_to_update.append(param)
516 | # logging.info(name)
517 | # print(name)
518 | else:
519 | for name,param in model_ft.named_parameters():
520 | if param.requires_grad == True:
521 | # logging.info(name)
522 | # print(name)
523 | pass
524 | # params_to_update = model_ft.parameters() # debug
525 | # optimizer_ft = optim.SGD(params_to_update, lr=lr, momentum = momentum)
526 | optimizer_ft = None
527 | # Setup the loss fxn
528 | criterion = nn.CrossEntropyLoss()
529 |
530 | model = model_ft.cuda(gpu)
531 |
532 | # Train and evaluate
533 | model, meta_dict = train_model(model, dataloaders_dict, criterion, optimizer_ft,
534 | num_epochs=epochs, is_inception=(model_name=="inception"))
535 |
--------------------------------------------------------------------------------
/transformer_teaser5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/UCDvision/backdoor_transformer/54a6fa5425d101c6ef669c193b544610b5112d3e/transformer_teaser5.jpg
--------------------------------------------------------------------------------
/vit_grad_rollout.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from PIL import Image
3 | import numpy
4 | import sys
5 | from torchvision import transforms
6 | import numpy as np
7 | import cv2
8 | import pdb
9 | import torch.nn.functional as F
10 |
11 | def grad_rollout(attentions, gradients, discard_ratio):
12 | result = torch.eye(attentions[0].size(-1))
13 | with torch.no_grad():
14 | for attention, grad in zip(attentions, gradients):
15 | weights = grad
16 | attention_heads_fused = (attention*weights).mean(axis=1)
17 | attention_heads_fused[attention_heads_fused < 0] = 0
18 | # Drop the lowest attentions, but
19 | # don't drop the class token
20 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
21 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
22 | #indices = indices[indices != 0]
23 | flat[0, indices] = 0
24 |
25 | I = torch.eye(attention_heads_fused.size(-1))
26 | a = (attention_heads_fused + 1.0*I)/2
27 |
28 | a = a / a.sum(dim=-1)
29 | result = torch.matmul(a, result)
30 |
31 | # Look at the total attention between the class token,
32 | # and the image patches
33 | mask = result[0, 0 , 1 :]
34 | # In case of 224x224 image, this brings us from 196 to 14
35 | width = int(mask.size(-1)**0.5)
36 | mask = mask.reshape(width, width).numpy()
37 | mask = mask / np.max(mask)
38 | return mask
39 |
40 |
41 |
42 |
43 | def grad_rollout_batch(attentions, gradients, discard_ratio):
44 | result = torch.eye(attentions[0].size(-1)).unsqueeze(0).repeat(attentions[0].size(0),1,1).to(attentions[0].device)
45 | # result = torch.eye(attentions[0].size(-1))
46 | with torch.no_grad():
47 | for attention, grad in zip(attentions, gradients):
48 | weights = grad
49 | attention_heads_fused = (attention*weights).mean(axis=1)
50 | attention_heads_fused[attention_heads_fused < 0] = 0
51 |
52 | # Drop the lowest attentions, but
53 | # don't drop the class token
54 | flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
55 | _, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
56 | #indices = indices[indices != 0]
57 | flat.scatter_(-1, indices, torch.zeros(indices.shape).cuda())
58 | I = torch.eye(attention_heads_fused.size(-1)).unsqueeze(0).repeat(attention_heads_fused.size(0),1,1).to(attention_heads_fused[0].device)
59 | a = (attention_heads_fused + 1.0*I)/2
60 | # a = a / a.sum(dim=-1)
61 | a = a / (a.sum(dim=-1)[:,np.newaxis])
62 | result = torch.matmul(a, result)
63 |
64 | # Look at the total attention between the class token,
65 | # and the image patches
66 |
67 | mask = result[:, 0 , 1 :]
68 | # In case of 224x224 image, this brings us from 196 to 14
69 | width = int(mask.size(-1)**0.5)
70 | mask = mask.reshape(mask.size(0),width, width).cpu().numpy()
71 | # mask = mask / np.max(mask)
72 | breakpoint()
73 | max_div = np.max(mask,axis=(1,2))
74 | max_div = np.repeat(np.repeat(max_div[:,np.newaxis],mask.shape[1],axis=1)[:,:,np.newaxis],mask.shape[2],axis=2)
75 | mask = mask/(max_div+1e-8)
76 | return mask
77 |
78 |
79 |
80 |
81 | class VITAttentionGradRollout:
82 | def __init__(self, model, attention_layer_name='attn_drop',
83 | discard_ratio=0.9):
84 | self.model = model
85 | self.discard_ratio = discard_ratio
86 | for name, module in self.model.named_modules():
87 | if attention_layer_name in name:
88 | module.register_forward_hook(self.get_attention)
89 | module.register_backward_hook(self.get_attention_gradient)
90 |
91 | self.attentions = []
92 | self.attention_gradients = []
93 |
94 | def clear_cache(self):
95 | self.attentions = []
96 | self.attention_gradients = []
97 | def get_attention(self, module, input, output):
98 | self.attentions.append(output.cpu())
99 |
100 | def get_attention_gradient(self, module, grad_input, grad_output):
101 | self.attention_gradients.append(grad_input[0].cpu())
102 |
103 | def __call__(self, input_tensor, category_index):
104 | self.model.zero_grad()
105 | output = self.model(input_tensor)
106 | category_mask = torch.zeros(output.size())
107 | category_mask[:, category_index] = 1
108 | category_mask = category_mask.to(output.device)
109 |
110 | loss = (output*category_mask).sum()
111 | loss.backward()
112 | return grad_rollout(self.attentions, self.attention_gradients,
113 | self.discard_ratio)
114 |
115 |
116 | class VITAttentionGradRollout_Batch:
117 | def __init__(self, model, attention_layer_name='attn_drop',
118 | discard_ratio=0.9):
119 | self.model = model
120 | self.discard_ratio = discard_ratio
121 | self.hook_handles = []
122 | for name, module in self.model.named_modules():
123 | if attention_layer_name in name:
124 | self.hook_handles.append(module.register_forward_hook(self.get_attention))
125 | self.hook_handles.append(module.register_backward_hook(self.get_attention_gradient))
126 |
127 | self.attentions = []
128 | self.attention_gradients = []
129 |
130 |
131 | def remove_hooks(self):
132 | for i in range(len(self.hook_handles)):
133 | self.hook_handles[i].remove()
134 |
135 | def clear_cache(self):
136 | self.attentions = []
137 | self.attention_gradients = []
138 |
139 | def get_attention(self, module, input, output):
140 | # self.attentions.append(output.cpu())
141 | self.attentions.append(output)
142 |
143 | def get_attention_gradient(self, module, grad_input, grad_output):
144 | # self.attention_gradients.append(grad_input[0].cpu())
145 | self.attention_gradients.append(grad_input[0])
146 |
147 | # def __call__(self, input_tensor, category_indices):
148 | def __call__(self, input_tensor, top=True):
149 | self.model.zero_grad()
150 | output = self.model(input_tensor)
151 | class_idx = output.data.topk(1, dim=1)[1][0]
152 | # category_mask = torch.zeros(output.size())
153 | # category_mask[:, category_index] = 1
154 | # category_mask = F.one_hot(category_indices,num_classes =output.size(1))
155 | if top:
156 | category_mask = F.one_hot(class_idx,num_classes =output.size(1))
157 | else:
158 | raise NotImplementedError
159 | category_mask = category_mask.to(output.device)
160 | loss = (output*category_mask).sum()
161 | loss.backward()
162 | return grad_rollout_batch(self.attentions, self.attention_gradients,
163 | self.discard_ratio)
164 |
--------------------------------------------------------------------------------