├── .github └── workflows │ └── ci.yml ├── .gitignore ├── LICENSE ├── README.md ├── YOLOv8_Explainer.code-workspace ├── YOLOv8_Explainer ├── __init__.py ├── core.py └── utils.py ├── docs ├── functions.md ├── index.md └── references.md ├── example.ipynb ├── mkdocs.yml └── setup.py /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | on: 3 | push: 4 | branches: 5 | - master 6 | - main 7 | permissions: 8 | contents: write 9 | jobs: 10 | deploy: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v3 14 | - uses: actions/setup-python@v4 15 | with: 16 | python-version: 3.x 17 | - uses: actions/cache@v2 18 | with: 19 | key: ${{ github.ref }} 20 | path: .cache 21 | - run: pip install mkdocs-material 22 | - run: pip install "mkdocstrings[python]" 23 | - run: pip install pillow cairosvg 24 | - run: mkdocs gh-deploy --force 25 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | *__pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | test/ 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | .vscode 162 | YOLOv8_Explainer.egg-info/* 163 | .code-workspace/ 164 | /Items 165 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Proyash Paban Sarma Borah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # YOLOv8_Explainer 2 | [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 3 | [![Downloads](https://static.pepy.tech/badge/YOLOv8-Explainer)](https://pepy.tech/project/YOLOv8-Explainer) 4 | ## Simplify your understanding of YOLOv8 Results 5 | This is a package with state of the art methods for Explainable AI for computer vision using YOLOv8. This can be used for diagnosing model predictions, either in production or while developing models. The aim is also to serve as a benchmark of algorithms and metrics for research of new explainability methods. 6 | 7 | ### Install Environment & Dependencies 8 | 9 | `YOLOv8-Explainer` can be seamlessly integrated into your projects with a straightforward installation process: 10 | 11 | #### Installation as a Package 12 | 13 | To incorporate `YOLOv8-Explainer` into your project as a dependency, execute the following command in your terminal: 14 | 15 | ```bash 16 | pip install YOLOv8-Explainer 17 | ``` 18 | ---------- 19 | 20 | ## Features and Functionality 21 | 22 | `YOLOv8-Explainer` can be used to deploy various different CAM models for cutting-edge XAI methodologies in `YOLOv8` for images: 23 | 24 | - GradCAM : Weight the 2D activations by the average gradient 25 | - GradCAM + + : Like GradCAM but uses second order gradients 26 | - XGradCAM : Like GradCAM but scale the gradients by the normalized activations 27 | - EigenCAM : Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) 28 | - HiResCAM : Like GradCAM but element-wise multiply the activations with the gradients; provably guaranteed faithfulness for certain models 29 | - LayerCAM : Spatially weight the activations by positive gradients. Works better especially in lower layers 30 | - EigenGradCAM : Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner 31 | 32 | ---------- 33 | 34 | # Using from code as a library 35 | 36 | ```python 37 | 38 | from YOLOv8_Explainer import yolov8_heatmap, display_images 39 | 40 | model = yolov8_heatmap( 41 | weight="/location/model.pt", 42 | conf_threshold=0.4, 43 | method = "EigenCAM", 44 | layer=[10, 12, 14, 16, 18, -3], 45 | ratio=0.02, 46 | show_box=True, 47 | renormalize=False, 48 | ) 49 | 50 | imagelist = model( 51 | img_path="/location/image.jpg", 52 | ) 53 | 54 | display_images(imagelist) 55 | 56 | ``` 57 | 58 | You can choose between the following CAM Models for version 0.0.5: 59 | 60 | `GradCAM` , `HiResCAM`, `GradCAMPlusPlus`, `XGradCAM` , `LayerCAM`, `EigenGradCAM` and `EigenCAM`. 61 | 62 | You can add a single image or a directory images to be used by the `Module`. The output will be a corresponding list of images (list contianing one PIL Image for a single image imput and list contining as many PIL images as Images in the input directory). 63 | 64 | ---------- 65 | 66 | ## Citation 67 | 68 | If you use this for research, please cite. Here is an example BibTeX entry: 69 | 70 | ``` 71 | @ARTICLE{Sarma_Borah2024-un, 72 | title = "A comprehensive study on Explainable {AI} using {YOLO} and post 73 | hoc method on medical diagnosis", 74 | author = "Sarma Borah, Proyash Paban and Kashyap, Devraj and Laskar, 75 | Ruhini Aktar and Sarmah, Ankur Jyoti", 76 | journal = "J. Phys. Conf. Ser.", 77 | publisher = "IOP Publishing", 78 | volume = 2919, 79 | number = 1, 80 | pages = "012045", 81 | month = dec, 82 | year = 2024, 83 | copyright = "https://creativecommons.org/licenses/by/4.0/" 84 | } 85 | 86 | ``` 87 | ---------- 88 | 89 | # References 90 | https://github.com/jacobgil/pytorch-grad-cam
91 | `PyTorch library for CAM methods 92 | Jacob Gildenblat and contributors` 93 | 94 | https://github.com/z1069614715/objectdetection_script
95 | `Object Detection Script 96 | Devil's Mask` 97 | 98 | https://arxiv.org/abs/1610.02391
99 | `Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization 100 | Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra` 101 | 102 | https://arxiv.org/abs/2011.08891
103 | `Use HiResCAM instead of Grad-CAM for faithful explanations of convolutional neural networks 104 | Rachel L. Draelos, Lawrence Carin` 105 | 106 | https://arxiv.org/abs/1710.11063
107 | `Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks 108 | Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian` 109 | 110 | https://arxiv.org/abs/1910.01279
111 | `Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks 112 | Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian Zhang, Sirui Ding, Piotr Mardziel, Xia Hu` 113 | 114 | https://ieeexplore.ieee.org/abstract/document/9093360/
115 | `Ablation-cam: Visual explanations for deep convolutional network via gradient-free localization. 116 | Saurabh Desai and Harish G Ramaswamy. In WACV, pages 972–980, 2020` 117 | 118 | https://arxiv.org/abs/2008.02312
119 | `Axiom-based Grad-CAM: Towards Accurate Visualization and Explanation of CNNs 120 | Ruigang Fu, Qingyong Hu, Xiaohu Dong, Yulan Guo, Yinghui Gao, Biao Li` 121 | 122 | https://arxiv.org/abs/2008.00299
123 | `Eigen-CAM: Class Activation Map using Principal Components 124 | Mohammed Bany Muhammad, Mohammed Yeasin` 125 | 126 | http://mftp.mmcheng.net/Papers/21TIP_LayerCAM.pdf
127 | `LayerCAM: Exploring Hierarchical Class Activation Maps for Localization 128 | Peng-Tao Jiang; Chang-Bin Zhang; Qibin Hou; Ming-Ming Cheng; Yunchao Wei` 129 | 130 | https://arxiv.org/abs/1905.00780
131 | `Full-Gradient Representation for Neural Network Visualization 132 | \n Suraj Srinivas, Francois Fleuret` 133 | 134 | https://arxiv.org/abs/1806.10206
135 | `Deep Feature Factorization For Concept Discovery 136 | Edo Collins, Radhakrishna Achanta, Sabine Süsstrunk` 137 | -------------------------------------------------------------------------------- /YOLOv8_Explainer.code-workspace: -------------------------------------------------------------------------------- 1 | { 2 | "folders": [ 3 | { 4 | "path": "." 5 | } 6 | ], 7 | "settings": {} 8 | } -------------------------------------------------------------------------------- /YOLOv8_Explainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .core import yolov8_heatmap, yolov8_target 2 | from .utils import letterbox, display_images 3 | 4 | 5 | __all__ =[ 6 | 'yolov8_heatmap', 7 | 'yolov8_target', 8 | 'letterbox', 9 | "display_images" 10 | ] -------------------------------------------------------------------------------- /YOLOv8_Explainer/core.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from typing import List, Optional, Tuple, Union 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | from pytorch_grad_cam import (EigenCAM, EigenGradCAM, GradCAM, GradCAMPlusPlus, 10 | HiResCAM, LayerCAM, RandomCAM, XGradCAM) 11 | from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients 12 | from pytorch_grad_cam.utils.image import scale_cam_image, show_cam_on_image 13 | from ultralytics.nn.tasks import attempt_load_weights 14 | from ultralytics.utils.ops import non_max_suppression, xywh2xyxy 15 | 16 | from .utils import letterbox 17 | 18 | 19 | class ActivationsAndGradients: 20 | """ Class for extracting activations and 21 | registering gradients from targetted intermediate layers """ 22 | 23 | def __init__(self, model: torch.nn.Module, 24 | target_layers: List[torch.nn.Module], 25 | reshape_transform: Optional[callable]) -> None: # type: ignore 26 | """ 27 | Initializes the ActivationsAndGradients object. 28 | 29 | Args: 30 | model (torch.nn.Module): The neural network model. 31 | target_layers (List[torch.nn.Module]): List of target layers from which to extract activations and gradients. 32 | reshape_transform (Optional[callable]): A function to transform the shape of the activations and gradients if needed. 33 | """ 34 | self.model = model 35 | self.gradients = [] 36 | self.activations = [] 37 | self.reshape_transform = reshape_transform 38 | self.handles = [] 39 | for target_layer in target_layers: 40 | self.handles.append( 41 | target_layer.register_forward_hook(self.save_activation)) 42 | # Because of https://github.com/pytorch/pytorch/issues/61519, 43 | # we don't use backward hook to record gradients. 44 | self.handles.append( 45 | target_layer.register_forward_hook(self.save_gradient)) 46 | 47 | def save_activation(self, module: torch.nn.Module, 48 | input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], 49 | output: torch.Tensor) -> None: 50 | """ 51 | Saves the activation of the targeted layer. 52 | 53 | Args: 54 | module (torch.nn.Module): The targeted layer module. 55 | input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer. 56 | output (torch.Tensor): The output activation of the targeted layer. 57 | """ 58 | activation = output 59 | 60 | if self.reshape_transform is not None: 61 | activation = self.reshape_transform(activation) 62 | self.activations.append(activation.cpu().detach()) 63 | 64 | def save_gradient(self, module: torch.nn.Module, 65 | input: Union[torch.Tensor, Tuple[torch.Tensor, ...]], 66 | output: torch.Tensor) -> None: 67 | """ 68 | Saves the gradient of the targeted layer. 69 | 70 | Args: 71 | module (torch.nn.Module): The targeted layer module. 72 | input (Union[torch.Tensor, Tuple[torch.Tensor, ...]]): The input to the targeted layer. 73 | output (torch.Tensor): The output activation of the targeted layer. 74 | """ 75 | if not hasattr(output, "requires_grad") or not output.requires_grad: 76 | # You can only register hooks on tensor requires grad. 77 | return 78 | 79 | # Gradients are computed in reverse order 80 | def _store_grad(grad: torch.Tensor) -> None: 81 | if self.reshape_transform is not None: 82 | grad = self.reshape_transform(grad) 83 | self.gradients = [grad.cpu().detach()] + self.gradients 84 | 85 | output.register_hook(_store_grad) 86 | 87 | def post_process(self, result: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, np.ndarray]: 88 | """ 89 | Post-processes the result. 90 | 91 | Args: 92 | result (torch.Tensor): The result tensor. 93 | 94 | Returns: 95 | Tuple[torch.Tensor, torch.Tensor, np.ndarray]: A tuple containing the post-processed result. 96 | """ 97 | logits_ = result[:, 4:] 98 | boxes_ = result[:, :4] 99 | sorted, indices = torch.sort(logits_.max(1)[0], descending=True) 100 | return torch.transpose(logits_[0], dim0=0, dim1=1)[indices[0]], torch.transpose(boxes_[0], dim0=0, dim1=1)[ 101 | indices[0]], xywh2xyxy(torch.transpose(boxes_[0], dim0=0, dim1=1)[indices[0]]).cpu().detach().numpy() 102 | 103 | def __call__(self, x: torch.Tensor) -> List[List[Union[torch.Tensor, np.ndarray]]]: 104 | """ 105 | Calls the ActivationsAndGradients object. 106 | 107 | Args: 108 | x (torch.Tensor): The input tensor. 109 | 110 | Returns: 111 | List[List[Union[torch.Tensor, np.ndarray]]]: A list containing activations and gradients. 112 | """ 113 | self.gradients = [] 114 | self.activations = [] 115 | model_output = self.model(x) 116 | post_result, pre_post_boxes, post_boxes = self.post_process( 117 | model_output[0]) 118 | return [[post_result, pre_post_boxes]] 119 | 120 | def release(self) -> None: 121 | """Removes hooks.""" 122 | for handle in self.handles: 123 | handle.remove() 124 | 125 | 126 | class yolov8_target(torch.nn.Module): 127 | def __init__(self, ouput_type, conf, ratio) -> None: 128 | super().__init__() 129 | self.ouput_type = ouput_type 130 | self.conf = conf 131 | self.ratio = ratio 132 | 133 | def forward(self, data): 134 | post_result, pre_post_boxes = data 135 | result = [] 136 | for i in range(post_result.size(0)): 137 | if float(post_result[i].max()) >= self.conf: 138 | if self.ouput_type == 'class' or self.ouput_type == 'all': 139 | result.append(post_result[i].max()) 140 | if self.ouput_type == 'box' or self.ouput_type == 'all': 141 | for j in range(4): 142 | result.append(pre_post_boxes[i, j]) 143 | return sum(result) 144 | 145 | 146 | class yolov8_heatmap: 147 | """ 148 | This class is used to implement the YOLOv8 target layer. 149 | 150 | Args: 151 | weight (str): The path to the checkpoint file. 152 | device (str): The device to use for inference. Defaults to "cuda:0" if a GPU is available, otherwise "cpu". 153 | method (str): The method to use for computing the CAM. Defaults to "EigenGradCAM". 154 | layer (list): The indices of the layers to use for computing the CAM. Defaults to [10, 12, 14, 16, 18, -3]. 155 | conf_threshold (float): The confidence threshold for detections. Defaults to 0.2. 156 | ratio (float): The ratio of maximum scores to return. Defaults to 0.02. 157 | show_box (bool): Whether to show bounding boxes with the CAM. Defaults to True. 158 | renormalize (bool): Whether to renormalize the CAM to be in the range [0, 1] across the entire image. Defaults to False. 159 | 160 | Returns: 161 | A tensor containing the output. 162 | 163 | """ 164 | 165 | def __init__( 166 | self, 167 | weight: str, 168 | device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), 169 | method="EigenGradCAM", 170 | layer=[12, 17, 21], 171 | conf_threshold=0.2, 172 | ratio=0.02, 173 | show_box=True, 174 | renormalize=False, 175 | ) -> None: 176 | """ 177 | Initialize the YOLOv8 heatmap layer. 178 | """ 179 | device = device 180 | backward_type = "all" 181 | ckpt = torch.load(weight) 182 | model_names = ckpt['model'].names 183 | model = attempt_load_weights(weight, device) 184 | model.info() 185 | for p in model.parameters(): 186 | p.requires_grad_(True) 187 | model.eval() 188 | 189 | target = yolov8_target(backward_type, conf_threshold, ratio) 190 | target_layers = [model.model[l] for l in layer] 191 | 192 | method = eval(method)(model, target_layers, 193 | use_cuda=device.type == 'cuda') 194 | method.activations_and_grads = ActivationsAndGradients( 195 | model, target_layers, None) 196 | 197 | colors = np.random.uniform( 198 | 0, 255, size=(len(model_names), 3)).astype(int) 199 | self.__dict__.update(locals()) 200 | 201 | def post_process(self, result): 202 | """ 203 | Perform non-maximum suppression on the detections and process results. 204 | 205 | Args: 206 | result (torch.Tensor): The raw detections from the model. 207 | 208 | Returns: 209 | torch.Tensor: Filtered and processed detections. 210 | """ 211 | # Perform non-maximum suppression 212 | processed_result = non_max_suppression( 213 | result, 214 | conf_thres=self.conf_threshold, # Use the class's confidence threshold 215 | iou_thres=0.45 # Intersection over Union threshold 216 | ) 217 | 218 | # If no detections, return an empty tensor 219 | if len(processed_result) == 0 or processed_result[0].numel() == 0: 220 | return torch.empty(0, 6) # Return an empty tensor with 6 columns 221 | 222 | # Take the first batch of detections (assuming single image) 223 | detections = processed_result[0] 224 | 225 | # Filter detections based on confidence 226 | mask = detections[:, 4] >= self.conf_threshold 227 | filtered_detections = detections[mask] 228 | 229 | return filtered_detections 230 | 231 | def draw_detections(self, box, color, name, img): 232 | """ 233 | Draw bounding boxes and labels on an image for multiple detections. 234 | 235 | Args: 236 | box (torch.Tensor or np.ndarray): The bounding box coordinates in the format [x1, y1, x2, y2] 237 | color (list): The color of the bounding box in the format [B, G, R] 238 | name (str): The label for the bounding box. 239 | img (np.ndarray): The image on which to draw the bounding box 240 | 241 | Returns: 242 | np.ndarray: The image with the bounding box drawn. 243 | """ 244 | # Ensure box coordinates are integers 245 | xmin, ymin, xmax, ymax = map(int, box[:4]) 246 | 247 | # Draw rectangle 248 | cv2.rectangle(img, (xmin, ymin), (xmax, ymax), 249 | tuple(int(x) for x in color), 2) 250 | 251 | # Draw label 252 | cv2.putText(img, name, (xmin, ymin - 5), 253 | cv2.FONT_HERSHEY_SIMPLEX, 254 | 0.8, tuple(int(x) for x in color), 2, 255 | lineType=cv2.LINE_AA) 256 | 257 | return img 258 | 259 | def renormalize_cam_in_bounding_boxes( 260 | self, 261 | boxes: np.ndarray, # type: ignore 262 | image_float_np: np.ndarray, # type: ignore 263 | grayscale_cam: np.ndarray, # type: ignore 264 | ) -> np.ndarray: 265 | """ 266 | Normalize the CAM to be in the range [0, 1] 267 | inside every bounding boxes, and zero outside of the bounding boxes. 268 | 269 | Args: 270 | boxes (np.ndarray): The bounding boxes. 271 | image_float_np (np.ndarray): The image as a numpy array of floats in the range [0, 1]. 272 | grayscale_cam (np.ndarray): The CAM as a numpy array of floats in the range [0, 1]. 273 | 274 | Returns: 275 | np.ndarray: The renormalized CAM. 276 | """ 277 | renormalized_cam = np.zeros(grayscale_cam.shape, dtype=np.float32) 278 | for x1, y1, x2, y2 in boxes: 279 | x1, y1 = max(x1, 0), max(y1, 0) 280 | x2, y2 = min(grayscale_cam.shape[1] - 1, 281 | x2), min(grayscale_cam.shape[0] - 1, y2) 282 | renormalized_cam[y1:y2, x1:x2] = scale_cam_image( 283 | grayscale_cam[y1:y2, x1:x2].copy()) 284 | renormalized_cam = scale_cam_image(renormalized_cam) 285 | eigencam_image_renormalized = show_cam_on_image( 286 | image_float_np, renormalized_cam, use_rgb=True) 287 | return eigencam_image_renormalized 288 | 289 | def renormalize_cam(self, boxes, image_float_np, grayscale_cam): 290 | """Normalize the CAM to be in the range [0, 1] 291 | across the entire image.""" 292 | renormalized_cam = scale_cam_image(grayscale_cam) 293 | eigencam_image_renormalized = show_cam_on_image( 294 | image_float_np, renormalized_cam, use_rgb=True) 295 | return eigencam_image_renormalized 296 | 297 | def process(self, img_path): 298 | """Process the input image and generate CAM visualization.""" 299 | img = cv2.imread(img_path) 300 | img = letterbox(img)[0] 301 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 302 | img = np.float32(img) / 255.0 303 | 304 | tensor = ( 305 | torch.from_numpy(np.transpose(img, axes=[2, 0, 1])) 306 | .unsqueeze(0) 307 | .to(self.device) 308 | ) 309 | 310 | try: 311 | grayscale_cam = self.method(tensor, [self.target]) 312 | except AttributeError as e: 313 | print(e) 314 | return 315 | 316 | grayscale_cam = grayscale_cam[0, :] 317 | 318 | pred1 = self.model(tensor)[0] 319 | pred = non_max_suppression( 320 | pred1, 321 | conf_thres=self.conf_threshold, 322 | iou_thres=0.45 323 | )[0] 324 | 325 | # Debugging print 326 | 327 | if self.renormalize: 328 | cam_image = self.renormalize_cam( 329 | pred[:, :4].cpu().detach().numpy().astype(np.int32), 330 | img, 331 | grayscale_cam 332 | ) 333 | else: 334 | cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True) 335 | 336 | if self.show_box and len(pred) > 0: 337 | for detection in pred: 338 | detection = detection.cpu().detach().numpy() 339 | 340 | # Get class index and confidence 341 | class_index = int(detection[5]) 342 | conf = detection[4] 343 | 344 | # Draw detection 345 | cam_image = self.draw_detections( 346 | detection[:4], # Box coordinates 347 | self.colors[class_index], # Color for this class 348 | f"{self.model_names[class_index]}", # Label with confidence 349 | cam_image, 350 | ) 351 | 352 | cam_image = Image.fromarray(cam_image) 353 | return cam_image 354 | 355 | def __call__(self, img_path): 356 | """Generate CAM visualizations for one or more images. 357 | 358 | Args: 359 | img_path (str): Path to the input image or directory containing images. 360 | 361 | Returns: 362 | None 363 | """ 364 | if os.path.isdir(img_path): 365 | image_list = [] 366 | for img_path_ in os.listdir(img_path): 367 | img_pil = self.process(f"{img_path}/{img_path_}") 368 | image_list.append(img_pil) 369 | return image_list 370 | else: 371 | return [self.process(img_path)] 372 | -------------------------------------------------------------------------------- /YOLOv8_Explainer/utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | 8 | 9 | def letterbox( 10 | im: np.ndarray, 11 | new_shape=(640, 640), 12 | color=(114, 114, 114), 13 | auto=True, 14 | scaleFill=False, 15 | scaleup=True, 16 | stride=32, 17 | ): 18 | """ 19 | Resize and pad image while meeting stride-multiple constraints. 20 | 21 | Args: 22 | im (numpy.ndarray): Input image. 23 | new_shape (tuple, optional): Desired output shape. Defaults to (640, 640). 24 | color (tuple, optional): Color of the border. Defaults to (114, 114, 114). 25 | auto (bool, optional): Whether to automatically determine padding. Defaults to True. 26 | scaleFill (bool, optional): Whether to stretch the image to fill the new shape. Defaults to False. 27 | scaleup (bool, optional): Whether to scale the image up if necessary. Defaults to True. 28 | stride (int, optional): Stride of the sliding window. Defaults to 32. 29 | 30 | Returns: 31 | numpy.ndarray: Letterboxed image. 32 | tuple: Ratio of the resized image. 33 | tuple: Padding sizes. 34 | 35 | """ 36 | shape = im.shape[:2] # current shape [height, width] 37 | if isinstance(new_shape, int): 38 | new_shape = (new_shape, new_shape) 39 | 40 | # Scale ratio (new / old) 41 | r = min(new_shape[0] / shape[0], new_shape[1] / shape[1]) 42 | if not scaleup: # only scale down, do not scale up (for better val mAP) 43 | r = min(r, 1.0) 44 | 45 | # Compute padding 46 | ratio = r, r # width, height ratios 47 | new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r)) 48 | dw, dh = new_shape[1] - new_unpad[0], new_shape[0] - new_unpad[1] # wh padding 49 | if auto: # minimum rectangle 50 | dw, dh = np.mod(dw, stride), np.mod(dh, stride) # wh padding 51 | elif scaleFill: # stretch 52 | dw, dh = 0.0, 0.0 53 | new_unpad = (new_shape[1], new_shape[0]) 54 | ratio = new_shape[1] / shape[1], new_shape[0] / shape[0] # width, height ratios 55 | 56 | dw /= 2 # divide padding into 2 sides 57 | dh /= 2 58 | 59 | if shape[::-1] != new_unpad: # resize 60 | im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR) 61 | top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1)) 62 | left, right = int(round(dw - 0.1)), int(round(dw + 0.1)) 63 | im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color) # add border 64 | 65 | return im, ratio, (dw, dh) 66 | 67 | def display_images(images): 68 | """ 69 | Display a list of PIL images in a grid. 70 | 71 | Args: 72 | images (list[PIL.Image]): A list of PIL images to display. 73 | 74 | Returns: 75 | None 76 | """ 77 | fig, axes = plt.subplots(1, len(images), figsize=(15, 7)) 78 | if len(images) == 1: 79 | axes = [axes] 80 | for ax, img in zip(axes, images): 81 | ax.imshow(img) 82 | ax.axis('off') 83 | plt.show() -------------------------------------------------------------------------------- /docs/functions.md: -------------------------------------------------------------------------------- 1 | # Functions 2 | 3 | ##Main Functions 4 | The core functions that can be used to visualise the different Class Activated Mapping(CAM) are given below. 5 | ::: YOLOv8_Explainer.core 6 | 7 | ##Helper Functions 8 | The functions that can be used to display images and provide various other functionalities can be found here. 9 | 10 | ::: YOLOv8_Explainer.utils -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to YOLOv8 Explainer 2 | 3 | ## Simplify your understanding of YOLOv8 Results 4 | This is a package with state of the art Class Activated Mapping(CAM) methods for Explainable AI for computer vision using YOLOv8. This can be used for diagnosing model predictions, either in production or while developing models. The aim is also to serve as a benchmark of algorithms and metrics for research of new explainability methods. 5 | 6 | ### Install Environment & Dependencies 7 | 8 | `YOLOv8-Explainer` can be seamlessly integrated into your projects with a straightforward installation process: 9 | 10 | #### Installation as a Package 11 | 12 | To incorporate `YOLOv8-Explainer` into your project as a dependency, execute the following command in your terminal: 13 | 14 | ```bash 15 | pip install YOLOv8-Explainer 16 | ``` 17 | ## Features and Functionality 18 | 19 | `YOLOv8-Explainer` can be used to deploy various different CAM models for cutting-edge XAI methodologies in `YOLOv8` for images: 20 | 21 | - GradCAM : Weight the 2D activations by the average gradient 22 | - GradCAM + + : Like GradCAM but uses second order gradients 23 | - XGradCAM : Like GradCAM but scale the gradients by the normalized activations 24 | - EigenCAM : Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results) 25 | - HiResCAM : Like GradCAM but element-wise multiply the activations with the gradients; provably guaranteed faithfulness for certain models 26 | - LayerCAM : Spatially weight the activations by positive gradients. Works better especially in lower layers 27 | - EigenGradCAM : Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner 28 | 29 | #### Using from code as a library 30 | 31 | ```python 32 | 33 | from YOLOv8_Explainer import yolov8_heatmap, display_images 34 | 35 | model = yolov8_heatmap( 36 | weight="/location/model.pt", 37 | conf_threshold=0.4, 38 | method = "EigenCAM", 39 | layer=[10, 12, 14, 16, 18, -3], 40 | ratio=0.02, 41 | show_box=True, 42 | renormalize=False, 43 | ) 44 | 45 | images = model( 46 | img_path="/location/image.jpg", 47 | ) 48 | 49 | display_images(images) 50 | 51 | ``` 52 | - Here the `from YOLOv8_Explainer import yolov8_heatmap, display_images` allows you to import the required functionalities. 53 | 54 | - The line `model = yolov8_heatmap( weight="/location/model.pt", conf_threshold=0.4, method = "EigenCAM", layer=[12, 17, 21], ratio=0.02, show_box=True, renormalize=False)` allows the user to pass a pertrained YOLO weight which the CAM is expected to evaluate, along with additional parameters like the desired CAM method target layers, and the confidence threshold. 55 | 56 | You can choose between the following CAM Models for version 0.0.5: 57 | 58 | `GradCAM` , `HiResCAM`, `GradCAMPlusPlus`, `XGradCAM` , `LayerCAM`, `EigenGradCAM` and `EigenCAM`. 59 | 60 | - The line `images = model( img_path="/location/image.jpg" )` passes the images the model will process 61 | 62 | - The line `display_images(images)` displays the output of the model (bounding box), along with the CAM model's output (saliency map). 63 | 64 | You can add a single image or a directory images to be used by the `Module`. The output will be a corresponding list of images (list containing one PIL Image for a single image input and list containing as many PIL images as Images in the input directory). 65 | 66 | ## Citing in Works 67 | 68 | If you use this for research, please cite. Here is an example BibTeX entry: 69 | 70 | ``` 71 | @ARTICLE{Sarma_Borah2024-un, 72 | title = "A comprehensive study on Explainable {AI} using {YOLO} and post 73 | hoc method on medical diagnosis", 74 | author = "Sarma Borah, Proyash Paban and Kashyap, Devraj and Laskar, 75 | Ruhini Aktar and Sarmah, Ankur Jyoti", 76 | journal = "J. Phys. Conf. Ser.", 77 | publisher = "IOP Publishing", 78 | volume = 2919, 79 | number = 1, 80 | pages = "012045", 81 | month = dec, 82 | year = 2024, 83 | copyright = "https://creativecommons.org/licenses/by/4.0/" 84 | } 85 | 86 | ``` 87 | -------------------------------------------------------------------------------- /docs/references.md: -------------------------------------------------------------------------------- 1 | # References 2 | 3 | https://github.com/jacobgil/pytorch-grad-cam
4 | `PyTorch library for CAM methods 5 | Jacob Gildenblat and contributors` 6 | 7 | https://github.com/z1069614715/objectdetection_script
8 | `Object Detection Script 9 | Devil's Mask` 10 | 11 | https://arxiv.org/abs/1610.02391
12 | `Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization 13 | Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, Dhruv Batra` 14 | 15 | https://arxiv.org/abs/2011.08891
16 | `Use HiResCAM instead of Grad-CAM for faithful explanations of convolutional neural networks 17 | Rachel L. Draelos, Lawrence Carin` 18 | 19 | https://arxiv.org/abs/1710.11063
20 | `Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks 21 | Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian` 22 | 23 | https://arxiv.org/abs/1910.01279
24 | `Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks 25 | Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian Zhang, Sirui Ding, Piotr Mardziel, Xia Hu` 26 | 27 | https://ieeexplore.ieee.org/abstract/document/9093360/
28 | `Ablation-cam: Visual explanations for deep convolutional network via gradient-free localization. 29 | Saurabh Desai and Harish G Ramaswamy. In WACV, pages 972–980, 2020` 30 | 31 | https://arxiv.org/abs/2008.02312
32 | `Axiom-based Grad-CAM: Towards Accurate Visualization and Explanation of CNNs 33 | Ruigang Fu, Qingyong Hu, Xiaohu Dong, Yulan Guo, Yinghui Gao, Biao Li` 34 | 35 | https://arxiv.org/abs/2008.00299
36 | `Eigen-CAM: Class Activation Map using Principal Components 37 | Mohammed Bany Muhammad, Mohammed Yeasin` 38 | 39 | http://mftp.mmcheng.net/Papers/21TIP_LayerCAM.pdf
40 | `LayerCAM: Exploring Hierarchical Class Activation Maps for Localization 41 | Peng-Tao Jiang; Chang-Bin Zhang; Qibin Hou; Ming-Ming Cheng; Yunchao Wei` 42 | 43 | https://arxiv.org/abs/1905.00780
44 | `Full-Gradient Representation for Neural Network Visualization 45 | Suraj Srinivas, Francois Fleuret` 46 | 47 | https://arxiv.org/abs/1806.10206
48 | `Deep Feature Factorization For Concept Discovery 49 | Edo Collins, Radhakrishna Achanta, Sabine Süsstrunk` -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | site_name: YOLOv8 Explainer 4 | theme: 5 | name: material 6 | features: 7 | - search.suggest 8 | - search.highlight 9 | - navigation.tabs 10 | - navigation.sections 11 | - navigation.top 12 | - navigation.tracking 13 | - navigation.footer 14 | - navigation.indexes 15 | - toc.integrate 16 | - navigation.top 17 | - content.tabs.link 18 | - content.code.annotation 19 | - content.code.copy 20 | 21 | language: en 22 | palette: 23 | - media: '(prefers-color-scheme: light)' 24 | scheme: slate 25 | primary: teal 26 | accent: amber 27 | toggle: 28 | icon: material/lightbulb 29 | name: Switch to light mode 30 | - media: '(prefers-color-scheme: dark)' 31 | scheme: default 32 | primary: teal 33 | accent: amber 34 | toggle: 35 | icon: material/lightbulb-outline 36 | name: Switch to dark mode 37 | 38 | 39 | extra: 40 | social: 41 | - icon: fontawesome/brands/github-alt 42 | link: https://github.com/dev-rajk 43 | - icon: fontawesome/brands/github-alt 44 | link: https://github.com/Spritan 45 | # - icon: fontawesome/brands/twitter 46 | # link: https://twitter.com/TheJamesWillett 47 | # - icon: fontawesome/brands/linkedin 48 | # link: https://www.linkedin.com/in/willettjames/ 49 | markdown_extensions: 50 | - pymdownx.highlight: 51 | anchor_linenums: true 52 | - pymdownx.inlinehilite 53 | - pymdownx.snippets 54 | - admonition 55 | - pymdownx.arithmatex: 56 | generic: true 57 | - footnotes 58 | - pymdownx.details 59 | - pymdownx.superfences 60 | - pymdownx.mark 61 | - attr_list 62 | - pymdownx.emoji: 63 | emoji_index: !!python/name:material.extensions.emoji.twemoji 64 | emoji_generator: !!python/name:materialx.emoji.to_svg 65 | 66 | copyright: | 67 | © 2024 Proyash Paban Sarma Borah and Devraj Kashyap 68 | 69 | plugins: 70 | - search: null 71 | - mkdocstrings: 72 | default_handler: python 73 | handlers: 74 | python: 75 | paths: [src] 76 | 77 | nav: 78 | - Home: index.md 79 | - Detailed Documentation : functions.md 80 | - References: references.md 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | with open("./README.md", "r") as f: 4 | long_description = f.read() 5 | 6 | setup( 7 | name="YOLOv8_Explainer", 8 | version="0.0.05", 9 | description="Python packages that enable XAI methods for YOLOv8", 10 | packages=find_packages(), 11 | long_description=long_description, 12 | long_description_content_type="text/markdown", 13 | url="https://github.com/Spritan/YOLOv8_Explainer", 14 | # Github="https://github.com/Spritan/YOLOv8_Explainer", 15 | author="Spritan", 16 | author_email="proypabsab@gmail.com", 17 | license="MIT", 18 | classifiers=[ 19 | "License :: OSI Approved :: MIT License", 20 | "Development Status :: 5 - Production/Stable", 21 | "Intended Audience :: Education", 22 | "Programming Language :: Python", 23 | "Operating System :: OS Independent", 24 | ], 25 | install_requires=[ 26 | "grad-cam==1.4.8", 27 | "ultralytics", 28 | "Pillow", 29 | "tqdm", 30 | "torch", 31 | "matplotlib", 32 | ], 33 | extras_require={ 34 | "dev": ["twine>=4.0.2"], 35 | }, 36 | project_urls={ 37 | 'Homepage': 'https://spritan.github.io/YOLOv8_Explainer/', # Homepage URL 38 | 'Source': 'https://github.com/Spritan/YOLOv8_Explainer', # GitHub URL 39 | }, 40 | ) 41 | --------------------------------------------------------------------------------