├── .github
└── workflows
│ └── main.yml
├── .gitignore
├── LICENSE
├── README.md
├── Untitled.ipynb
├── blacbox.ipynb
├── blacbox
├── __init__.py
├── activation_visualizers
│ ├── __init__.py
│ ├── gcam.py
│ └── gcampp.py
├── architectures
│ ├── __init__.py
│ ├── images
│ │ ├── black.png
│ │ └── dog.png
│ └── rclf.py
└── pixel_visualizers
│ ├── __init__.py
│ ├── gbp.py
│ └── saliency.py
├── images
├── gbp_another_ex_input.png
├── gbp_another_ex_output.png
├── gbp_input.png
├── gbp_output.png
├── gcam_hot.png
├── gcam_input.png
├── gcam_output.png
├── gcam_overlay.png
├── saliency_input.png
└── saliency_output.png
├── requirements.txt
├── setup.py
└── tests
├── __init__.py
└── test.py
/.github/workflows/main.yml:
--------------------------------------------------------------------------------
1 | # This is a basic workflow to help you get started with Actions
2 |
3 | name: build
4 |
5 | # Controls when the action will run.
6 | on:
7 | # Triggers the workflow on push or pull request events but only for the master branch
8 | push:
9 | branches: [ main ]
10 | pull_request:
11 | branches: [ main ]
12 |
13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel
14 | jobs:
15 | # This workflow contains a single job called "build"
16 | build:
17 | # The type of runner that the job will run on
18 | runs-on: ubuntu-latest
19 |
20 | # Steps represent a sequence of tasks that will be executed as part of the job
21 | steps:
22 | - uses: actions/checkout@v2
23 | - name: Install Python 3
24 | uses: actions/setup-python@v1
25 | with:
26 | python-version: 3.8
27 | - name: Install dependencies
28 | run: |
29 | python -m pip install --upgrade pip
30 | pip install -r requirements.txt
31 | - name: Run tests
32 | run: python3 -m unittest tests/test.py
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Rishab Mudliar
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # blacbox
2 |
3 | Making CNNs interpretable. Well because accuracy is not the best anymore
4 |
5 | ## Summary of Contents
6 |
7 |
8 | -
9 | About The Project
10 |
13 |
14 | - New!
15 | - Coming up
16 | -
17 | Getting Started
18 |
22 |
23 | - Usage
24 | - Example
25 | - Roadmap
26 | - Contributing
27 | - License
28 | - Contact
29 |
30 |
31 |
32 |
33 |
34 | ## About The Project
35 | Consider a scenario where a military commander wants to identify enemy tanks and he uses an object-detection algorithm. So based on some data he reports an accuracy of 99%. Naturally he would be happy(enemies definetely wouldn't). But when he puts this algorithm to use he fails terribly. Turns out the data used had images of tanks only at night but not during day so instead of learning about the tank the algorithm learned to distinguish day from night. He would have lost but thanks to blacbox he uses our techniques to save his day!
36 |
37 |
38 | ### Built With
39 | * [OpenCV](https://opencv.org/)
40 | * [PyTorch](https://pytorch.org/)
41 | * [Numpy, Pandas](https://pandas.pydata.org/)
42 |
43 | ## New!
44 | - The project is fresh so the whole thing is new right now:p
45 |
46 | ## Coming up
47 | - Improved version of Grad-CAM (Score-CAM, Grad-CAM++)
48 | - A relatively new method just for object-detection algorithms. The DRISE method! (Paper: https://arxiv.org/abs/2006.03204)
49 | - Added support for Image-Captioning and VQA models
50 |
51 |
52 | ## Getting Started
53 |
54 | ### Prerequisites
55 | - Python versions 3.6-3.10
56 |
57 | ### Installation
58 | 1. Install the latest version from github
59 | ```
60 | pip install git+https://github.com/lazyCodes7/blacbox.git
61 | ```
62 | Note: Soon a stable version will be released to PyPI
63 | ## Usage
64 | ### 1. Saliency Maps
65 | A saliency map is a way to measure the spatial support of a particular class in each image. It is the oldest and most frequently used explanation method for interpreting the predictions of convolutional neural networks. The saliency map is built using gradients of the output over the input.
66 | Paper link: https://arxiv.org/abs/1911.11293
67 |
68 | ```python
69 |
70 | from blacbox import Saliency
71 | import matplotlib.pyplot as plt
72 | import torchvison.models as models
73 | # Load a model
74 | model = models.resnet50(pretrained = True)
75 |
76 | # Pass the model to Saliency generator
77 | maps = Saliency(
78 | model,
79 | device = 'cuda'
80 | )
81 |
82 | # Images.shape = (B,C,H,W), class_idx = used to specify the class to calculate gradients against
83 | # saliencie return_type = np.ndarray(B,H,W,C)
84 | saliencie = maps.reveal(
85 | images = images,
86 | class_idx = "keepmax"
87 | )
88 |
89 | ```
90 | ### Input image
91 |
92 |
93 | ### Output saliency
94 |
95 |
96 | ### 2. Guided Backpropagation
97 | Guided backpropagation is an interesting variation of vanilla backprop where we zero out the negative gradients during backprop instead of transferring because of the fact that we are only interested in the positive influences.
98 | Paper link: https://arxiv.org/abs/1412.6806
99 |
100 | ```python
101 | from blacbox import GuidedBackPropagation
102 | from blacbox import RaceClassifier
103 | import torchvison.models as models
104 |
105 | clf = RaceClassifier()
106 |
107 | ## Provide batch of images
108 | images = torch.stack((image1, image2), axis = 0)
109 |
110 | ## Notice we are providing clf.model, the reason being clf itself that we used isn't of type nn.Module
111 | gbp = GuidedBackPropagation(
112 | model = clf.model,
113 | device = device
114 | )
115 |
116 | ## Provide batch
117 | grads = gbp.reveal(
118 | images = images
119 | )
120 |
121 | ## Provide path
122 | grads = gbp.reveal(
123 | path = 'blacbox/architectures/images/dog.png'
124 | )
125 | ```
126 | ### Inputs
127 | 
128 |
129 | ### Outputs
130 | 
131 |
132 | ### 3. Grad-CAM
133 | Grad-CAM is another interesting way of visualizing the activation maps based on the gradients of a target and provides a coarse heatmap based on that.
134 | Paper link: https://arxiv.org/abs/1610.02391
135 |
136 | ```python
137 | from blacbox import RaceClassifier
138 | from blacbox import GradCAM
139 |
140 | ## Initalize
141 | clf = RaceClassifier()
142 | ## Provide batch of images
143 | images = torch.stack((image1, image2), axis = 0)
144 |
145 | ## For interpolation refer to torch.interpolate methods
146 | gcam = GCAM(
147 | model = clf.model,
148 | interpolate = 'bilinear',
149 | device = 'cuda'
150 | )
151 |
152 | ## Generating the heatmaps
153 | heatmaps = gcam.reveal(
154 | images = images,
155 | module = clf.model.layer4[0].conv1,
156 | class_idx = 'keepmax',
157 | colormap = 'hot'
158 | )
159 |
160 | # module: refers to the module to compute gradients and fmaps
161 | # colormap: type of colormap to apply use gcam.colormap_dict to see valid types
162 |
163 | ## Overlay on images
164 | ret_images = gcam.overlay(
165 | images = images,
166 | heatmaps = heatmaps,
167 | is_np = True
168 | )
169 | # is_np: To convert images to numpy for concat
170 |
171 | ```
172 | ### Input
173 |
174 |
175 | ### Heatmaps
176 | 
177 |
178 | ### Overlay
179 | 
180 |
181 | ### Note:
182 | For using raceclassifier please download the pretrained_model [here](https://github.com/lazyCodes7/blacbox/tree/main/blacbox/pretrained_models) and provide it as a path
183 | ```python
184 | from blacbox import RaceClassifier
185 | clf = RaceClassifier(model_path = "")
186 | ```
187 |
188 | ## Example
189 | 1. Clone the repo.
190 | ```
191 | git clone https://github.com/lazyCodes7/blacbox.git
192 | ```
193 | 2. Use jupyter to try out the notebook titled blacbox.ipynb!
194 | ```
195 | jupyter notebook
196 | ```
197 |
198 |
199 | ## Roadmap
200 |
201 | See the [open issues](https://github.com/lazyCodes7/blacbox/issues) for a list of proposed features (and known issues).
202 |
203 |
204 |
205 |
206 | ## Contributing
207 |
208 | Contributions are what make the open source community such an amazing place to be learn, inspire, and create. Any contributions you make are **greatly appreciated**.
209 |
210 | 1. Fork the Project
211 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`)
212 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`)
213 | 4. Push to the Branch (`git push origin feature/AmazingFeature`)
214 | 5. Open a Pull Request
215 |
216 |
217 |
218 |
219 | ## License
220 |
221 | Distributed under the MIT License. See `LICENSE` for more information.
222 |
223 |
224 |
225 |
226 | ## Contact
227 |
228 | Rishab Mudliar - [@cheesetaco19](https://twitter.com/cheesetaco19) - rishabmudliar@gmail.com
229 |
230 | Telegram: [lazyCodes7](https://t.me/lazyCodes7)
231 |
--------------------------------------------------------------------------------
/blacbox/__init__.py:
--------------------------------------------------------------------------------
1 | import blacbox
2 | from blacbox.activation_visualizers import GCAM
3 | from blacbox.pixel_visualizers import GuidedBackPropagation
4 | from blacbox.pixel_visualizers import Saliency
5 | from blacbox.architectures import RaceClassifier
6 | from blacbox.activation_visualizers import GCAM_plus
--------------------------------------------------------------------------------
/blacbox/activation_visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .gcam import GCAM
2 | from .gcampp import GCAM_plus
--------------------------------------------------------------------------------
/blacbox/activation_visualizers/gcam.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import torchvision.transforms as transforms
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import cv2
8 | class GCAM:
9 | def __init__(self, model, interpolate = 'bilinear', device = 'cpu'):
10 | '''
11 | Description: Method to visualize CNN layers w.r.t output
12 |
13 | Args:
14 | model -> nn.Module
15 | interpolate -> interpolation types (refer to https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html)
16 | device -> 'cpu'/'cuda' (device to set model to)
17 |
18 | Example:
19 | gcam = GCAM(model, interpolate = 'bilinear', device = 'cuda')
20 |
21 | '''
22 | self.device = device
23 | self.model = model.to(self.device)
24 | self.interpolate = interpolate
25 | self.activations = {}
26 | self.gradients = {}
27 | self.colormap_dict = {
28 | "autumn": cv2.COLORMAP_AUTUMN,
29 | "bone": cv2.COLORMAP_BONE,
30 | "jet" : cv2.COLORMAP_JET,
31 | "winter" : cv2.COLORMAP_WINTER,
32 | "rainbow" : cv2.COLORMAP_RAINBOW,
33 | "ocean" : cv2.COLORMAP_OCEAN,
34 | "summer" : cv2.COLORMAP_SUMMER,
35 | "spring" : cv2.COLORMAP_SPRING,
36 | "cool" : cv2.COLORMAP_COOL,
37 | "hsv" : cv2.COLORMAP_HSV,
38 | "pink" : cv2.COLORMAP_PINK,
39 | "hot" : cv2.COLORMAP_HOT
40 |
41 | }
42 |
43 | def get_activations(self,key):
44 | '''
45 | Description: Method to get activations for a particular layer
46 |
47 | Args:
48 | key -> nn.Module(the module to calculate activations for)
49 |
50 | '''
51 | def hook(module, input, out):
52 | self.activations[key] = out.detach()
53 | return hook
54 | def get_gradients(self,key):
55 | '''
56 | Description: Method to get gradients for a particular layer
57 |
58 | Args:
59 | key -> nn.Module(the module to calculate gradients for)
60 |
61 | '''
62 | def hook(module, grad_in, grad_out):
63 | self.gradients[key] = grad_out[0]
64 | return hook
65 | def reveal(
66 | self,
67 | images = None,
68 | module = None,
69 | class_idx = 0,
70 | path = None,
71 | colormap = "jet"
72 | ):
73 | '''
74 | Description: Where it all takes place and we get the output visualize a CNN layer.
75 |
76 | Args:
77 | images -> type = torch.Tensor, shape = (B,C,H,W)
78 | path -> str(required if images is None)
79 | module -> nn.Module (to perform Grad-CAM on)
80 | class_idx -> int (class_idx to calculate gradients on)
81 | colormap -> colormap to apply on heatmap
82 |
83 | Types of colormaps:
84 | "autumn": cv2.COLORMAP_AUTUMN,
85 | "bone": cv2.COLORMAP_BONE,
86 | "jet" : cv2.COLORMAP_JET,
87 | "winter" : cv2.COLORMAP_WINTER,
88 | "rainbow" : cv2.COLORMAP_RAINBOW,
89 | "ocean" : cv2.COLORMAP_OCEAN,
90 | "summer" : cv2.COLORMAP_SUMMER,
91 | "spring" : cv2.COLORMAP_SPRING,
92 | "cool" : cv2.COLORMAP_COOL,
93 | "hsv" : cv2.COLORMAP_HSV,
94 | "pink" : cv2.COLORMAP_PINK,
95 | "hot" : cv2.COLORMAP_HOT
96 |
97 | Types of class_idx:
98 | 'keepmax' -> Visualize the max score from classification
99 | 'keepmin' -> Visualize the min score from classification
100 | '0-len(output)' -> For visualizing a certain output
101 |
102 | Example:
103 | from blacbox import GCAM
104 | gcam = GCAM(model, interpolate = True)
105 | output_viz = gcam.reveal(
106 | images = images,
107 | module = model.resnet.layer[0].conv1,
108 | class_idx = 'keepmax',
109 | colormap = 'hot'
110 |
111 | )
112 | '''
113 | # Raise error if both path and images are provided
114 | if(path!=None and images!=None):
115 | raise ValueError("Image batches cannot be passed when path is provided")
116 |
117 | # If path is provided
118 | elif(path!=None):
119 | images = cv2.imread(path)
120 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
121 | images = self.preprocess_image(images).unsqueeze(0)
122 |
123 | # If batches of image is provided
124 | if(images!=None):
125 | if(isinstance(images, torch.Tensor) == False):
126 | raise ValueError("reveal() expects images to be of type tensor with shape (B,C,H,W)")
127 | if(isinstance(module, nn.Module)):
128 | key = str(module)
129 | self.ac_handler = module.register_forward_hook(self.get_activations(key))
130 | self.grad_handler = module.register_backward_hook(self.get_gradients(key))
131 | gcams = self.retrieve_gcams(images, class_idx, key, colormap)
132 | return gcams
133 | else:
134 | error = "Module argument expects variable of type nn.Module"
135 | raise ValueError(error)
136 |
137 | # If None then raise errors
138 | else:
139 | raise ValueError("Either path or images need to be provided to reveal GCAM visualization.")
140 |
141 | @staticmethod
142 | def overlay(heatmaps, images, influence = 0.3, is_np = True):
143 | if(is_np):
144 | images = images.cpu().detach().permute(0,2,3,1).numpy()
145 | images = cv2.normalize(
146 | images,
147 | None,
148 | alpha = 0,
149 | beta = 255,
150 | norm_type = cv2.NORM_MINMAX,
151 | dtype = cv2.CV_32F
152 | )
153 | superimposed_img = heatmaps*(influence) + images
154 | return superimposed_img.astype(int)
155 |
156 | def preprocess_image(self, images):
157 | '''
158 | Description:
159 | Takes in an images and applies some transformations to it
160 |
161 | Args:
162 | images -> np.ndarray
163 |
164 | '''
165 | if(isinstance(images, np.ndarray)):
166 | transform = transforms.Compose([
167 | transforms.ToPILImage(),
168 | transforms.Resize((224, 224)),
169 | transforms.ToTensor(),
170 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
171 | ])
172 |
173 | images = transform(images)
174 | return images
175 |
176 | else:
177 | raise AttributeError("Preprocessing requires a np.ndarray")
178 |
179 | def retrieve_gcams(self, images, class_idx, key, colormap, apply_cmap = True):
180 | """
181 | Retrieving gcams from the images provided
182 | """
183 | gcams = []
184 | # Selecting the index
185 |
186 | for image in images:
187 | # Taking the images
188 | self.image = image.unsqueeze(0).to(self.device)
189 | self.image.requires_grad = True
190 | output = self.model(self.image)
191 | if(class_idx == "keepmax"):
192 | class_idx = output.argmax(dim = 1)
193 | elif(class_idx == "keepmin"):
194 | class_idx = output.argmin(dim = 1)
195 |
196 | # Calculating gradients w.r.t to the idx selected
197 | class_required = output[0,class_idx]
198 | class_required.backward()
199 |
200 |
201 | # Retrieving gradients and activation maps
202 | fmaps = self.activations[key]
203 | weights = self.gradients[key]
204 |
205 | # Freeing up the gradients from the images
206 | self.image.grad.zero_()
207 |
208 |
209 | # Avg pooling as in the paper
210 | weights = F.adaptive_avg_pool2d(weights, 1)
211 |
212 | # Retrieving the gcam outputs
213 | gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
214 | gcam = F.relu(gcam)
215 | gcam = F.interpolate(
216 | gcam, self.image.shape[2:4], mode=self.interpolate, align_corners=False
217 | )
218 | gcam = gcam.cpu().squeeze(0).detach().permute(1,2,0).numpy()
219 | gcam = cv2.normalize(
220 | gcam,
221 | None,
222 | alpha = 0,
223 | beta = 255,
224 | norm_type = cv2.NORM_MINMAX,
225 | dtype = cv2.CV_8UC3
226 | )
227 | gcam = cv2.applyColorMap(gcam, self.colormap_dict[colormap])
228 | #gcam = (gcam*255).astype(np.uint8)
229 | gcams.append(gcam)
230 |
231 | # Removing the hooks, activations, gradients stored
232 | self.ac_handler.remove()
233 | self.grad_handler.remove()
234 | self.activations = {}
235 | self.gradients = {}
236 | return np.array(gcams)
237 |
238 |
239 |
--------------------------------------------------------------------------------
/blacbox/activation_visualizers/gcampp.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import torchvision.transforms as transforms
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import cv2
8 | from .gcam import GCAM
9 | class GCAM_plus(GCAM):
10 | def __init__(self, model, interpolate = 'bilinear', device = 'cpu'):
11 | super().__init__(model, interpolate, device)
12 |
13 | def retrieve_gcams(self, images, class_idx, key, colormap, apply_cmap = True):
14 | """
15 | Retrieving gcams from the images provided
16 | """
17 | gcams = []
18 | # Selecting the index
19 |
20 | for image in images:
21 | # Taking the images
22 | self.image = image.unsqueeze(0).to(self.device)
23 | self.image.requires_grad = True
24 | output = self.model(self.image)
25 | output = F.softmax(output)
26 | if(class_idx == "keepmax"):
27 | class_idx = output.argmax(dim = 1)
28 | elif(class_idx == "keepmin"):
29 | class_idx = output.argmin(dim = 1)
30 |
31 | # Calculating gradients w.r.t to the idx selected
32 | cr = output[0,class_idx]
33 | cr.backward()
34 |
35 |
36 | # Retrieving gradients and activation maps
37 | fmaps = self.activations[key]
38 | weights = self.gradients[key]
39 | weights = weights/cr
40 |
41 | # Freeing up the gradients from the images
42 | self.image.grad.zero_()
43 | weights = weights.squeeze().unsqueeze(0)
44 | b, k, u, v = weights.size()
45 | a_num = weights.pow(2)
46 | a_dm = weights.pow(2).mul(2) + \
47 | fmaps.mul(weights.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1)
48 | a_dm = torch.where(a_dm != 0.0, a_dm, torch.ones_like(a_dm))
49 | alpha = a_num.div(a_dm)
50 | weights = (alpha*weights).view(b, k, u*v).sum(-1).view(b, k, 1, 1)
51 |
52 | # Retrieving the gcam outputs
53 | gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True)
54 | gcam = F.relu(gcam)
55 |
56 | gcam = F.interpolate(
57 | gcam, self.image.shape[2:4], mode=self.interpolate, align_corners=False
58 | )
59 | gcam = gcam.cpu().squeeze(0).detach().permute(1,2,0).numpy()
60 |
61 | gcam = cv2.normalize(
62 | gcam,
63 | None,
64 | alpha = 0,
65 | beta = 255,
66 | norm_type = cv2.NORM_MINMAX,
67 | dtype = cv2.CV_8UC3
68 | )
69 | gcam = cv2.applyColorMap(gcam, self.colormap_dict[colormap])
70 | #gcam = (gcam*255).astype(np.uint8)
71 |
72 |
73 | gcams.append(gcam)
74 |
75 | # Removing the hooks, activations, gradients stored
76 | self.ac_handler.remove()
77 | self.grad_handler.remove()
78 | self.activations = {}
79 | self.gradients = {}
80 | return np.array(gcams)
--------------------------------------------------------------------------------
/blacbox/architectures/__init__.py:
--------------------------------------------------------------------------------
1 | from .rclf import RaceClassifier
--------------------------------------------------------------------------------
/blacbox/architectures/images/black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/blacbox/architectures/images/black.png
--------------------------------------------------------------------------------
/blacbox/architectures/images/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/blacbox/architectures/images/dog.png
--------------------------------------------------------------------------------
/blacbox/architectures/rclf.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import warnings
3 | warnings.filterwarnings("ignore")
4 | import os.path
5 | import pandas as pd
6 | import torch
7 | import torch.nn as nn
8 | import numpy as np
9 | import torchvision
10 | from torchvision import datasets, models, transforms
11 | import os
12 | import argparse
13 | import cv2
14 | class RaceClassifier:
15 | def __init__(self,model_path = 'blacbox/pretrained_models/checkpoint.pt'):
16 | ## Loading the Resnet34 pretrained model
17 | self.model = torchvision.models.resnet34(pretrained=True)
18 | # Selecting the device
19 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20 | # Output has lots of features like the age, race which sum upto 18
21 | self.model.fc = nn.Linear(self.model.fc.in_features, 18)
22 | #model_fair_7.load_state_dict(torch.load('fair_face_models/fairface_alldata_20191111.pt'))
23 |
24 | # Loading the pretrained model trained on the Fairface dataset
25 | self.model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu')))
26 | self.model = self.model.to(self.device)
27 |
28 | # Setting eval mode to not compute gradients
29 | self.model.eval()
30 |
31 | def preprocess_image(self, image):
32 | transform = transforms.Compose([
33 | transforms.ToPILImage(),
34 | transforms.Resize((224, 224)),
35 | transforms.ToTensor(),
36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37 | ])
38 |
39 | image = transform(image)
40 | image = image.unsqueeze(0)
41 | return image
42 |
43 |
44 | # Predict function used to predict the race of the face detected
45 | def predict(self,image):
46 | self.image = image.to(self.device)
47 | self.image = self.image.requires_grad_()
48 |
49 | # fair
50 | self.outputs = self.model(self.image)
51 | self.grad_outputs = self.outputs
52 | self.outputs = self.outputs.cpu().detach().numpy()
53 | #self.outputs = np.squeeze(self.outputs)
54 |
55 | '''
56 | self.race_outputs = self.outputs[:7]
57 | self.gender_outputs = self.outputs[7:9]
58 | self.age_outputs = self.outputs[9:18]
59 | race_score = np.exp(self.race_outputs) / np.sum(np.exp(self.race_outputs))
60 | gender_score = np.exp(self.gender_outputs) / np.sum(np.exp(self.gender_outputs))
61 | age_score = np.exp(self.age_outputs) / np.sum(np.exp(self.age_outputs))
62 | race_pred = np.argmax(race_score, dim = 1)
63 | gender_pred = np.argmax(gender_score, dim = 1)
64 | age_pred = np.argmax(age_score)
65 | return (race_pred,gender_pred,age_pred)
66 | '''
67 |
68 | def __call__(self, image):
69 | preds = self.predict(image)
70 | return self.grad_outputs
--------------------------------------------------------------------------------
/blacbox/pixel_visualizers/__init__.py:
--------------------------------------------------------------------------------
1 | from .gbp import GuidedBackPropagation
2 | from .saliency import Saliency
--------------------------------------------------------------------------------
/blacbox/pixel_visualizers/gbp.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import torchvision.transforms as transforms
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import cv2
8 | class _BaseWrapper(object):
9 |
10 | def __init__(self, model, device = 'cpu'):
11 | '''
12 | Description: BaseWrapper for BackProp and Guided BackProp
13 |
14 | Args:
15 |
16 | model -> nn.Module
17 | device -> cpu/cuda
18 | '''
19 | super(_BaseWrapper, self).__init__()
20 | self.device = device
21 | self.model = model.to(self.device)
22 | self.handlers = [] # a set of hook function handlers
23 |
24 | def _encode_one_hot(self, ids):
25 | '''
26 | Description: One hot encoder for the model outputs
27 |
28 | Args:
29 |
30 | ids -> int (id to encode)
31 | '''
32 | one_hot = torch.zeros_like(self.logits).to(self.device)
33 | one_hot[0, ids] = 1.0
34 | #print(one_hot)
35 | return one_hot
36 | def forward(self, image):
37 | '''
38 | Description: Method for feed-forward process
39 |
40 | Args:
41 | image -> torch.Tensor(image to be processed through the neural network)
42 | '''
43 | self.image_shape = image.shape[2:]
44 | self.logits = self.model(image)
45 | self.probs = F.softmax(self.logits, dim=1)
46 | return self.probs.sort(dim=1, descending=True)
47 |
48 | def backward(self, ids):
49 | '''
50 | Description: Backward prop for the neural network
51 |
52 | Args:
53 | ids -> int(id to calculate the gradient for)
54 |
55 | '''
56 | one_hot = self._encode_one_hot(ids)
57 | self.model.zero_grad()
58 | self.logits.backward(gradient=one_hot, retain_graph=True)
59 |
60 | def get_gradients(self):
61 | raise NotImplementedError
62 |
63 | def remove_hook(self):
64 | """
65 | Remove all the forward/backward hook functions
66 | """
67 | for handle in self.handlers:
68 | handle.remove()
69 |
70 | class BackPropagation(_BaseWrapper):
71 | def forward(self, image):
72 | '''
73 | Description: Method for feed-forward process
74 |
75 | Args:
76 | image -> torch.Tensor(image to be processed through the neural network)
77 | '''
78 | self.image = image.requires_grad_()
79 | return super(BackPropagation, self).forward(self.image)
80 |
81 | def get_gradients(self):
82 | """
83 | Generating the gradients for the image
84 | """
85 | gradient = self.image.grad.clone()
86 | self.image.grad.zero_()
87 | return gradient
88 |
89 | class GuidedBackPropagation(BackPropagation):
90 | """
91 | "Striving for Simplicity: the All Convolutional Net"
92 | https://arxiv.org/pdf/1412.6806.pdf
93 | Look at Figure 1 on page 8.
94 | """
95 | def backward_hook(self,module, grad_in, grad_out):
96 | """
97 | PyTorch hook function for zeroing out the negative gradients during backprop
98 | """
99 | # Cut off negative gradients
100 | if isinstance(module, nn.ReLU):
101 | #print(module)
102 | grad_in[0][grad_in[0] > 0] = 1
103 | return (F.relu(grad_in[0]),)
104 |
105 | def __init__(self, model, device):
106 | '''
107 | Description: GBP module for CNN visualization
108 |
109 | Args:
110 |
111 | model -> nn.Module
112 | device -> cpu/cuda
113 |
114 | Example:
115 | gbp = GuidedBackPropagation(model = model, device = 'cuda')
116 | '''
117 | super(GuidedBackPropagation, self).__init__(model, device)
118 |
119 |
120 |
121 |
122 | def reveal(self, images = None, path = None):
123 | '''
124 | Description: function for forwarding the images and getting the gradients to visualize
125 |
126 | Args:
127 | images -> type = torch.Tensor, shape = (B,C,H,W)
128 | path -> str(required if images is None)
129 |
130 | Example:
131 | gbp = GuidedBackPropagation(model = model, device = 'cuda')
132 | gradients = gbp.reveal(images = images)
133 | '''
134 | for module in self.model.named_modules():
135 | self.handlers.append(module[1].register_backward_hook(self.backward_hook))
136 |
137 | # Raise error if both path and images are provided
138 | if(path!=None and images!=None):
139 | raise ValueError("Image batches cannot be passed when path is provided")
140 |
141 | # If path is provided
142 | elif(path!=None):
143 | images = cv2.imread(path)
144 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
145 | images = self.preprocess_image(images)
146 |
147 | # If batches of image is provided
148 | if(images!=None):
149 | grads = self.retrieve_gradients(images)
150 | return np.array(grads)
151 |
152 | # If None then raise errors
153 | else:
154 | raise ValueError("Either path or images need to be provided to reveal GBP visualization.")
155 | def preprocess_image(self, image):
156 | '''
157 | Description:
158 | Takes in an image and applies some transformations to it
159 |
160 | Args:
161 | image -> np.ndarray
162 |
163 | '''
164 | if(isinstance(image, np.ndarray)):
165 | transform = transforms.Compose([
166 | transforms.ToPILImage(),
167 | transforms.Resize((224, 224)),
168 | transforms.ToTensor(),
169 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
170 | ])
171 |
172 | image = transform(image)
173 | image = image.unsqueeze(0)
174 | return image
175 |
176 | else:
177 | raise ValueError("Preprocessing requires type np.ndarray")
178 |
179 | def retrieve_gradients(self, images):
180 | """
181 | Generating the gradients
182 | """
183 | grads = []
184 | for image in images:
185 | image = image.unsqueeze(0).to(self.device)
186 | ids = self.forward(image)
187 | self.backward(ids=ids.indices[0,0])
188 | gradient = self.get_gradients()
189 | gradient = gradient.cpu().squeeze().permute(1,2,0).detach().numpy()
190 | gradient -= gradient.min()
191 | gradient /= gradient.max()
192 | gradient *= 255.0
193 | gradient = gradient.astype(int)
194 | grads.append(gradient)
195 |
196 | self.remove_hook()
197 | self.handlers = []
198 | return grads
199 |
200 |
201 |
202 |
203 |
--------------------------------------------------------------------------------
/blacbox/pixel_visualizers/saliency.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import torch
4 | import torchvision.transforms as transforms
5 | import torch.nn.functional as F
6 | import torch.nn as nn
7 | import cv2
8 | import logging
9 | import numpy
10 | class Saliency:
11 | def __init__(
12 | self,
13 | model,
14 | device = 'cpu'):
15 | '''
16 | Description:
17 | 1. Pass the model to the saliency visualizer
18 | 2. set the device to CPU/GPU
19 |
20 | Args:
21 | model -> nn.Module
22 | device -> cpu/cuda
23 | '''
24 | self.device = device
25 | self.model = model.to(self.device)
26 |
27 | def reveal(
28 | self,
29 | images = None,
30 | path = None,
31 | class_idx = 'keepmax'):
32 | '''
33 | Description:
34 | The class takes in images of type tensor or a path and returns the saliency for the same
35 |
36 | Args:
37 | images -> type = torch.Tensor, shape = (B,C,H,W)
38 | path -> str(required if images is None)
39 | class_idx -> int(class to visualize)/str
40 |
41 |
42 | Types of class_idx:
43 | 'keepmax' -> Visualize the max score from classification
44 | 'keepmin' -> Visualize the min score from classification
45 | '0-len(output)' -> For visualizing a certain output
46 |
47 | '''
48 |
49 | # Raise error if both path and images are provided
50 | if(path!=None and images!=None):
51 | raise ValueError("Image batches cannot be passed when path is provided")
52 |
53 | # If path is provided
54 | elif(path!=None):
55 | images = cv2.imread(path)
56 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB)
57 | images = self.preprocess_image(images)
58 |
59 | # If batches of image is provided
60 | if(images!=None):
61 | saliencies = self.retrieve_saliencies(images, class_idx)
62 | return np.array(saliencies)
63 |
64 | # If None then raise errors
65 | else:
66 | raise AttributeError("Either path or images need to be provided to reveal Saliency visualization.")
67 |
68 | def preprocess_image(self, image):
69 | '''
70 | Description:
71 | Takes in an image and applies some transformations to it
72 |
73 | Args:
74 | image -> np.ndarray
75 |
76 | '''
77 | if(isinstance(image, numpy.ndarray)):
78 | transform = transforms.Compose([
79 | transforms.ToPILImage(),
80 | transforms.Resize((224, 224)),
81 | transforms.ToTensor(),
82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
83 | ])
84 |
85 | image = transform(image)
86 | image = image.unsqueeze(0)
87 | return image
88 |
89 | else:
90 | raise ValueError("Preprocessing requires type np.ndarray")
91 |
92 | def retrieve_saliencies(self, images, class_idx):
93 | saliencies = []
94 | for image in images:
95 | self.image = image.unsqueeze(0).to(self.device)
96 | self.image.requires_grad = True
97 | '''
98 | Calculating the gradients and computing the
99 | max across each channel and returning the saliency map
100 | '''
101 | self.output = self.model(self.image)
102 | if(isinstance(class_idx, str)):
103 | if(class_idx == 'keepmax'):
104 | class_idx = self.output.argmax(dim = 1)
105 |
106 | elif(class_idx == 'keepmin'):
107 | class_idx = self.output.argmin(dim = 1)
108 |
109 | else:
110 | error = "class_idx only supports two str arguments\n.\
111 | 1. keepmax\n \
112 | 2. keepmin\n \
113 | "
114 | raise ValueError(error)
115 | self.output[0, class_idx].backward()
116 | saliency, _ = torch.max(self.image.grad.data.abs(), dim=1)
117 | saliency = saliency.reshape(self.image.shape[2],self.image.shape[3],1)
118 | saliency = saliency.cpu().numpy()
119 | saliencies.append(saliency)
120 | return saliencies
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/images/gbp_another_ex_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_another_ex_input.png
--------------------------------------------------------------------------------
/images/gbp_another_ex_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_another_ex_output.png
--------------------------------------------------------------------------------
/images/gbp_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_input.png
--------------------------------------------------------------------------------
/images/gbp_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_output.png
--------------------------------------------------------------------------------
/images/gcam_hot.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_hot.png
--------------------------------------------------------------------------------
/images/gcam_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_input.png
--------------------------------------------------------------------------------
/images/gcam_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_output.png
--------------------------------------------------------------------------------
/images/gcam_overlay.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_overlay.png
--------------------------------------------------------------------------------
/images/saliency_input.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/saliency_input.png
--------------------------------------------------------------------------------
/images/saliency_output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/saliency_output.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cycler
2 | kiwisolver
3 | numpy
4 | opencv-python
5 | packaging
6 | pandas
7 | pillow
8 | pyparsing
9 | python-dateutil
10 | pytz
11 | setuptools-scm
12 | six
13 | tomli
14 | torch>=1.8.0
15 | torchvision>=0.8.1
16 | typing_extensions
17 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 | from os import path
3 |
4 |
5 | with open('requirements.txt') as f:
6 | reqs = f.read().splitlines()
7 | with open('README.md', encoding='utf-8') as fr:
8 | long_description = fr.read()
9 |
10 | setup(
11 | name="blacbox",
12 | version="0.1.0",
13 | description="A visualization library to make CNNs more interpretable",
14 | long_description=long_description,
15 | long_description_content_type="text/markdown",
16 | url="https://github.com/lazyCodes7/blacbox",
17 | author="Rishab Mudliar",
18 | author_email="rishabmudliar@gmail.com",
19 | license="MIT",
20 | classifiers=[
21 | "Intended Audience :: Developers",
22 | "License :: OSI Approved :: MIT License",
23 | "Programming Language :: Python",
24 | "Programming Language :: Python :: 3",
25 | "Programming Language :: Python :: 3.6",
26 | "Programming Language :: Python :: 3.7",
27 | "Programming Language :: Python :: 3.8",
28 | "Programming Language :: Python :: 3.9",
29 | "Operating System :: OS Independent"
30 | ],
31 | packages=find_packages(),
32 | include_package_data=True,
33 | install_requires=reqs,
34 | )
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/tests/__init__.py
--------------------------------------------------------------------------------
/tests/test.py:
--------------------------------------------------------------------------------
1 | import unittest
2 | import math
3 | from blacbox import GCAM
4 | from blacbox import Saliency
5 | from blacbox import GuidedBackPropagation
6 | import numpy as np
7 | import torchvision.models as models
8 | import torchvision.transforms as transforms
9 | import cv2
10 | import torch
11 | class Tester(unittest.TestCase):
12 | def setUp(self):
13 | image1 = cv2.imread('blacbox/architectures/images/black.png')
14 | image1= cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
15 | image2 = cv2.imread('blacbox/architectures/images/dog.png')
16 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
17 | image1 = self.preprocess_image(image1)
18 | image2 = self.preprocess_image(image2)
19 | self.images = torch.stack((image1, image2), axis = 0)
20 | self.model_test1 = models.resnet18(pretrained = True)
21 | self.model_test2 = models.resnet50(pretrained=True)
22 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23 |
24 |
25 | def test_gcam(self):
26 | gcam = GCAM(
27 | model = self.model_test2,
28 | interpolate = 'bilinear',
29 | device = self.device
30 | )
31 | heatmap = gcam.reveal(
32 | images = self.images,
33 | module = self.model_test2.layer4[0].conv1,
34 | class_idx = 'keepmax',
35 | colormap = 'hot'
36 | )
37 | self.assertEqual(type(heatmap), np.ndarray)
38 |
39 | def test_gbp(self):
40 | gbp = GuidedBackPropagation(
41 | model = self.model_test1,
42 | device = self.device
43 | )
44 | generated_img = gbp.reveal(
45 | path = 'blacbox/architectures/images/black.png' ,
46 | )
47 | self.assertEqual(type(generated_img), np.ndarray)
48 |
49 | def test_saliency(self):
50 | sal = Saliency(
51 | model = self.model_test1,
52 | device=self.device
53 | )
54 | maps = sal.reveal(images=self.images, class_idx = 'keepmax')
55 | self.assertEqual(type(maps), np.ndarray)
56 |
57 | def test_out_size(self):
58 | sal = Saliency(
59 | self.model_test1
60 | )
61 | maps = sal.reveal(images=self.images, class_idx = 'keepmax')
62 | self.assertEqual(len(maps.shape), 4)
63 |
64 |
65 |
66 | def preprocess_image(self, image):
67 | transform = transforms.Compose([
68 |
69 | transforms.ToPILImage(),
70 | transforms.Resize((224, 224)),
71 | transforms.ToTensor(),
72 |
73 | ])
74 |
75 | image = transform(image)
76 | return image
77 |
78 |
79 | if __name__ == '__main__':
80 | unittest.main()
81 |
--------------------------------------------------------------------------------