├── .gitignore
├── LICENSE
├── README.md
├── generated
├── ga_adv_class_398.jpg
├── ga_adv_class_420.jpg
├── ga_adv_class_543.jpg
├── ga_fooling_class_340.jpg
├── ga_fooling_class_457.jpg
├── ga_fooling_class_483.jpg
├── targeted_adv_img_from_948_to_35.jpg
├── targeted_adv_img_from_948_to_62.jpg
├── targeted_adv_noise_from_948_to_35.jpg
├── targeted_adv_noise_from_948_to_62.jpg
├── untargeted_adv_img_from_13_to_19.jpg
├── untargeted_adv_img_from_390_to_397.jpg
├── untargeted_adv_noise_from_13_to_19.jpg
└── untargeted_adv_noise_from_390_to_397.jpg
├── input_images
├── apple.JPEG
├── bird.JPEG
├── cat_dog.png
├── eel.JPEG
├── snake.jpg
└── spider.png
└── src
├── fast_gradient_sign_targeted.py
├── fast_gradient_sign_untargeted.py
├── gradient_ascent_adv.py
├── gradient_ascent_fooling.py
└── misc_functions.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | env/
12 | build/
13 | develop-eggs/
14 | dist/
15 | downloads/
16 | eggs/
17 | .eggs/
18 | lib/
19 | lib64/
20 | parts/
21 | sdist/
22 | var/
23 | wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | local_settings.py
56 |
57 | # Flask stuff:
58 | instance/
59 | .webassets-cache
60 |
61 | # Scrapy stuff:
62 | .scrapy
63 |
64 | # Sphinx documentation
65 | docs/_build/
66 |
67 | # PyBuilder
68 | target/
69 |
70 | # Jupyter Notebook
71 | .ipynb_checkpoints
72 |
73 | # pyenv
74 | .python-version
75 |
76 | # celery beat schedule file
77 | celerybeat-schedule
78 |
79 | # SageMath parsed files
80 | *.sage.py
81 |
82 | # dotenv
83 | .env
84 |
85 | # virtualenv
86 | .venv
87 | venv/
88 | ENV/
89 |
90 | # Spyder project settings
91 | .spyderproject
92 | .spyproject
93 |
94 | # Rope project settings
95 | .ropeproject
96 |
97 | # mkdocs documentation
98 | /site
99 |
100 | # mypy
101 | .mypy_cache/
102 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Utku Ozbulak
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Convolutional Neural Network Adversarial Attacks
2 |
3 |
4 | **Note**: I am aware that there are some issues with the code, I will update this repository soon (Also will move away from cv2 to PIL).
5 |
6 | This repo is a branch off of [CNN Visualisations](https://github.com/utkuozbulak/pytorch-cnn-visualizations) because it was starting to get bloated. It contains following CNN adversarial attacks implemented in Pytorch:
7 |
8 | * Fast Gradient Sign, Untargeted [1]
9 | * Fast Gradient Sign, Targeted [1]
10 | * Gradient Ascent, Adversarial Images [2]
11 | * Gradient Ascent, Fooling Images (Unrecognizable images predicted as classes with high confidence) [2]
12 |
13 | It will also include more adverisarial attack and defenses techniques in the future as well.
14 |
15 | The code uses pretrained AlexNet in the model zoo. You can simply change it with your model but don't forget to change target class parameters as well.
16 |
17 | All images are pre-processed with mean and std of the ImageNet dataset before being fed to the model. None of the code uses GPU as these operations are quite fast (for a single image). You can make use of gpu with very little effort. The examples below include numbers in the brackets after the description, like *Mastiff (243)*, this number represents the class id in the ImageNet dataset.
18 |
19 | I tried to comment on the code as much as possible, if you have any issues understanding it or porting it, don't hesitate to reach out.
20 |
21 | Below, are some sample results for each operation.
22 |
23 | ## Fast Gradient Sign - Untargeted
24 | In this operation we update the original image with signs of the received gradient on the first layer. Untargeted version aims to reduce the confidence of the initial class. The code breaks as soon as the image stops being classified as the original label.
25 |
26 |
27 |
28 | Predicted as Eel (390) Confidence: 0.96 |
29 | Adversarial Noise |
30 | Predicted as Blowfish (397) Confidence: 0.81 |
31 |
32 |
33 | |
34 | |
35 | |
36 |
37 |
38 |
39 |
40 |
41 |
42 | Predicted as Snowbird (13) Confidence: 0.99 |
43 | Adversarial Noise |
44 | Predicted as Chickadee (19) Confidence: 0.95 |
45 |
46 |
47 | |
48 | |
49 | |
50 |
51 |
52 |
53 |
54 | ## Fast Gradient Sign - Targeted
55 | Targeted version of FGS works almost the same as the untargeted version. The only difference is that we do not try to minimize the original label but maximize the target label. The code breaks as soon as the image is predicted as the target class.
56 |
57 |
58 |
59 | Predicted as Apple (948) Confidence: 0.95 |
60 | Adversarial Noise |
61 | Predicted as Rock python (62) Confidence: 0.16 |
62 |
63 |
64 | |
65 | |
66 | |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 | Predicted as Apple (948) Confidence: 0.95 |
75 | Adversarial Noise |
76 | Predicted as Mud turtle (35) Confidence: 0.54 |
77 |
78 |
79 | |
80 | |
81 | |
82 |
83 |
84 |
85 |
86 |
87 |
88 | ## Gradient Ascent - Fooling Image Generation
89 | In this operation we start with a random image and continously update the image with targeted backpropagation (for a certain class) and stop when we achieve target confidence for that class. All of the below images are generated from pretrained AlexNet to fool it.
90 |
91 |
92 |
93 |
94 | Predicted as Zebra (340) Confidence: 0.94 |
95 | Predicted as Bow tie (457) Confidence: 0.95 |
96 | Predicted as Castle (483) Confidence: 0.99 |
97 |
98 |
99 | |
100 | |
101 | |
102 |
103 |
104 |
105 |
106 |
107 | ## Gradient Ascent - Adversarial Image Generation
108 | This operation works exactly same as the previous one. The only important thing is that keeping learning rate a bit smaller so that the image does not receive huge updates so that it will continue to look like the originial. As it can be seen from samples, on some images it is almost impossible to recognize the difference between two images but on others it can clearly be observed that something is wrong. All of the examples below were created from and tested on AlexNet to fool it.
109 |
110 |
111 |
112 |
113 | Predicted as Eel (390) Confidence: 0.96 |
114 | Predicted as Apple (948) Confidence: 0.95 |
115 | Predicted as Snowbird (13) Confidence: 0.99 |
116 |
117 |
118 | |
119 | |
120 | |
121 |
122 | Predicted as Banjo (420) Confidence: 0.99 |
123 | Predicted as Abacus (398) Confidence: 0.99 |
124 | Predicted as Dumbell (543) Confidence: 1 |
125 |
126 |
127 | |
128 | |
129 | |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 | ## Requirements:
138 | ```
139 | torch >= 0.2.0.post4
140 | torchvision >= 0.1.9
141 | numpy >= 1.13.0
142 | opencv >= 3.1.0
143 | ```
144 |
145 |
146 | ## References:
147 |
148 | [1] I. J. Goodfellow, J. Shlens, C. Szegedy. *Explaining and Harnessing Adversarial Examples* https://arxiv.org/abs/1412.6572
149 |
150 | [2] A. Nguyen, J. Yosinski, J. Clune. *Deep Neural Networks are Easily Fooled: High Confidence Predictions for Unrecognizable Images* https://arxiv.org/abs/1412.1897
151 |
--------------------------------------------------------------------------------
/generated/ga_adv_class_398.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_adv_class_398.jpg
--------------------------------------------------------------------------------
/generated/ga_adv_class_420.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_adv_class_420.jpg
--------------------------------------------------------------------------------
/generated/ga_adv_class_543.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_adv_class_543.jpg
--------------------------------------------------------------------------------
/generated/ga_fooling_class_340.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_fooling_class_340.jpg
--------------------------------------------------------------------------------
/generated/ga_fooling_class_457.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_fooling_class_457.jpg
--------------------------------------------------------------------------------
/generated/ga_fooling_class_483.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/ga_fooling_class_483.jpg
--------------------------------------------------------------------------------
/generated/targeted_adv_img_from_948_to_35.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/targeted_adv_img_from_948_to_35.jpg
--------------------------------------------------------------------------------
/generated/targeted_adv_img_from_948_to_62.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/targeted_adv_img_from_948_to_62.jpg
--------------------------------------------------------------------------------
/generated/targeted_adv_noise_from_948_to_35.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/targeted_adv_noise_from_948_to_35.jpg
--------------------------------------------------------------------------------
/generated/targeted_adv_noise_from_948_to_62.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/targeted_adv_noise_from_948_to_62.jpg
--------------------------------------------------------------------------------
/generated/untargeted_adv_img_from_13_to_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/untargeted_adv_img_from_13_to_19.jpg
--------------------------------------------------------------------------------
/generated/untargeted_adv_img_from_390_to_397.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/untargeted_adv_img_from_390_to_397.jpg
--------------------------------------------------------------------------------
/generated/untargeted_adv_noise_from_13_to_19.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/untargeted_adv_noise_from_13_to_19.jpg
--------------------------------------------------------------------------------
/generated/untargeted_adv_noise_from_390_to_397.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/generated/untargeted_adv_noise_from_390_to_397.jpg
--------------------------------------------------------------------------------
/input_images/apple.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/apple.JPEG
--------------------------------------------------------------------------------
/input_images/bird.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/bird.JPEG
--------------------------------------------------------------------------------
/input_images/cat_dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/cat_dog.png
--------------------------------------------------------------------------------
/input_images/eel.JPEG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/eel.JPEG
--------------------------------------------------------------------------------
/input_images/snake.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/snake.jpg
--------------------------------------------------------------------------------
/input_images/spider.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/utkuozbulak/pytorch-cnn-adversarial-attacks/8af67f124c904401947c90078481f93f3456e107/input_images/spider.png
--------------------------------------------------------------------------------
/src/fast_gradient_sign_targeted.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri Dec 16 01:24:11 2017
3 |
4 | @author: Utku Ozbulak - github.com/utkuozbulak
5 | """
6 | import os
7 | import numpy as np
8 | import cv2
9 |
10 | import torch
11 | from torch import nn
12 | from torch.autograd import Variable
13 | # from torch.autograd.gradcheck import zero_gradients # See processed_image.grad = None
14 |
15 | from misc_functions import preprocess_image, recreate_image, get_params
16 |
17 |
18 | class FastGradientSignTargeted():
19 | """
20 | Fast gradient sign untargeted adversarial attack, maximizes the target class activation
21 | with iterative grad sign updates
22 | """
23 | def __init__(self, model, alpha):
24 | self.model = model
25 | self.model.eval()
26 | # Movement multiplier per iteration
27 | self.alpha = alpha
28 | # Create the folder to export images if not exists
29 | if not os.path.exists('../generated'):
30 | os.makedirs('../generated')
31 |
32 | def generate(self, original_image, org_class, target_class):
33 | # I honestly dont know a better way to create a variable with specific value
34 | # Targeting the specific class
35 | im_label_as_var = Variable(torch.from_numpy(np.asarray([target_class])))
36 | # Define loss functions
37 | ce_loss = nn.CrossEntropyLoss()
38 | # Process image
39 | processed_image = preprocess_image(original_image)
40 | # Start iteration
41 | for i in range(10):
42 | print('Iteration:', str(i))
43 | # zero_gradients(x)
44 | # Zero out previous gradients
45 | # Can also use zero_gradients(x)
46 | processed_image.grad = None
47 | # Forward pass
48 | out = self.model(processed_image)
49 | # Calculate CE loss
50 | pred_loss = ce_loss(out, im_label_as_var)
51 | # Do backward pass
52 | pred_loss.backward()
53 | # Create Noise
54 | # Here, processed_image.grad.data is also the same thing is the backward gradient from
55 | # the first layer, can use that with hooks as well
56 | adv_noise = self.alpha * torch.sign(processed_image.grad.data)
57 | # Add noise to processed image
58 | processed_image.data = processed_image.data - adv_noise
59 |
60 | # Confirming if the image is indeed adversarial with added noise
61 | # This is necessary (for some cases) because when we recreate image
62 | # the values become integers between 1 and 255 and sometimes the adversariality
63 | # is lost in the recreation process
64 |
65 | # Generate confirmation image
66 | recreated_image = recreate_image(processed_image)
67 | # Process confirmation image
68 | prep_confirmation_image = preprocess_image(recreated_image)
69 | # Forward pass
70 | confirmation_out = self.model(prep_confirmation_image)
71 | # Get prediction
72 | _, confirmation_prediction = confirmation_out.data.max(1)
73 | # Get Probability
74 | confirmation_confidence = \
75 | nn.functional.softmax(confirmation_out)[0][confirmation_prediction].data.numpy()[0]
76 | # Convert tensor to int
77 | confirmation_prediction = confirmation_prediction.numpy()[0]
78 | # Check if the prediction is different than the original
79 | if confirmation_prediction == target_class:
80 | print('Original image was predicted as:', org_class,
81 | 'with adversarial noise converted to:', confirmation_prediction,
82 | 'and predicted with confidence of:', confirmation_confidence)
83 | # Create the image for noise as: Original image - generated image
84 | noise_image = original_image - recreated_image
85 | cv2.imwrite('../generated/targeted_adv_noise_from_' + str(org_class) + '_to_' +
86 | str(confirmation_prediction) + '.jpg', noise_image)
87 | # Write image
88 | cv2.imwrite('../generated/targeted_adv_img_from_' + str(org_class) + '_to_' +
89 | str(confirmation_prediction) + '.jpg', recreated_image)
90 | break
91 |
92 | return 1
93 |
94 |
95 | if __name__ == '__main__':
96 | target_example = 0 # Apple
97 | (original_image, prep_img, org_class, _, pretrained_model) =\
98 | get_params(target_example)
99 | target_class = 62 # Mud turtle
100 |
101 | FGS_untargeted = FastGradientSignTargeted(pretrained_model, 0.01)
102 | FGS_untargeted.generate(original_image, org_class, target_class)
103 |
--------------------------------------------------------------------------------
/src/fast_gradient_sign_untargeted.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Fri Dec 15 19:57:34 2017
3 |
4 | @author: Utku Ozbulak - github.com/utkuozbulak
5 | """
6 | import os
7 | import numpy as np
8 | import cv2
9 |
10 | import torch
11 | from torch import nn
12 | from torch.autograd import Variable
13 | # from torch.autograd.gradcheck import zero_gradients # See processed_image.grad = None
14 |
15 | from misc_functions import preprocess_image, recreate_image, get_params
16 |
17 |
18 | class FastGradientSignUntargeted():
19 | """
20 | Fast gradient sign untargeted adversarial attack, minimizes the initial class activation
21 | with iterative grad sign updates
22 | """
23 | def __init__(self, model, alpha):
24 | self.model = model
25 | self.model.eval()
26 | # Movement multiplier per iteration
27 | self.alpha = alpha
28 | # Create the folder to export images if not exists
29 | if not os.path.exists('../generated'):
30 | os.makedirs('../generated')
31 |
32 | def generate(self, original_image, im_label):
33 | # I honestly dont know a better way to create a variable with specific value
34 | im_label_as_var = Variable(torch.from_numpy(np.asarray([im_label])))
35 | # Define loss functions
36 | ce_loss = nn.CrossEntropyLoss()
37 | # Process image
38 | processed_image = preprocess_image(original_image)
39 | # Start iteration
40 | for i in range(10):
41 | print('Iteration:', str(i))
42 | # zero_gradients(x)
43 | # Zero out previous gradients
44 | # Can also use zero_gradients(x)
45 | processed_image.grad = None
46 | # Forward pass
47 | out = self.model(processed_image)
48 | # Calculate CE loss
49 | pred_loss = ce_loss(out, im_label_as_var)
50 | # Do backward pass
51 | pred_loss.backward()
52 | # Create Noise
53 | # Here, processed_image.grad.data is also the same thing is the backward gradient from
54 | # the first layer, can use that with hooks as well
55 | adv_noise = self.alpha * torch.sign(processed_image.grad.data)
56 | # Add Noise to processed image
57 | processed_image.data = processed_image.data + adv_noise
58 |
59 | # Confirming if the image is indeed adversarial with added noise
60 | # This is necessary (for some cases) because when we recreate image
61 | # the values become integers between 1 and 255 and sometimes the adversariality
62 | # is lost in the recreation process
63 |
64 | # Generate confirmation image
65 | recreated_image = recreate_image(processed_image)
66 | # Process confirmation image
67 | prep_confirmation_image = preprocess_image(recreated_image)
68 | # Forward pass
69 | confirmation_out = self.model(prep_confirmation_image)
70 | # Get prediction
71 | _, confirmation_prediction = confirmation_out.data.max(1)
72 | # Get Probability
73 | confirmation_confidence = \
74 | nn.functional.softmax(confirmation_out)[0][confirmation_prediction].data.numpy()[0]
75 | # Convert tensor to int
76 | confirmation_prediction = confirmation_prediction.numpy()[0]
77 | # Check if the prediction is different than the original
78 | if confirmation_prediction != im_label:
79 | print('Original image was predicted as:', im_label,
80 | 'with adversarial noise converted to:', confirmation_prediction,
81 | 'and predicted with confidence of:', confirmation_confidence)
82 | # Create the image for noise as: Original image - generated image
83 | noise_image = original_image - recreated_image
84 | cv2.imwrite('../generated/untargeted_adv_noise_from_' + str(im_label) + '_to_' +
85 | str(confirmation_prediction) + '.jpg', noise_image)
86 | # Write image
87 | cv2.imwrite('../generated/untargeted_adv_img_from_' + str(im_label) + '_to_' +
88 | str(confirmation_prediction) + '.jpg', recreated_image)
89 | break
90 |
91 | return 1
92 |
93 |
94 | if __name__ == '__main__':
95 | target_example = 2 # Eel
96 | (original_image, prep_img, target_class, _, pretrained_model) =\
97 | get_params(target_example)
98 |
99 | FGS_untargeted = FastGradientSignUntargeted(pretrained_model, 0.01)
100 | FGS_untargeted.generate(original_image, target_class)
101 |
--------------------------------------------------------------------------------
/src/gradient_ascent_adv.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Thu Oct 29 14:09:01 2017
3 |
4 | @author: Utku Ozbulak - github.com/utkuozbulak
5 | """
6 | import os
7 | import cv2
8 |
9 | from torch.optim import SGD
10 | from torchvision import models
11 | from torch.nn import functional
12 |
13 | from misc_functions import preprocess_image, recreate_image, get_params
14 |
15 |
16 | class DisguisedFoolingSampleGeneration():
17 | """
18 | Produces an image that maximizes a certain class with gradient ascent, breaks as soon as
19 | the target prediction confidence is captured
20 | """
21 | def __init__(self, model, initial_image, target_class, minimum_confidence):
22 | self.model = model
23 | self.model.eval()
24 | self.target_class = target_class
25 | self.minimum_confidence = minimum_confidence
26 | # Generate a random image
27 | self.initial_image = initial_image
28 | # Create the folder to export images if not exists
29 | if not os.path.exists('../generated'):
30 | os.makedirs('../generated')
31 |
32 | def generate(self):
33 | for i in range(1, 500):
34 | # Process image and return variable
35 | self.processed_image = preprocess_image(self.initial_image)
36 | # Define optimizer for the image
37 | optimizer = SGD([self.processed_image], lr=0.7)
38 | # Forward
39 | output = self.model(self.processed_image)
40 | # Get confidence from softmax
41 | target_confidence = functional.softmax(output)[0][self.target_class].data.numpy()[0]
42 | if target_confidence > self.minimum_confidence:
43 | # Reading the raw image and pushing it through model to see the prediction
44 | # this is needed because the format of preprocessed image is float and when
45 | # it is written back to file it is converted to uint8, so there is a chance that
46 | # there are some losses while writing
47 | confirmation_image = cv2.imread('../generated/ga_adv_class_' +
48 | str(self.target_class) + '.jpg', 1)
49 | # Preprocess image
50 | confirmation_processed_image = preprocess_image(confirmation_image)
51 | # Get prediction
52 | confirmation_output = self.model(confirmation_processed_image)
53 | # Get confidence
54 | softmax_confirmation = \
55 | functional.softmax(confirmation_output)[0][self.target_class].data.numpy()[0]
56 | if softmax_confirmation > self.minimum_confidence:
57 | print('Generated disguised fooling image with', "{0:.2f}".format(softmax_confirmation),
58 | 'confidence at', str(i) + 'th iteration.')
59 | break
60 | # Target specific class
61 | class_loss = -output[0, self.target_class]
62 | print('Iteration:', str(i), 'Target confidence', "{0:.4f}".format(target_confidence))
63 | # Zero grads
64 | self.model.zero_grad()
65 | # Backward
66 | class_loss.backward()
67 | # Update image
68 | optimizer.step()
69 | # Recreate image
70 | self.initial_image = recreate_image(self.processed_image)
71 | # Save image
72 | cv2.imwrite('../generated/ga_adv_class_' + str(self.target_class) + '.jpg',
73 | self.initial_image)
74 | return confirmation_image
75 |
76 |
77 | if __name__ == '__main__':
78 | target_example = 0 # Appple
79 | (original_image, prep_img, _, _, pretrained_model) =\
80 | get_params(target_example)
81 |
82 | fooling_target_class = 398 # Abacus
83 | min_confidence = 0.99
84 | fool = DisguisedFoolingSampleGeneration(pretrained_model,
85 | original_image,
86 | fooling_target_class,
87 | min_confidence)
88 | generated_image = fool.generate()
89 |
--------------------------------------------------------------------------------
/src/gradient_ascent_fooling.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Thu Oct 28 08:12:10 2017
3 |
4 | @author: Utku Ozbulak - github.com/utkuozbulak
5 | """
6 | import os
7 | import cv2
8 | import numpy as np
9 |
10 | from torch.optim import SGD
11 | from torchvision import models
12 | from torch.nn import functional
13 |
14 | from misc_functions import preprocess_image, recreate_image
15 |
16 |
17 | class FoolingSampleGeneration():
18 | """
19 | Produces an image that maximizes a certain class with gradient ascent, breaks as soon as
20 | the target prediction confidence is captured
21 | """
22 | def __init__(self, model, target_class, minimum_confidence):
23 | self.model = model
24 | self.model.eval()
25 | self.target_class = target_class
26 | self.minimum_confidence = minimum_confidence
27 | # Generate a random image
28 | self.created_image = np.uint8(np.random.uniform(0, 255, (224, 224, 3)))
29 | # Create the folder to export images if not exists
30 | if not os.path.exists('../generated'):
31 | os.makedirs('../generated')
32 |
33 | def generate(self):
34 | for i in range(1, 200):
35 | # Process image and return variable
36 | self.processed_image = preprocess_image(self.created_image)
37 | # Define optimizer for the image
38 | optimizer = SGD([self.processed_image], lr=6)
39 | # Forward
40 | output = self.model(self.processed_image)
41 | # Get confidence from softmax
42 | target_confidence = functional.softmax(output)[0][self.target_class].data.numpy()[0]
43 | if target_confidence > self.minimum_confidence:
44 | # Reading the raw image and pushing it through model to see the prediction
45 | # this is needed because the format of preprocessed image is float and when
46 | # it is written back to file it is converted to uint8, so there is a chance that
47 | # there are some losses while writing
48 | confirmation_image = cv2.imread('../generated/ga_fooling_class_' +
49 | str(self.target_class) + '.jpg', 1)
50 | # Preprocess image
51 | confirmation_processed_image = preprocess_image(confirmation_image)
52 | # Get prediction
53 | confirmation_output = self.model(confirmation_processed_image)
54 | # Get confidence
55 | softmax_confirmation = \
56 | functional.softmax(confirmation_output)[0][self.target_class].data.numpy()[0]
57 | if softmax_confirmation > self.minimum_confidence:
58 | print('Generated fooling image with', "{0:.2f}".format(softmax_confirmation),
59 | 'confidence at', str(i) + 'th iteration.')
60 | break
61 | # Target specific class
62 | class_loss = -output[0, self.target_class]
63 | print('Iteration:', str(i), 'Target Confidence', "{0:.4f}".format(target_confidence))
64 | # Zero grads
65 | self.model.zero_grad()
66 | # Backward
67 | class_loss.backward()
68 | # Update image
69 | optimizer.step()
70 | # Recreate image
71 | self.created_image = recreate_image(self.processed_image)
72 | # Save image
73 | cv2.imwrite('../generated/ga_fooling_class_' + str(self.target_class) + '.jpg',
74 | self.created_image)
75 | return self.processed_image
76 |
77 |
78 | if __name__ == '__main__':
79 | target_class = 483 # Castle
80 | pretrained_model = models.alexnet(pretrained=True)
81 | cig = FoolingSampleGeneration(pretrained_model, target_class, 0.99)
82 | cig.generate()
83 |
--------------------------------------------------------------------------------
/src/misc_functions.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on Thu Oct 21 11:09:09 2017
3 |
4 | @author: Utku Ozbulak - github.com/utkuozbulak
5 | """
6 | import copy
7 | import cv2
8 | import numpy as np
9 |
10 | import torch
11 | from torch.autograd import Variable
12 | from torchvision import models
13 |
14 |
15 | def preprocess_image(cv2im, resize_im=True):
16 | """
17 | Processes image for CNNs
18 |
19 | Args:
20 | PIL_img (PIL_img): Image to process
21 | resize_im (bool): Resize to 224 or not
22 | returns:
23 | im_as_var (Pytorch variable): Variable that contains processed float tensor
24 | """
25 | # mean and std list for channels (Imagenet)
26 | mean = [0.485, 0.456, 0.406]
27 | std = [0.229, 0.224, 0.225]
28 | # Resize image
29 | if resize_im:
30 | cv2im = cv2.resize(cv2im, (224, 224))
31 | im_as_arr = np.float32(cv2im)
32 | im_as_arr = np.ascontiguousarray(im_as_arr[..., ::-1])
33 | im_as_arr = im_as_arr.transpose(2, 0, 1) # Convert array to D,W,H
34 | # Normalize the channels
35 | for channel, _ in enumerate(im_as_arr):
36 | im_as_arr[channel] /= 255
37 | im_as_arr[channel] -= mean[channel]
38 | im_as_arr[channel] /= std[channel]
39 | # Convert to float tensor
40 | im_as_ten = torch.from_numpy(im_as_arr).float()
41 | # Add one more channel to the beginning. Tensor shape = 1,3,224,224
42 | im_as_ten.unsqueeze_(0)
43 | # Convert to Pytorch variable
44 | im_as_var = Variable(im_as_ten, requires_grad=True)
45 | return im_as_var
46 |
47 |
48 | def recreate_image(im_as_var):
49 | """
50 | Recreates images from a torch variable, sort of reverse preprocessing
51 |
52 | Args:
53 | im_as_var (torch variable): Image to recreate
54 |
55 | returns:
56 | recreated_im (numpy arr): Recreated image in array
57 | """
58 | reverse_mean = [-0.485, -0.456, -0.406]
59 | reverse_std = [1/0.229, 1/0.224, 1/0.225]
60 | recreated_im = copy.copy(im_as_var.data.numpy()[0])
61 | for c in range(3):
62 | recreated_im[c] /= reverse_std[c]
63 | recreated_im[c] -= reverse_mean[c]
64 | recreated_im[recreated_im > 1] = 1
65 | recreated_im[recreated_im < 0] = 0
66 | recreated_im = np.round(recreated_im * 255)
67 |
68 | recreated_im = np.uint8(recreated_im).transpose(1, 2, 0)
69 | # Convert RBG to GBR
70 | recreated_im = recreated_im[..., ::-1]
71 | return recreated_im
72 |
73 |
74 | def get_params(example_index):
75 | """
76 | Gets used variables for almost all visualizations, like the image, model etc.
77 |
78 | Args:
79 | example_index (int): Image id to use from examples
80 |
81 | returns:
82 | original_image (numpy arr): Original image read from the file
83 | prep_img (numpy_arr): Processed image
84 | target_class (int): Target class for the image
85 | file_name_to_export (string): File name to export the visualizations
86 | pretrained_model(Pytorch model): Model to use for the operations
87 | """
88 | # Pick one of the examples
89 | example_list = [['../input_images/apple.JPEG', 948],
90 | ['../input_images/eel.JPEG', 390],
91 | ['../input_images/bird.JPEG', 13]]
92 | selected_example = example_index
93 | img_path = example_list[selected_example][0]
94 | target_class = example_list[selected_example][1]
95 | file_name_to_export = img_path[img_path.rfind('/')+1:img_path.rfind('.')]
96 | # Read image
97 | original_image = cv2.imread(img_path, 1)
98 | # Process image
99 | prep_img = preprocess_image(original_image)
100 | # Define model
101 | pretrained_model = models.alexnet(pretrained=True)
102 | return (original_image,
103 | prep_img,
104 | target_class,
105 | file_name_to_export,
106 | pretrained_model)
107 |
--------------------------------------------------------------------------------