├── .github └── workflows │ └── main.yml ├── .gitignore ├── LICENSE ├── README.md ├── Untitled.ipynb ├── blacbox.ipynb ├── blacbox ├── __init__.py ├── activation_visualizers │ ├── __init__.py │ ├── gcam.py │ └── gcampp.py ├── architectures │ ├── __init__.py │ ├── images │ │ ├── black.png │ │ └── dog.png │ └── rclf.py └── pixel_visualizers │ ├── __init__.py │ ├── gbp.py │ └── saliency.py ├── images ├── gbp_another_ex_input.png ├── gbp_another_ex_output.png ├── gbp_input.png ├── gbp_output.png ├── gcam_hot.png ├── gcam_input.png ├── gcam_output.png ├── gcam_overlay.png ├── saliency_input.png └── saliency_output.png ├── requirements.txt ├── setup.py └── tests ├── __init__.py └── test.py /.github/workflows/main.yml: -------------------------------------------------------------------------------- 1 | # This is a basic workflow to help you get started with Actions 2 | 3 | name: build 4 | 5 | # Controls when the action will run. 6 | on: 7 | # Triggers the workflow on push or pull request events but only for the master branch 8 | push: 9 | branches: [ main ] 10 | pull_request: 11 | branches: [ main ] 12 | 13 | # A workflow run is made up of one or more jobs that can run sequentially or in parallel 14 | jobs: 15 | # This workflow contains a single job called "build" 16 | build: 17 | # The type of runner that the job will run on 18 | runs-on: ubuntu-latest 19 | 20 | # Steps represent a sequence of tasks that will be executed as part of the job 21 | steps: 22 | - uses: actions/checkout@v2 23 | - name: Install Python 3 24 | uses: actions/setup-python@v1 25 | with: 26 | python-version: 3.8 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | pip install -r requirements.txt 31 | - name: Run tests 32 | run: python3 -m unittest tests/test.py -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Rishab Mudliar 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # blacbox 2 | 3 | Making CNNs interpretable. Well because accuracy is not the best anymore 4 | 5 | ## Summary of Contents 6 |
7 |
    8 |
  1. 9 | About The Project 10 | 13 |
  2. 14 |
  3. New!
  4. 15 |
  5. Coming up
  6. 16 |
  7. 17 | Getting Started 18 | 22 |
  8. 23 |
  9. Usage
  10. 24 |
  11. Example
  12. 25 |
  13. Roadmap
  14. 26 |
  15. Contributing
  16. 27 |
  17. License
  18. 28 |
  19. Contact
  20. 29 | 30 |
31 |
32 | 33 | 34 | ## About The Project 35 | Consider a scenario where a military commander wants to identify enemy tanks and he uses an object-detection algorithm. So based on some data he reports an accuracy of 99%. Naturally he would be happy(enemies definetely wouldn't). But when he puts this algorithm to use he fails terribly. Turns out the data used had images of tanks only at night but not during day so instead of learning about the tank the algorithm learned to distinguish day from night. He would have lost but thanks to blacbox he uses our techniques to save his day! 36 | 37 | 38 | ### Built With 39 | * [OpenCV](https://opencv.org/) 40 | * [PyTorch](https://pytorch.org/) 41 | * [Numpy, Pandas](https://pandas.pydata.org/) 42 | 43 | ## New! 44 | - The project is fresh so the whole thing is new right now:p 45 | 46 | ## Coming up 47 | - Improved version of Grad-CAM (Score-CAM, Grad-CAM++) 48 | - A relatively new method just for object-detection algorithms. The DRISE method! (Paper: https://arxiv.org/abs/2006.03204) 49 | - Added support for Image-Captioning and VQA models 50 | 51 | 52 | ## Getting Started 53 | 54 | ### Prerequisites 55 | - Python versions 3.6-3.10 56 | 57 | ### Installation 58 | 1. Install the latest version from github 59 | ``` 60 | pip install git+https://github.com/lazyCodes7/blacbox.git 61 | ``` 62 | Note: Soon a stable version will be released to PyPI 63 | ## Usage 64 | ### 1. Saliency Maps 65 | A saliency map is a way to measure the spatial support of a particular class in each image. It is the oldest and most frequently used explanation method for interpreting the predictions of convolutional neural networks. The saliency map is built using gradients of the output over the input. 66 | Paper link: https://arxiv.org/abs/1911.11293 67 | 68 | ```python 69 | 70 | from blacbox import Saliency 71 | import matplotlib.pyplot as plt 72 | import torchvison.models as models 73 | # Load a model 74 | model = models.resnet50(pretrained = True) 75 | 76 | # Pass the model to Saliency generator 77 | maps = Saliency( 78 | model, 79 | device = 'cuda' 80 | ) 81 | 82 | # Images.shape = (B,C,H,W), class_idx = used to specify the class to calculate gradients against 83 | # saliencie return_type = np.ndarray(B,H,W,C) 84 | saliencie = maps.reveal( 85 | images = images, 86 | class_idx = "keepmax" 87 | ) 88 | 89 | ``` 90 | ### Input image 91 | 92 | 93 | ### Output saliency 94 | 95 | 96 | ### 2. Guided Backpropagation 97 | Guided backpropagation is an interesting variation of vanilla backprop where we zero out the negative gradients during backprop instead of transferring because of the fact that we are only interested in the positive influences. 98 | Paper link: https://arxiv.org/abs/1412.6806 99 | 100 | ```python 101 | from blacbox import GuidedBackPropagation 102 | from blacbox import RaceClassifier 103 | import torchvison.models as models 104 | 105 | clf = RaceClassifier() 106 | 107 | ## Provide batch of images 108 | images = torch.stack((image1, image2), axis = 0) 109 | 110 | ## Notice we are providing clf.model, the reason being clf itself that we used isn't of type nn.Module 111 | gbp = GuidedBackPropagation( 112 | model = clf.model, 113 | device = device 114 | ) 115 | 116 | ## Provide batch 117 | grads = gbp.reveal( 118 | images = images 119 | ) 120 | 121 | ## Provide path 122 | grads = gbp.reveal( 123 | path = 'blacbox/architectures/images/dog.png' 124 | ) 125 | ``` 126 | ### Inputs 127 | 128 | 129 | ### Outputs 130 | 131 | 132 | ### 3. Grad-CAM 133 | Grad-CAM is another interesting way of visualizing the activation maps based on the gradients of a target and provides a coarse heatmap based on that. 134 | Paper link: https://arxiv.org/abs/1610.02391 135 | 136 | ```python 137 | from blacbox import RaceClassifier 138 | from blacbox import GradCAM 139 | 140 | ## Initalize 141 | clf = RaceClassifier() 142 | ## Provide batch of images 143 | images = torch.stack((image1, image2), axis = 0) 144 | 145 | ## For interpolation refer to torch.interpolate methods 146 | gcam = GCAM( 147 | model = clf.model, 148 | interpolate = 'bilinear', 149 | device = 'cuda' 150 | ) 151 | 152 | ## Generating the heatmaps 153 | heatmaps = gcam.reveal( 154 | images = images, 155 | module = clf.model.layer4[0].conv1, 156 | class_idx = 'keepmax', 157 | colormap = 'hot' 158 | ) 159 | 160 | # module: refers to the module to compute gradients and fmaps 161 | # colormap: type of colormap to apply use gcam.colormap_dict to see valid types 162 | 163 | ## Overlay on images 164 | ret_images = gcam.overlay( 165 | images = images, 166 | heatmaps = heatmaps, 167 | is_np = True 168 | ) 169 | # is_np: To convert images to numpy for concat 170 | 171 | ``` 172 | ### Input 173 | 174 | 175 | ### Heatmaps 176 | 177 | 178 | ### Overlay 179 | 180 | 181 | ### Note: 182 | For using raceclassifier please download the pretrained_model [here](https://github.com/lazyCodes7/blacbox/tree/main/blacbox/pretrained_models) and provide it as a path 183 | ```python 184 | from blacbox import RaceClassifier 185 | clf = RaceClassifier(model_path = "") 186 | ``` 187 | 188 | ## Example 189 | 1. Clone the repo. 190 | ``` 191 | git clone https://github.com/lazyCodes7/blacbox.git 192 | ``` 193 | 2. Use jupyter to try out the notebook titled blacbox.ipynb! 194 | ``` 195 | jupyter notebook 196 | ``` 197 | 198 | 199 | ## Roadmap 200 | 201 | See the [open issues](https://github.com/lazyCodes7/blacbox/issues) for a list of proposed features (and known issues). 202 | 203 | 204 | 205 | 206 | ## Contributing 207 | 208 | Contributions are what make the open source community such an amazing place to be learn, inspire, and create. Any contributions you make are **greatly appreciated**. 209 | 210 | 1. Fork the Project 211 | 2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) 212 | 3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) 213 | 4. Push to the Branch (`git push origin feature/AmazingFeature`) 214 | 5. Open a Pull Request 215 | 216 | 217 | 218 | 219 | ## License 220 | 221 | Distributed under the MIT License. See `LICENSE` for more information. 222 | 223 | 224 | 225 | 226 | ## Contact 227 | 228 | Rishab Mudliar - [@cheesetaco19](https://twitter.com/cheesetaco19) - rishabmudliar@gmail.com 229 | 230 | Telegram: [lazyCodes7](https://t.me/lazyCodes7) 231 | -------------------------------------------------------------------------------- /blacbox/__init__.py: -------------------------------------------------------------------------------- 1 | import blacbox 2 | from blacbox.activation_visualizers import GCAM 3 | from blacbox.pixel_visualizers import GuidedBackPropagation 4 | from blacbox.pixel_visualizers import Saliency 5 | from blacbox.architectures import RaceClassifier 6 | from blacbox.activation_visualizers import GCAM_plus -------------------------------------------------------------------------------- /blacbox/activation_visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gcam import GCAM 2 | from .gcampp import GCAM_plus -------------------------------------------------------------------------------- /blacbox/activation_visualizers/gcam.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import cv2 8 | class GCAM: 9 | def __init__(self, model, interpolate = 'bilinear', device = 'cpu'): 10 | ''' 11 | Description: Method to visualize CNN layers w.r.t output 12 | 13 | Args: 14 | model -> nn.Module 15 | interpolate -> interpolation types (refer to https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html) 16 | device -> 'cpu'/'cuda' (device to set model to) 17 | 18 | Example: 19 | gcam = GCAM(model, interpolate = 'bilinear', device = 'cuda') 20 | 21 | ''' 22 | self.device = device 23 | self.model = model.to(self.device) 24 | self.interpolate = interpolate 25 | self.activations = {} 26 | self.gradients = {} 27 | self.colormap_dict = { 28 | "autumn": cv2.COLORMAP_AUTUMN, 29 | "bone": cv2.COLORMAP_BONE, 30 | "jet" : cv2.COLORMAP_JET, 31 | "winter" : cv2.COLORMAP_WINTER, 32 | "rainbow" : cv2.COLORMAP_RAINBOW, 33 | "ocean" : cv2.COLORMAP_OCEAN, 34 | "summer" : cv2.COLORMAP_SUMMER, 35 | "spring" : cv2.COLORMAP_SPRING, 36 | "cool" : cv2.COLORMAP_COOL, 37 | "hsv" : cv2.COLORMAP_HSV, 38 | "pink" : cv2.COLORMAP_PINK, 39 | "hot" : cv2.COLORMAP_HOT 40 | 41 | } 42 | 43 | def get_activations(self,key): 44 | ''' 45 | Description: Method to get activations for a particular layer 46 | 47 | Args: 48 | key -> nn.Module(the module to calculate activations for) 49 | 50 | ''' 51 | def hook(module, input, out): 52 | self.activations[key] = out.detach() 53 | return hook 54 | def get_gradients(self,key): 55 | ''' 56 | Description: Method to get gradients for a particular layer 57 | 58 | Args: 59 | key -> nn.Module(the module to calculate gradients for) 60 | 61 | ''' 62 | def hook(module, grad_in, grad_out): 63 | self.gradients[key] = grad_out[0] 64 | return hook 65 | def reveal( 66 | self, 67 | images = None, 68 | module = None, 69 | class_idx = 0, 70 | path = None, 71 | colormap = "jet" 72 | ): 73 | ''' 74 | Description: Where it all takes place and we get the output visualize a CNN layer. 75 | 76 | Args: 77 | images -> type = torch.Tensor, shape = (B,C,H,W) 78 | path -> str(required if images is None) 79 | module -> nn.Module (to perform Grad-CAM on) 80 | class_idx -> int (class_idx to calculate gradients on) 81 | colormap -> colormap to apply on heatmap 82 | 83 | Types of colormaps: 84 | "autumn": cv2.COLORMAP_AUTUMN, 85 | "bone": cv2.COLORMAP_BONE, 86 | "jet" : cv2.COLORMAP_JET, 87 | "winter" : cv2.COLORMAP_WINTER, 88 | "rainbow" : cv2.COLORMAP_RAINBOW, 89 | "ocean" : cv2.COLORMAP_OCEAN, 90 | "summer" : cv2.COLORMAP_SUMMER, 91 | "spring" : cv2.COLORMAP_SPRING, 92 | "cool" : cv2.COLORMAP_COOL, 93 | "hsv" : cv2.COLORMAP_HSV, 94 | "pink" : cv2.COLORMAP_PINK, 95 | "hot" : cv2.COLORMAP_HOT 96 | 97 | Types of class_idx: 98 | 'keepmax' -> Visualize the max score from classification 99 | 'keepmin' -> Visualize the min score from classification 100 | '0-len(output)' -> For visualizing a certain output 101 | 102 | Example: 103 | from blacbox import GCAM 104 | gcam = GCAM(model, interpolate = True) 105 | output_viz = gcam.reveal( 106 | images = images, 107 | module = model.resnet.layer[0].conv1, 108 | class_idx = 'keepmax', 109 | colormap = 'hot' 110 | 111 | ) 112 | ''' 113 | # Raise error if both path and images are provided 114 | if(path!=None and images!=None): 115 | raise ValueError("Image batches cannot be passed when path is provided") 116 | 117 | # If path is provided 118 | elif(path!=None): 119 | images = cv2.imread(path) 120 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 121 | images = self.preprocess_image(images).unsqueeze(0) 122 | 123 | # If batches of image is provided 124 | if(images!=None): 125 | if(isinstance(images, torch.Tensor) == False): 126 | raise ValueError("reveal() expects images to be of type tensor with shape (B,C,H,W)") 127 | if(isinstance(module, nn.Module)): 128 | key = str(module) 129 | self.ac_handler = module.register_forward_hook(self.get_activations(key)) 130 | self.grad_handler = module.register_backward_hook(self.get_gradients(key)) 131 | gcams = self.retrieve_gcams(images, class_idx, key, colormap) 132 | return gcams 133 | else: 134 | error = "Module argument expects variable of type nn.Module" 135 | raise ValueError(error) 136 | 137 | # If None then raise errors 138 | else: 139 | raise ValueError("Either path or images need to be provided to reveal GCAM visualization.") 140 | 141 | @staticmethod 142 | def overlay(heatmaps, images, influence = 0.3, is_np = True): 143 | if(is_np): 144 | images = images.cpu().detach().permute(0,2,3,1).numpy() 145 | images = cv2.normalize( 146 | images, 147 | None, 148 | alpha = 0, 149 | beta = 255, 150 | norm_type = cv2.NORM_MINMAX, 151 | dtype = cv2.CV_32F 152 | ) 153 | superimposed_img = heatmaps*(influence) + images 154 | return superimposed_img.astype(int) 155 | 156 | def preprocess_image(self, images): 157 | ''' 158 | Description: 159 | Takes in an images and applies some transformations to it 160 | 161 | Args: 162 | images -> np.ndarray 163 | 164 | ''' 165 | if(isinstance(images, np.ndarray)): 166 | transform = transforms.Compose([ 167 | transforms.ToPILImage(), 168 | transforms.Resize((224, 224)), 169 | transforms.ToTensor(), 170 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 171 | ]) 172 | 173 | images = transform(images) 174 | return images 175 | 176 | else: 177 | raise AttributeError("Preprocessing requires a np.ndarray") 178 | 179 | def retrieve_gcams(self, images, class_idx, key, colormap, apply_cmap = True): 180 | """ 181 | Retrieving gcams from the images provided 182 | """ 183 | gcams = [] 184 | # Selecting the index 185 | 186 | for image in images: 187 | # Taking the images 188 | self.image = image.unsqueeze(0).to(self.device) 189 | self.image.requires_grad = True 190 | output = self.model(self.image) 191 | if(class_idx == "keepmax"): 192 | class_idx = output.argmax(dim = 1) 193 | elif(class_idx == "keepmin"): 194 | class_idx = output.argmin(dim = 1) 195 | 196 | # Calculating gradients w.r.t to the idx selected 197 | class_required = output[0,class_idx] 198 | class_required.backward() 199 | 200 | 201 | # Retrieving gradients and activation maps 202 | fmaps = self.activations[key] 203 | weights = self.gradients[key] 204 | 205 | # Freeing up the gradients from the images 206 | self.image.grad.zero_() 207 | 208 | 209 | # Avg pooling as in the paper 210 | weights = F.adaptive_avg_pool2d(weights, 1) 211 | 212 | # Retrieving the gcam outputs 213 | gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) 214 | gcam = F.relu(gcam) 215 | gcam = F.interpolate( 216 | gcam, self.image.shape[2:4], mode=self.interpolate, align_corners=False 217 | ) 218 | gcam = gcam.cpu().squeeze(0).detach().permute(1,2,0).numpy() 219 | gcam = cv2.normalize( 220 | gcam, 221 | None, 222 | alpha = 0, 223 | beta = 255, 224 | norm_type = cv2.NORM_MINMAX, 225 | dtype = cv2.CV_8UC3 226 | ) 227 | gcam = cv2.applyColorMap(gcam, self.colormap_dict[colormap]) 228 | #gcam = (gcam*255).astype(np.uint8) 229 | gcams.append(gcam) 230 | 231 | # Removing the hooks, activations, gradients stored 232 | self.ac_handler.remove() 233 | self.grad_handler.remove() 234 | self.activations = {} 235 | self.gradients = {} 236 | return np.array(gcams) 237 | 238 | 239 | -------------------------------------------------------------------------------- /blacbox/activation_visualizers/gcampp.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import cv2 8 | from .gcam import GCAM 9 | class GCAM_plus(GCAM): 10 | def __init__(self, model, interpolate = 'bilinear', device = 'cpu'): 11 | super().__init__(model, interpolate, device) 12 | 13 | def retrieve_gcams(self, images, class_idx, key, colormap, apply_cmap = True): 14 | """ 15 | Retrieving gcams from the images provided 16 | """ 17 | gcams = [] 18 | # Selecting the index 19 | 20 | for image in images: 21 | # Taking the images 22 | self.image = image.unsqueeze(0).to(self.device) 23 | self.image.requires_grad = True 24 | output = self.model(self.image) 25 | output = F.softmax(output) 26 | if(class_idx == "keepmax"): 27 | class_idx = output.argmax(dim = 1) 28 | elif(class_idx == "keepmin"): 29 | class_idx = output.argmin(dim = 1) 30 | 31 | # Calculating gradients w.r.t to the idx selected 32 | cr = output[0,class_idx] 33 | cr.backward() 34 | 35 | 36 | # Retrieving gradients and activation maps 37 | fmaps = self.activations[key] 38 | weights = self.gradients[key] 39 | weights = weights/cr 40 | 41 | # Freeing up the gradients from the images 42 | self.image.grad.zero_() 43 | weights = weights.squeeze().unsqueeze(0) 44 | b, k, u, v = weights.size() 45 | a_num = weights.pow(2) 46 | a_dm = weights.pow(2).mul(2) + \ 47 | fmaps.mul(weights.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1) 48 | a_dm = torch.where(a_dm != 0.0, a_dm, torch.ones_like(a_dm)) 49 | alpha = a_num.div(a_dm) 50 | weights = (alpha*weights).view(b, k, u*v).sum(-1).view(b, k, 1, 1) 51 | 52 | # Retrieving the gcam outputs 53 | gcam = torch.mul(fmaps, weights).sum(dim=1, keepdim=True) 54 | gcam = F.relu(gcam) 55 | 56 | gcam = F.interpolate( 57 | gcam, self.image.shape[2:4], mode=self.interpolate, align_corners=False 58 | ) 59 | gcam = gcam.cpu().squeeze(0).detach().permute(1,2,0).numpy() 60 | 61 | gcam = cv2.normalize( 62 | gcam, 63 | None, 64 | alpha = 0, 65 | beta = 255, 66 | norm_type = cv2.NORM_MINMAX, 67 | dtype = cv2.CV_8UC3 68 | ) 69 | gcam = cv2.applyColorMap(gcam, self.colormap_dict[colormap]) 70 | #gcam = (gcam*255).astype(np.uint8) 71 | 72 | 73 | gcams.append(gcam) 74 | 75 | # Removing the hooks, activations, gradients stored 76 | self.ac_handler.remove() 77 | self.grad_handler.remove() 78 | self.activations = {} 79 | self.gradients = {} 80 | return np.array(gcams) -------------------------------------------------------------------------------- /blacbox/architectures/__init__.py: -------------------------------------------------------------------------------- 1 | from .rclf import RaceClassifier -------------------------------------------------------------------------------- /blacbox/architectures/images/black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/blacbox/architectures/images/black.png -------------------------------------------------------------------------------- /blacbox/architectures/images/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/blacbox/architectures/images/dog.png -------------------------------------------------------------------------------- /blacbox/architectures/rclf.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import warnings 3 | warnings.filterwarnings("ignore") 4 | import os.path 5 | import pandas as pd 6 | import torch 7 | import torch.nn as nn 8 | import numpy as np 9 | import torchvision 10 | from torchvision import datasets, models, transforms 11 | import os 12 | import argparse 13 | import cv2 14 | class RaceClassifier: 15 | def __init__(self,model_path = 'blacbox/pretrained_models/checkpoint.pt'): 16 | ## Loading the Resnet34 pretrained model 17 | self.model = torchvision.models.resnet34(pretrained=True) 18 | # Selecting the device 19 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 20 | # Output has lots of features like the age, race which sum upto 18 21 | self.model.fc = nn.Linear(self.model.fc.in_features, 18) 22 | #model_fair_7.load_state_dict(torch.load('fair_face_models/fairface_alldata_20191111.pt')) 23 | 24 | # Loading the pretrained model trained on the Fairface dataset 25 | self.model.load_state_dict(torch.load(model_path,map_location=torch.device('cpu'))) 26 | self.model = self.model.to(self.device) 27 | 28 | # Setting eval mode to not compute gradients 29 | self.model.eval() 30 | 31 | def preprocess_image(self, image): 32 | transform = transforms.Compose([ 33 | transforms.ToPILImage(), 34 | transforms.Resize((224, 224)), 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 37 | ]) 38 | 39 | image = transform(image) 40 | image = image.unsqueeze(0) 41 | return image 42 | 43 | 44 | # Predict function used to predict the race of the face detected 45 | def predict(self,image): 46 | self.image = image.to(self.device) 47 | self.image = self.image.requires_grad_() 48 | 49 | # fair 50 | self.outputs = self.model(self.image) 51 | self.grad_outputs = self.outputs 52 | self.outputs = self.outputs.cpu().detach().numpy() 53 | #self.outputs = np.squeeze(self.outputs) 54 | 55 | ''' 56 | self.race_outputs = self.outputs[:7] 57 | self.gender_outputs = self.outputs[7:9] 58 | self.age_outputs = self.outputs[9:18] 59 | race_score = np.exp(self.race_outputs) / np.sum(np.exp(self.race_outputs)) 60 | gender_score = np.exp(self.gender_outputs) / np.sum(np.exp(self.gender_outputs)) 61 | age_score = np.exp(self.age_outputs) / np.sum(np.exp(self.age_outputs)) 62 | race_pred = np.argmax(race_score, dim = 1) 63 | gender_pred = np.argmax(gender_score, dim = 1) 64 | age_pred = np.argmax(age_score) 65 | return (race_pred,gender_pred,age_pred) 66 | ''' 67 | 68 | def __call__(self, image): 69 | preds = self.predict(image) 70 | return self.grad_outputs -------------------------------------------------------------------------------- /blacbox/pixel_visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | from .gbp import GuidedBackPropagation 2 | from .saliency import Saliency -------------------------------------------------------------------------------- /blacbox/pixel_visualizers/gbp.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import cv2 8 | class _BaseWrapper(object): 9 | 10 | def __init__(self, model, device = 'cpu'): 11 | ''' 12 | Description: BaseWrapper for BackProp and Guided BackProp 13 | 14 | Args: 15 | 16 | model -> nn.Module 17 | device -> cpu/cuda 18 | ''' 19 | super(_BaseWrapper, self).__init__() 20 | self.device = device 21 | self.model = model.to(self.device) 22 | self.handlers = [] # a set of hook function handlers 23 | 24 | def _encode_one_hot(self, ids): 25 | ''' 26 | Description: One hot encoder for the model outputs 27 | 28 | Args: 29 | 30 | ids -> int (id to encode) 31 | ''' 32 | one_hot = torch.zeros_like(self.logits).to(self.device) 33 | one_hot[0, ids] = 1.0 34 | #print(one_hot) 35 | return one_hot 36 | def forward(self, image): 37 | ''' 38 | Description: Method for feed-forward process 39 | 40 | Args: 41 | image -> torch.Tensor(image to be processed through the neural network) 42 | ''' 43 | self.image_shape = image.shape[2:] 44 | self.logits = self.model(image) 45 | self.probs = F.softmax(self.logits, dim=1) 46 | return self.probs.sort(dim=1, descending=True) 47 | 48 | def backward(self, ids): 49 | ''' 50 | Description: Backward prop for the neural network 51 | 52 | Args: 53 | ids -> int(id to calculate the gradient for) 54 | 55 | ''' 56 | one_hot = self._encode_one_hot(ids) 57 | self.model.zero_grad() 58 | self.logits.backward(gradient=one_hot, retain_graph=True) 59 | 60 | def get_gradients(self): 61 | raise NotImplementedError 62 | 63 | def remove_hook(self): 64 | """ 65 | Remove all the forward/backward hook functions 66 | """ 67 | for handle in self.handlers: 68 | handle.remove() 69 | 70 | class BackPropagation(_BaseWrapper): 71 | def forward(self, image): 72 | ''' 73 | Description: Method for feed-forward process 74 | 75 | Args: 76 | image -> torch.Tensor(image to be processed through the neural network) 77 | ''' 78 | self.image = image.requires_grad_() 79 | return super(BackPropagation, self).forward(self.image) 80 | 81 | def get_gradients(self): 82 | """ 83 | Generating the gradients for the image 84 | """ 85 | gradient = self.image.grad.clone() 86 | self.image.grad.zero_() 87 | return gradient 88 | 89 | class GuidedBackPropagation(BackPropagation): 90 | """ 91 | "Striving for Simplicity: the All Convolutional Net" 92 | https://arxiv.org/pdf/1412.6806.pdf 93 | Look at Figure 1 on page 8. 94 | """ 95 | def backward_hook(self,module, grad_in, grad_out): 96 | """ 97 | PyTorch hook function for zeroing out the negative gradients during backprop 98 | """ 99 | # Cut off negative gradients 100 | if isinstance(module, nn.ReLU): 101 | #print(module) 102 | grad_in[0][grad_in[0] > 0] = 1 103 | return (F.relu(grad_in[0]),) 104 | 105 | def __init__(self, model, device): 106 | ''' 107 | Description: GBP module for CNN visualization 108 | 109 | Args: 110 | 111 | model -> nn.Module 112 | device -> cpu/cuda 113 | 114 | Example: 115 | gbp = GuidedBackPropagation(model = model, device = 'cuda') 116 | ''' 117 | super(GuidedBackPropagation, self).__init__(model, device) 118 | 119 | 120 | 121 | 122 | def reveal(self, images = None, path = None): 123 | ''' 124 | Description: function for forwarding the images and getting the gradients to visualize 125 | 126 | Args: 127 | images -> type = torch.Tensor, shape = (B,C,H,W) 128 | path -> str(required if images is None) 129 | 130 | Example: 131 | gbp = GuidedBackPropagation(model = model, device = 'cuda') 132 | gradients = gbp.reveal(images = images) 133 | ''' 134 | for module in self.model.named_modules(): 135 | self.handlers.append(module[1].register_backward_hook(self.backward_hook)) 136 | 137 | # Raise error if both path and images are provided 138 | if(path!=None and images!=None): 139 | raise ValueError("Image batches cannot be passed when path is provided") 140 | 141 | # If path is provided 142 | elif(path!=None): 143 | images = cv2.imread(path) 144 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 145 | images = self.preprocess_image(images) 146 | 147 | # If batches of image is provided 148 | if(images!=None): 149 | grads = self.retrieve_gradients(images) 150 | return np.array(grads) 151 | 152 | # If None then raise errors 153 | else: 154 | raise ValueError("Either path or images need to be provided to reveal GBP visualization.") 155 | def preprocess_image(self, image): 156 | ''' 157 | Description: 158 | Takes in an image and applies some transformations to it 159 | 160 | Args: 161 | image -> np.ndarray 162 | 163 | ''' 164 | if(isinstance(image, np.ndarray)): 165 | transform = transforms.Compose([ 166 | transforms.ToPILImage(), 167 | transforms.Resize((224, 224)), 168 | transforms.ToTensor(), 169 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 170 | ]) 171 | 172 | image = transform(image) 173 | image = image.unsqueeze(0) 174 | return image 175 | 176 | else: 177 | raise ValueError("Preprocessing requires type np.ndarray") 178 | 179 | def retrieve_gradients(self, images): 180 | """ 181 | Generating the gradients 182 | """ 183 | grads = [] 184 | for image in images: 185 | image = image.unsqueeze(0).to(self.device) 186 | ids = self.forward(image) 187 | self.backward(ids=ids.indices[0,0]) 188 | gradient = self.get_gradients() 189 | gradient = gradient.cpu().squeeze().permute(1,2,0).detach().numpy() 190 | gradient -= gradient.min() 191 | gradient /= gradient.max() 192 | gradient *= 255.0 193 | gradient = gradient.astype(int) 194 | grads.append(gradient) 195 | 196 | self.remove_hook() 197 | self.handlers = [] 198 | return grads 199 | 200 | 201 | 202 | 203 | -------------------------------------------------------------------------------- /blacbox/pixel_visualizers/saliency.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import torch 4 | import torchvision.transforms as transforms 5 | import torch.nn.functional as F 6 | import torch.nn as nn 7 | import cv2 8 | import logging 9 | import numpy 10 | class Saliency: 11 | def __init__( 12 | self, 13 | model, 14 | device = 'cpu'): 15 | ''' 16 | Description: 17 | 1. Pass the model to the saliency visualizer 18 | 2. set the device to CPU/GPU 19 | 20 | Args: 21 | model -> nn.Module 22 | device -> cpu/cuda 23 | ''' 24 | self.device = device 25 | self.model = model.to(self.device) 26 | 27 | def reveal( 28 | self, 29 | images = None, 30 | path = None, 31 | class_idx = 'keepmax'): 32 | ''' 33 | Description: 34 | The class takes in images of type tensor or a path and returns the saliency for the same 35 | 36 | Args: 37 | images -> type = torch.Tensor, shape = (B,C,H,W) 38 | path -> str(required if images is None) 39 | class_idx -> int(class to visualize)/str 40 | 41 | 42 | Types of class_idx: 43 | 'keepmax' -> Visualize the max score from classification 44 | 'keepmin' -> Visualize the min score from classification 45 | '0-len(output)' -> For visualizing a certain output 46 | 47 | ''' 48 | 49 | # Raise error if both path and images are provided 50 | if(path!=None and images!=None): 51 | raise ValueError("Image batches cannot be passed when path is provided") 52 | 53 | # If path is provided 54 | elif(path!=None): 55 | images = cv2.imread(path) 56 | images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) 57 | images = self.preprocess_image(images) 58 | 59 | # If batches of image is provided 60 | if(images!=None): 61 | saliencies = self.retrieve_saliencies(images, class_idx) 62 | return np.array(saliencies) 63 | 64 | # If None then raise errors 65 | else: 66 | raise AttributeError("Either path or images need to be provided to reveal Saliency visualization.") 67 | 68 | def preprocess_image(self, image): 69 | ''' 70 | Description: 71 | Takes in an image and applies some transformations to it 72 | 73 | Args: 74 | image -> np.ndarray 75 | 76 | ''' 77 | if(isinstance(image, numpy.ndarray)): 78 | transform = transforms.Compose([ 79 | transforms.ToPILImage(), 80 | transforms.Resize((224, 224)), 81 | transforms.ToTensor(), 82 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 83 | ]) 84 | 85 | image = transform(image) 86 | image = image.unsqueeze(0) 87 | return image 88 | 89 | else: 90 | raise ValueError("Preprocessing requires type np.ndarray") 91 | 92 | def retrieve_saliencies(self, images, class_idx): 93 | saliencies = [] 94 | for image in images: 95 | self.image = image.unsqueeze(0).to(self.device) 96 | self.image.requires_grad = True 97 | ''' 98 | Calculating the gradients and computing the 99 | max across each channel and returning the saliency map 100 | ''' 101 | self.output = self.model(self.image) 102 | if(isinstance(class_idx, str)): 103 | if(class_idx == 'keepmax'): 104 | class_idx = self.output.argmax(dim = 1) 105 | 106 | elif(class_idx == 'keepmin'): 107 | class_idx = self.output.argmin(dim = 1) 108 | 109 | else: 110 | error = "class_idx only supports two str arguments\n.\ 111 | 1. keepmax\n \ 112 | 2. keepmin\n \ 113 | " 114 | raise ValueError(error) 115 | self.output[0, class_idx].backward() 116 | saliency, _ = torch.max(self.image.grad.data.abs(), dim=1) 117 | saliency = saliency.reshape(self.image.shape[2],self.image.shape[3],1) 118 | saliency = saliency.cpu().numpy() 119 | saliencies.append(saliency) 120 | return saliencies 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | -------------------------------------------------------------------------------- /images/gbp_another_ex_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_another_ex_input.png -------------------------------------------------------------------------------- /images/gbp_another_ex_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_another_ex_output.png -------------------------------------------------------------------------------- /images/gbp_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_input.png -------------------------------------------------------------------------------- /images/gbp_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gbp_output.png -------------------------------------------------------------------------------- /images/gcam_hot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_hot.png -------------------------------------------------------------------------------- /images/gcam_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_input.png -------------------------------------------------------------------------------- /images/gcam_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_output.png -------------------------------------------------------------------------------- /images/gcam_overlay.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/gcam_overlay.png -------------------------------------------------------------------------------- /images/saliency_input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/saliency_input.png -------------------------------------------------------------------------------- /images/saliency_output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/images/saliency_output.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cycler 2 | kiwisolver 3 | numpy 4 | opencv-python 5 | packaging 6 | pandas 7 | pillow 8 | pyparsing 9 | python-dateutil 10 | pytz 11 | setuptools-scm 12 | six 13 | tomli 14 | torch>=1.8.0 15 | torchvision>=0.8.1 16 | typing_extensions 17 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | from os import path 3 | 4 | 5 | with open('requirements.txt') as f: 6 | reqs = f.read().splitlines() 7 | with open('README.md', encoding='utf-8') as fr: 8 | long_description = fr.read() 9 | 10 | setup( 11 | name="blacbox", 12 | version="0.1.0", 13 | description="A visualization library to make CNNs more interpretable", 14 | long_description=long_description, 15 | long_description_content_type="text/markdown", 16 | url="https://github.com/lazyCodes7/blacbox", 17 | author="Rishab Mudliar", 18 | author_email="rishabmudliar@gmail.com", 19 | license="MIT", 20 | classifiers=[ 21 | "Intended Audience :: Developers", 22 | "License :: OSI Approved :: MIT License", 23 | "Programming Language :: Python", 24 | "Programming Language :: Python :: 3", 25 | "Programming Language :: Python :: 3.6", 26 | "Programming Language :: Python :: 3.7", 27 | "Programming Language :: Python :: 3.8", 28 | "Programming Language :: Python :: 3.9", 29 | "Operating System :: OS Independent" 30 | ], 31 | packages=find_packages(), 32 | include_package_data=True, 33 | install_requires=reqs, 34 | ) -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lazyCodes7/blacbox/de32d3b4b3d31a95179bd75b3259456ca34a0b4a/tests/__init__.py -------------------------------------------------------------------------------- /tests/test.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | import math 3 | from blacbox import GCAM 4 | from blacbox import Saliency 5 | from blacbox import GuidedBackPropagation 6 | import numpy as np 7 | import torchvision.models as models 8 | import torchvision.transforms as transforms 9 | import cv2 10 | import torch 11 | class Tester(unittest.TestCase): 12 | def setUp(self): 13 | image1 = cv2.imread('blacbox/architectures/images/black.png') 14 | image1= cv2.cvtColor(image1, cv2.COLOR_BGR2RGB) 15 | image2 = cv2.imread('blacbox/architectures/images/dog.png') 16 | image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) 17 | image1 = self.preprocess_image(image1) 18 | image2 = self.preprocess_image(image2) 19 | self.images = torch.stack((image1, image2), axis = 0) 20 | self.model_test1 = models.resnet18(pretrained = True) 21 | self.model_test2 = models.resnet50(pretrained=True) 22 | self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 23 | 24 | 25 | def test_gcam(self): 26 | gcam = GCAM( 27 | model = self.model_test2, 28 | interpolate = 'bilinear', 29 | device = self.device 30 | ) 31 | heatmap = gcam.reveal( 32 | images = self.images, 33 | module = self.model_test2.layer4[0].conv1, 34 | class_idx = 'keepmax', 35 | colormap = 'hot' 36 | ) 37 | self.assertEqual(type(heatmap), np.ndarray) 38 | 39 | def test_gbp(self): 40 | gbp = GuidedBackPropagation( 41 | model = self.model_test1, 42 | device = self.device 43 | ) 44 | generated_img = gbp.reveal( 45 | path = 'blacbox/architectures/images/black.png' , 46 | ) 47 | self.assertEqual(type(generated_img), np.ndarray) 48 | 49 | def test_saliency(self): 50 | sal = Saliency( 51 | model = self.model_test1, 52 | device=self.device 53 | ) 54 | maps = sal.reveal(images=self.images, class_idx = 'keepmax') 55 | self.assertEqual(type(maps), np.ndarray) 56 | 57 | def test_out_size(self): 58 | sal = Saliency( 59 | self.model_test1 60 | ) 61 | maps = sal.reveal(images=self.images, class_idx = 'keepmax') 62 | self.assertEqual(len(maps.shape), 4) 63 | 64 | 65 | 66 | def preprocess_image(self, image): 67 | transform = transforms.Compose([ 68 | 69 | transforms.ToPILImage(), 70 | transforms.Resize((224, 224)), 71 | transforms.ToTensor(), 72 | 73 | ]) 74 | 75 | image = transform(image) 76 | return image 77 | 78 | 79 | if __name__ == '__main__': 80 | unittest.main() 81 | --------------------------------------------------------------------------------