├── .gitignore ├── LICENSE.md ├── README.md ├── data ├── test_annotation │ ├── 14_1.xml │ └── 15_1.xml └── test_images │ ├── 14_1.jpg │ └── 15_1.jpg ├── docker-compose.yml ├── evaluation.py ├── img.png ├── img_1.png ├── img_2.png ├── main.py ├── model ├── config │ ├── model_config.json │ └── server_config.json └── src │ ├── .gitkeep │ ├── RMQ.py │ ├── config.json │ ├── image.jpg │ ├── model.py │ ├── object-detection.pbtxt │ └── pipeline.config ├── requirements.txt ├── utils.py ├── xAI_config.json └── xai ├── adasise.py ├── density_map.py ├── drise.py ├── gradcam.py ├── kde.py ├── lime_method.py └── rise.py /.gitignore: -------------------------------------------------------------------------------- 1 | /data/samples/ 2 | /data/test_annotation/ 3 | /data/test_images/ 4 | /model/src/frozen_inference_graph.pb 5 | .idea 6 | /xai/KDE+density_map.ipynb 7 | /xai/Code_akamedic.ipynb 8 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![img_2.png](img_2.png) 2 | # Towards XAI in Thyroid Tumor Diagnosis # 3 | Our paper is accepted and presented as a long presentation at Health Intelligence Workshop (W3PHIAI-23) at AAAI-23. 4 | ## Source code for XAI Thyroid - an XAI object detection problem ## 5 | 6 | 1. **Install environment** 7 | 8 | ``` 9 | pip install -r requirements.txt 10 | ``` 11 | 12 | - Download the model at: https://drive.google.com/file/d/1IOyom78mexC6BPq4gPkfFLg-3vPzBnnO/view?usp=sharing 13 | 14 | Move the model to ``model/src/`` folder. 15 | 16 | 2. **Instruct the parameters to be run with each algorithm** 17 | 18 | ``` 19 | python main.py --help 20 | ``` 21 | 22 | 3. **Command line example with algorithms** 23 | Arguments options: 24 | 25 | - `--config-path`: path to the configuration file 26 | - `--method`: XAI method to run (options: eLRP, GradCAM, GradCAM++, RISE, LIME, DRISE, KDE, DensityMap, AdaSISE) 27 | - `--image-path`: path to the image to be processed 28 | - `--stage`: stage of the algorithm to be run (options: first_stage, second_stage, default: first_stage) 29 | - `--threshold`: threshold of output values to visualize 30 | - `--output-path`: path to the output directory 31 | 32 | For example, to run the XAI algorithms on images in test_images folder: 33 | 34 | - **GradCAM** 35 | 36 | In first stage: 37 | 38 | ``` 39 | python main.py --config-path xAI_config.json --method GradCAM --image-path data/test_images/ --output-path results/ 40 | ``` 41 | 42 | In second stage: 43 | 44 | ``` 45 | python main.py --config-path xAI_config.json --method GradCAM --image-path data/test_images/ --stage second_stage --output-path results/ 46 | ``` 47 | 48 | - **GradCAM++** 49 | 50 | In first stage: 51 | 52 | ``` 53 | python main.py --config-path xAI_config.json --method GradCAM++ --image-path data/test_images/ --output-path results/ 54 | ``` 55 | 56 | In second stage: 57 | 58 | ``` 59 | python main.py --config-path xAI_config.json --method GradCAM++ --image-path data/test_images/ --stage second_stage --output-path results/ 60 | 61 | ``` 62 | 63 | **Note:** To change input, change the path to new data and path to xml file in xAI_config.json 64 | 65 | ## Applicability 66 | ![img_1.png](img_1.png) 67 | • Region Proposal Generation (Which proposals are generated by the model during the model’s first stage?): Kernel Density Estimation (KDE), Density map (DM). 68 | 69 | • Classification (Which features of an image make the model classify an image containing a nodule(s) at the model’s second stage?): LRP, Grad-CAM, Grad-CAM++, LIME, RISE, Ada-SISE, D-RISE. 70 | 71 | • Localization (Which features of an image does the model consider to detect a specific box containing a nodule at the model’s second stage?): D-RISE. 72 | ## Results 73 | ![img.png](img.png) 74 | 75 | ## Citation 76 | If you find this repository helpful for your research. Please cite our paper as a small support for us too :) 77 | ``` 78 | @article{nguyen2023towards, 79 | title={Towards Trust of Explainable AI in Thyroid Nodule Diagnosis}, 80 | author={Nguyen, Truong Thanh Hung and Truong, Van Binh and Nguyen, Vo Thanh Khang and Cao, Quoc Hung and Nguyen, Quoc Khanh}, 81 | journal={arXiv preprint arXiv:2303.04731}, 82 | year={2023} 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /data/test_annotation/14_1.xml: -------------------------------------------------------------------------------- 1 | 14_1.jpgxml5603603thyroid_cancer0Unspecified03631403611323571173501133411073371073351033349832990319813088129811130512131712932213131913430213729814029514329515129815530616231616632816933317133717235217535617336016836616436815836815136214236114229581369176thyroid_cancer0Unspecified026110726810627610628410528699284932828627882271792607924782223962219821810221710722011222511423311623911624711526511126410821779287117 -------------------------------------------------------------------------------- /data/test_annotation/15_1.xml: -------------------------------------------------------------------------------- 1 | 15_1.jpgxml5603603thyroid_cancer0Unspecified031599304982989626610125810524910824111323012022213022213021314121215021215721316521418322019322519722920223420624120824220824620826021026721626921828222528922929823631924233324334424435324635924937125238325239625140024940424641724242724143623643823144221844720744720044419144318744317644317144316144315244014543313942513341612741012040411539611138210837210736510735410634710633510432510032098316973119921296448253 -------------------------------------------------------------------------------- /data/test_images/14_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/data/test_images/14_1.jpg -------------------------------------------------------------------------------- /data/test_images/15_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/data/test_images/15_1.jpg -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: '3.7' 2 | 3 | volumes: 4 | explainer_logs: 5 | name: explainer_logs 6 | external: false 7 | 8 | services: 9 | explainer_tensorboard: 10 | container_name: explainer_tensorboard 11 | image: explainer_tensorboard:latest 12 | build: 13 | context: ./tensorboard/ 14 | dockerfile: ./Dockerfile 15 | ports: 16 | - 6006:6006 17 | volumes: 18 | - explainer_logs:/tmp/explainer_logs 19 | command: ["run", "//explainer_tensorboard", "--", "--logdir=/tmp/explainer_logs"] 20 | depends_on: 21 | - explainer_model-backend 22 | - explainer_assets-server 23 | 24 | explainer_model-backend: 25 | container_name: explainer_model-backend 26 | image: explainer_model-backend:latest 27 | build: 28 | context: ./model-backend/ 29 | dockerfile: ./Dockerfile 30 | ports: 31 | - 5000:5000 32 | command: ["python", "run.py"] 33 | 34 | explainer_assets-server: 35 | container_name: explainer_assets-server 36 | image: explainer_assets-server:latest 37 | build: 38 | context: ./assets-server/ 39 | dockerfile: ./Dockerfile 40 | ports: 41 | - 8000:8000 42 | command: ["python", "-u", "app.py", "/serve"] 43 | 44 | explainer_summary: 45 | container_name: explainer_summary 46 | image: explainer_summary:latest 47 | build: 48 | context: ./tensorboard/ 49 | dockerfile: ./Dockerfile 50 | volumes: 51 | - explainer_logs:/tmp/explainer_logs 52 | command: ["run", "//explainer_plugin:explainer_demo"] 53 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import warnings 4 | from datetime import datetime 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow.compat.v1 as tf 9 | from skimage import io, img_as_ubyte 10 | from tqdm import tqdm 11 | 12 | from utils import * 13 | from xai.adasise import AdaSISE 14 | from xai.density_map import DensityMap 15 | from xai.drise import DRISE 16 | from xai.gradcam import GradCAM, GradCAMPlusPlus 17 | from xai.kde import KDE 18 | from xai.lime_method import LIME 19 | from xai.rise import RISE 20 | 21 | warnings.filterwarnings('ignore') 22 | start = datetime.now() 23 | 24 | 25 | def main(args): 26 | # ---------------------------------Parameters------------------------------------- 27 | img_rs, output_tensor, last_conv_tensor, grads, num_sample, NMS = None, None, None, None, None, None 28 | config_xAI = get_config(args.config_path) 29 | config_models = get_config(config_xAI['Model']['file_config']) 30 | image_dict = {} 31 | sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes = get_model( 32 | config_models[0]['model_path']) 33 | threshold = config_xAI['Model']['threshold'] 34 | 35 | # create array to save results for 5 metrics 36 | drop_rate = [] 37 | inc = [] 38 | ebpg_ = [] 39 | bbox_ = [] 40 | iou_ = [] 41 | 42 | # Run xAI for each image 43 | for j in tqdm(sorted(glob.glob(f'{args.image_path}/*.jpg'))): 44 | # Load image from input folder and extract ground-truth labels from xml file 45 | image = cv2.imread(j) 46 | img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 47 | image = img.reshape(1, img.shape[0], img.shape[1], 3) 48 | name_img = os.path.basename(j).split('.')[0] 49 | mu = np.mean(image) 50 | h_img, w_img = img.shape[:2] 51 | y_p_boxes, y_p_scores, y_p_num_detections, y_p_classes = sess.run( \ 52 | [detection_boxes, detection_scores, num_detections, detection_classes], \ 53 | feed_dict={img_input: image}) 54 | if y_p_num_detections == 0: 55 | continue 56 | # load the saliency map 57 | cam_map = np.load(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy")) 58 | if args.method == 'eLRP': 59 | cam_map = cam_map[:, :, 2] - cam_map[:, :, 0] 60 | cam_map = abs(cam_map) 61 | cam_map = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min()) 62 | elif args.method == 'D-RISE': 63 | map = 0 64 | for cam in cam_map: 65 | cam = (cam - cam.min()) / (cam.max() - cam.min()) 66 | map += cam 67 | cam_map = map 68 | # Coordinates of the predicted boxes 69 | box_predicted = [] 70 | mask = np.zeros_like(cam_map) 71 | for i in range(int(y_p_num_detections[0])): 72 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 73 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 74 | box_predicted.append([x1,y1,x2,y2]) 75 | mask[y1:y2, x1:x2] = 1 76 | # ---------------------------DROP-INC---------------------------- 77 | cam_map = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min()) 78 | invert = ((img / 255) * np.dstack([cam_map]*3)) * 255 79 | bias = mu * (1 -cam_map) 80 | masked = (invert + bias[:, :, np.newaxis]).astype(np.uint8) 81 | masked = masked[None, :] 82 | p_boxes, p_scores, p_num_detections, p_classes = sess.run( \ 83 | [detection_boxes, detection_scores, num_detections, detection_classes], \ 84 | feed_dict={img_input: masked}) 85 | prob = y_p_scores[0][:int(y_p_num_detections[0])].sum() 86 | prob_ex = p_scores[0][:int(p_num_detections[0])].sum() 87 | if prob < prob_ex: 88 | inc.append(1) 89 | drop = max((prob - prob_ex) / prob, 0) 90 | drop_rate.append(drop) 91 | # ---------------------------Localization evaluation---------------------------- 92 | bbox_.append(bounding_boxes(box_predicted, cam_map)) 93 | ebpg_.append(energy_point_game(box_predicted, cam_map)) 94 | iou_.append(IoU(mask, cam_map)) 95 | # print results with eps = 1e-10 avoid the case where the denominator is 0 96 | print("Drop rate:", sum(drop_rate)/(len(drop_rate)+1e-10)) 97 | print("Increases", sum(inc)/(len(inc)+1e-10)) 98 | print("EBPG:", sum(ebpg_)/(len(ebpg_)+1e-10)) 99 | print("Bbox:", sum(bbox_)/(len(bbox_)+1e-10)) 100 | print("IoU:", sum(iou_)/(len(iou_)+1e-10)) 101 | 102 | if __name__ == '__main__': 103 | arguments = get_parser() 104 | main(arguments) 105 | print(f'Total training time: {datetime.now() - start}') 106 | -------------------------------------------------------------------------------- /img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/img.png -------------------------------------------------------------------------------- /img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/img_1.png -------------------------------------------------------------------------------- /img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/img_2.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import warnings 4 | from datetime import datetime 5 | 6 | import cv2 7 | import numpy as np 8 | import tensorflow.compat.v1 as tf 9 | from skimage import io, img_as_ubyte 10 | from tqdm import tqdm 11 | 12 | from utils import DeepExplain, get_config, get_model, get_info, save_image, gen_cam, GradientMethod, draw, create_file, get_parser 13 | from xai.adasise import AdaSISE 14 | from xai.density_map import DensityMap 15 | from xai.drise import DRISE 16 | from xai.gradcam import GradCAM, GradCAMPlusPlus 17 | from xai.kde import KDE 18 | from xai.lime_method import LIME 19 | from xai.rise import RISE 20 | 21 | warnings.filterwarnings('ignore') 22 | start = datetime.now() 23 | 24 | 25 | def main(args): 26 | # ---------------------------------Parameters------------------------------------- 27 | img_rs, output_tensor, last_conv_tensor, grads, num_sample, NMS = None, None, None, None, None, None 28 | config_xAI = get_config(args.config_path) 29 | config_models = get_config(config_xAI['Model']['file_config']) 30 | image_dict = {} 31 | sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes = get_model( 32 | config_models[0]['model_path']) 33 | threshold = config_xAI['Model']['threshold'] 34 | # -------------------Create directory------------------- 35 | create_file(args.output_path) 36 | create_file(args.output_numpy) 37 | # -------------------------eLRP------------------------- 38 | if args.method in ['eLRP']: 39 | img_rs = sess.graph.get_tensor_by_name(config_xAI['Gradient']['target'] + ':0') 40 | output_tensor = sess.graph.get_tensor_by_name(config_xAI['Gradient']['output'] + ':0') 41 | with DeepExplain(session=sess) as de: 42 | explainer = de.get_explainer(args.method, np.sum(output_tensor[0, :, 1:2]), img_rs) 43 | 44 | # ---------------------------------GradCAM, GradCAM++------------------------------------- 45 | elif args.method in ['GradCAM', 'GradCAM++']: 46 | last_conv_tensor = sess.graph.get_tensor_by_name(config_xAI['CAM'][args.stage]['target'] + ':0') 47 | output_tensor = sess.graph.get_tensor_by_name(config_xAI['CAM'][args.stage]['output'] + ':0') 48 | 49 | if args.stage == 'first_stage': 50 | grads = tf.gradients(np.sum(output_tensor[0, :, 1:2]), last_conv_tensor)[0] 51 | else: 52 | NMS = sess.graph.get_tensor_by_name(config_xAI['CAM'][args.stage]['NMS'] + ':0') 53 | elif args.method in ['KDE', 'DensityMap', 'AdaSISE']: 54 | # These two methods do not need num_sample 55 | pass 56 | else: 57 | num_sample = config_xAI[args.method]['num_sample'] 58 | 59 | # Run xAI for each image 60 | for j in tqdm(sorted(glob.glob(f'{args.image_path}/*.jpg'))): 61 | # Load image from input folder and extract ground-truth labels from xml file 62 | image = cv2.imread(j) 63 | img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 64 | image = img.reshape(1, img.shape[0], img.shape[1], 3) 65 | name_img = os.path.basename(j).split('.')[0] 66 | try: 67 | gr_truth_boxes = get_info(config_xAI['Model']['folder_xml'] + f'{name_img}.xml') 68 | except FileNotFoundError: 69 | gr_truth_boxes = None 70 | 71 | # First stage of model: Extract 300 boxes 72 | if args.stage == 'first_stage': 73 | if args.method in ['eLRP']: 74 | # Run epsilon-LRP explainer 75 | image_rs = sess.run(img_rs, feed_dict={img_input: image}) 76 | baseline = np.zeros_like(image_rs) 77 | gradient = GradientMethod(sess, img_rs, output_tensor, explainer, baseline) 78 | image_dict[args.method] = gradient(image, args.method, img_input) 79 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), image_dict[args.method]) 80 | save_image(image_dict, os.path.basename(j), args.output_path, index='full_image') 81 | 82 | elif args.method in ['GradCAM', 'GradCAM++']: 83 | y_p_boxes, y_p_num_detections = sess.run([detection_boxes, num_detections], 84 | feed_dict={img_input: image}) 85 | boxs = [] 86 | for i in range(int(y_p_num_detections[0])): 87 | h_img, w_img = image.shape[1:3] 88 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 89 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 90 | boxs.append([x1, y1, x2, y2]) 91 | if args.method == 'GradCAM': 92 | # Run GradCAM for each class 93 | gradcam = GradCAM(sess, last_conv_tensor, output_tensor) 94 | mask = gradcam(image, grads, img_input, args.stage, y_p_boxes) 95 | # Save image and heatmap 96 | image_dict[args.method], _ = gen_cam(img, mask, gr_truth_boxes, threshold, boxs) 97 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), mask) 98 | save_image(image_dict, os.path.basename(j), args.output_path, 99 | index='gradcam_first_stage_full_image') 100 | else: 101 | # Run GradCAM++ for each class 102 | grad_cam_plus_plus = GradCAMPlusPlus(sess, last_conv_tensor, output_tensor) 103 | mask_plus_plus = grad_cam_plus_plus(image, grads, img_input, args.stage, y_p_boxes) # cam mask 104 | # Save image and heatmap 105 | image_dict[args.method], _ = gen_cam(img, mask_plus_plus, gr_truth_boxes, threshold, boxs) 106 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), mask_plus_plus) 107 | save_image(image_dict, os.path.basename(j), args.output_path, 108 | index='gradcam_plus_first_stage_full_image') 109 | elif args.method == 'AdaSISE': 110 | # Run AdaSISE and save results 111 | adasise = AdaSISE(image=image, sess=sess) 112 | image_cam = adasise.explain(image, img_input) 113 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), image_cam) 114 | image_dict[args.method], _ = gen_cam(img, image_cam, gr_truth_boxes, threshold) 115 | save_image(image_dict, os.path.basename(j), args.output_path, index=f'adasise_full_image') 116 | else: 117 | print('Method not supported for first stage') 118 | pass 119 | 120 | # Main part 121 | # Second stage of model: Detect final boxes containing the nodule(s) 122 | else: 123 | # Extract boxes from session 124 | y_p_boxes, y_p_scores, y_p_num_detections = sess.run([detection_boxes, 125 | detection_scores, 126 | num_detections], 127 | feed_dict={img_input: image}) 128 | 129 | if args.method in ['RISE']: 130 | boxs = [] 131 | grid_size = config_xAI['RISE']['grid_size'] 132 | prob = config_xAI['RISE']['prob'] 133 | index = config_xAI['RISE']['index'] 134 | assert y_p_scores[0][index] > args.threshold 135 | 136 | for i in range(int(y_p_num_detections[0])): 137 | h_img, w_img = image.shape[1:3] 138 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 139 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 140 | boxs.append([x1, y1, x2, y2]) 141 | 142 | rise = RISE(image=image, sess=sess, grid_size=grid_size, prob=prob, num_samples=num_sample) 143 | rs = rise.explain(image, index, img_input, detection_boxes, detection_scores, num_detections, 144 | detection_classes)[0] 145 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), rs) 146 | image_dict[args.method], _ = gen_cam(img, rs, gr_truth_boxes, threshold, boxs) 147 | save_image(image_dict, os.path.basename(j), args.output_path, index=f'rise_box{index}') 148 | 149 | elif args.method in ['LIME']: 150 | index = config_xAI['LIME']['index'] 151 | num_features = config_xAI['LIME']['num_features'] 152 | feature_view = 1 153 | lime = LIME(sess, img_input=img_input, detection_scores=detection_scores, image=img, indices=index, num_features=num_features) 154 | image_dict[args.method] = lime.explain(feature_view, num_samples=num_sample) 155 | # ------------saliency map for LIME -------------- 156 | cam_map = np.zeros(image.shape[1:3]) 157 | for k, v in image_dict[args.method].result.local_exp[0]: 158 | cam_map[image_dict[args.method].segments == k] = v 159 | # save results 160 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), cam_map) 161 | save_image(image_dict, os.path.basename(j), args.output_path, index=f'rise_box{index}') 162 | 163 | elif args.method in ['GradCAM', 'GradCAM++']: 164 | index = config_xAI['CAM'][args.stage]['index'] 165 | assert y_p_scores[0][index] > args.threshold 166 | NMS_tensor = sess.run(NMS, feed_dict={img_input: image}) 167 | indices = NMS_tensor[index] 168 | grads = tf.gradients(output_tensor[0][indices][1], last_conv_tensor)[0] 169 | 170 | if args.method == 'GradCAM': 171 | # Run GradCAM and save results 172 | gradcam = GradCAM(sess, last_conv_tensor, output_tensor) 173 | mask, x1, y1, x2, y2 = gradcam(image, 174 | grads, 175 | args.stage, 176 | y_p_boxes, 177 | indices=indices, 178 | index=index, 179 | y_p_boxes=y_p_boxes) # cam mask 180 | # Save image and heatmap 181 | image_dict['predict_box'] = img[y1:y2, x1:x2] # [H, W, C] 182 | image_dict[args.method], _ = gen_cam(img[y1:y2, x1:x2], mask, gr_truth_boxes, threshold) 183 | save_image(image_dict, os.path.basename(j), args.output_path, index=f'gradcam_2th_stage_box{index}') 184 | else: 185 | # Run GradCAM++ and save results 186 | grad_cam_plus_plus = GradCAMPlusPlus(sess, last_conv_tensor, output_tensor) 187 | mask_plus_plus, x1, y1, x2, y2 = grad_cam_plus_plus(image, 188 | grads, 189 | args.stage, 190 | y_p_boxes, 191 | indices=indices, 192 | index=index, 193 | y_p_boxes=y_p_boxes) # cam mask 194 | # Save image and heatmap 195 | image_dict[args.method], _ = gen_cam(img[y1:y2, x1:x2], mask_plus_plus, gr_truth_boxes, threshold) 196 | save_image(image_dict, os.path.basename(j), args.output_path, 197 | index=f'gradcam_plus_2th_stage_box{index}') 198 | elif args.method == 'DRISE': 199 | # Run DRISE and save results 200 | drise = DRISE(image=image, sess=sess, grid_size=8, prob=0.4, num_samples=100) 201 | rs = drise.explain(image, 202 | img_input, 203 | y_p_boxes, 204 | y_p_num_detections, 205 | detection_boxes, 206 | detection_scores, 207 | num_detections, 208 | detection_classes) 209 | boxs = [] 210 | for i in range(int(y_p_num_detections[0])): 211 | h_img, w_img = image.shape[1:3] 212 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 213 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 214 | boxs.append([x1, y1, x2, y2]) 215 | rs[0] -= np.min(rs[0]) 216 | rs[0] /= (np.max(rs[0]) - np.min(rs[0])) 217 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), rs) 218 | image_dict[args.method], _ = gen_cam(img, rs[0], gr_truth_boxes, threshold, boxs) 219 | save_image(image_dict, os.path.basename(j), args.output_path, index='drise_result') 220 | elif args.method == 'KDE': 221 | # Run KDE and save results 222 | all_box = None 223 | kde = KDE(sess, image, j, y_p_num_detections, y_p_boxes) 224 | box, box_predicted = kde.get_box_predicted(img_input) 225 | kernel, f = kde.get_kde_map(box) 226 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), f.T) 227 | kde_score = 1 / kde.get_kde_score(kernel, box_predicted) # Compute KDE score 228 | print('kde_score:', kde_score) 229 | for i in range(300): 230 | all_box = draw(image, boxs=[box[i]]) 231 | kde.show_kde_map(box_predicted, f, save_file=args.output_path) 232 | elif args.method == 'DensityMap': 233 | # Run DensityMap and save results 234 | density_map = DensityMap(sess, image, j) 235 | heatmap = density_map.explain(img_input, y_p_num_detections, y_p_boxes) 236 | np.save(os.path.join(args.output_numpy, f"{args.method}_{name_img}.npy"), heatmap) 237 | cv2.imwrite(os.path.join("/results/", f'{name_img}.jpg'), img_as_ubyte(heatmap)) 238 | 239 | 240 | if __name__ == '__main__': 241 | arguments = get_parser() 242 | main(arguments) 243 | print(f'Total training time: {datetime.now() - start}') -------------------------------------------------------------------------------- /model/config/model_config.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "model_name":"frcnn_inception_v2", 4 | "model_id":"model1", 5 | "model_path":"model/src/frozen_inference_graph.pb" 6 | } 7 | ] 8 | -------------------------------------------------------------------------------- /model/config/server_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "rmq_server":"172.27.169.162", 3 | "rmq_port":5672, 4 | "rmq_user":"tienbt", 5 | "rmq_password":"admin123", 6 | "rmq_virtual_host":"/", 7 | "rmq_source_queue":"queue.model.input", 8 | "rmq_completed_exchange":"DataClasses.Messages:FrameCompleted" 9 | } 10 | -------------------------------------------------------------------------------- /model/src/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/model/src/.gitkeep -------------------------------------------------------------------------------- /model/src/RMQ.py: -------------------------------------------------------------------------------- 1 | import pika 2 | 3 | 4 | class BasicRMQClient: 5 | """A basic RabbitMQ client handling connection failures""" 6 | 7 | def __init__(self, server, port, user, password, virtual_host='/'): 8 | self.channel = None 9 | self.connection = None 10 | self.server = server 11 | self.port = port 12 | self.user = user 13 | self.password = password 14 | self.virtual_host = virtual_host 15 | 16 | # publishes a message to the specified exchange on the active channel 17 | @staticmethod 18 | def publish_exchange(channel, exchange, body, routing_key=''): 19 | channel.basic_publish(exchange=exchange, body=body, routing_key=routing_key) 20 | 21 | # processes message from RabbitMQ 22 | def process(self, callback_on_message, source_queue): 23 | # define our connection parameters 24 | creds = pika.PlainCredentials(self.user, self.password) 25 | connection_params = pika.ConnectionParameters(host=self.server, 26 | port=self.port, 27 | virtual_host=self.virtual_host, 28 | credentials=creds) 29 | # Connect to RMQ and wait until a message is received 30 | while True: 31 | try: 32 | print("Connecting to %s" % self.server) 33 | self.connection = pika.BlockingConnection(connection_params) 34 | 35 | # create channel and a queue bound to the source exchange 36 | self.channel = self.connection.channel() 37 | self.channel.basic_qos(prefetch_count=1) 38 | self.channel.basic_consume( 39 | queue=source_queue, on_message_callback=callback_on_message, auto_ack=False) 40 | 41 | # print(' [*] Waiting for messages. To exit press CTRL+C') 42 | try: 43 | self.channel.start_consuming() 44 | except KeyboardInterrupt: 45 | self.channel.stop_consuming() 46 | self.connection.close() 47 | break 48 | # Recover from server-initiated connection closure - handles manual RMQ restarts 49 | except pika.exceptions.ConnectionClosedByBroker: 50 | continue 51 | # Do not recover on channel errors 52 | except pika.exceptions.AMQPChannelError as err: 53 | print("Channel error: {}, stopping...".format(err)) 54 | break 55 | # Recover on all other connection errors 56 | except pika.exceptions.AMQPConnectionError: 57 | continue 58 | -------------------------------------------------------------------------------- /model/src/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "predictions": 15, 3 | "confidence": 10, 4 | "inference_engine_name": "tensorflow_detection", 5 | "framework": "tensorflow", 6 | "type": "detection", 7 | "network": "frc", 8 | "number_of_classes": 1 9 | } -------------------------------------------------------------------------------- /model/src/image.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hungntt/xai_thyroid/4fe1392d9123b844d8947a161a44cb79ef876435/model/src/image.jpg -------------------------------------------------------------------------------- /model/src/model.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import json 3 | import cv2 4 | import tensorflow.compat.v1 as tf 5 | import time 6 | import warnings 7 | import logging 8 | from datetime import datetime 9 | from RMQ import BasicRMQClient 10 | 11 | now = datetime.now() 12 | warnings.filterwarnings('ignore') 13 | 14 | logname = '/logs/log-{}.log'.format(now.strftime("%Y-%m-%d")) 15 | logging.basicConfig(filename=logname, 16 | filemode='w', 17 | format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s', 18 | datefmt='%H:%M:%S', 19 | level=logging.INFO) 20 | logging.info('=' * 10 + ' LOG FILE FOR ' + '=' * 10) 21 | 22 | 23 | def get_config(path_config): 24 | with open(path_config, 'r') as fin: 25 | config_model = json.load(fin) 26 | return config_model 27 | 28 | 29 | # default_model = None 30 | config_models = get_config('/config/model_config.json') 31 | config_rmq = get_config('/config/server_config.json') 32 | 33 | rmq_server = config_rmq['rmq_server'] 34 | rmq_port = config_rmq['rmq_port'] 35 | rmq_user = config_rmq['rmq_user'] 36 | rmq_password = config_rmq['rmq_password'] 37 | 38 | rmq_virtual_host = config_rmq['rmq_virtual_host'] 39 | rmq_source_queue = config_rmq['rmq_source_queue'] 40 | rmq_completed_exchange = config_rmq['rmq_completed_exchange'] 41 | 42 | 43 | def get_model(model_path): 44 | graph = tf.Graph() 45 | with graph.as_default(): 46 | with tf.gfile.GFile(model_path, 'rb') as file: 47 | graph_def = tf.GraphDef() 48 | graph_def.ParseFromString(file.read()) 49 | tf.import_graph_def(graph_def, name='') 50 | 51 | img_input = graph.get_tensor_by_name('image_tensor:0') 52 | detection_boxes = graph.get_tensor_by_name('detection_boxes:0') 53 | detection_scores = graph.get_tensor_by_name('detection_scores:0') 54 | num_detections = graph.get_tensor_by_name('num_detections:0') 55 | detection_classes = graph.get_tensor_by_name('detection_classes:0') 56 | 57 | sess = tf.Session(graph=graph) 58 | return sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes 59 | 60 | 61 | def draw(image, boxs, color=(255, 0, 0), thickness=2, predict=False): 62 | for b in boxs: 63 | logging.info(b) 64 | start_point, end_point = (b[1], b[2]), (b[3], b[4]) 65 | image = cv2.rectangle(image, start_point, end_point, color, thickness) 66 | return image 67 | 68 | 69 | def detector(file_path, model_id, score_threshold=0.3, is_draw=False): 70 | img = cv2.imread(file_path, 1) # 1 is color 71 | img_original = img.copy() 72 | h_img, w_img = img.shape[:2] 73 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 74 | image = img.reshape(1, img.shape[0], img.shape[1], 3) # reshape of (1,w,h,3). Channels = 3 75 | boxs = [] 76 | global default_model 77 | if model_id != default_model: 78 | for model in config_models: 79 | if model['model_id'] == model_id: 80 | logging.info('======== Switch mode {} ========'.format(model_id)) 81 | global sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes 82 | sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes = get_model( 83 | model['model_path']) 84 | feed_dict = {img_input: image} 85 | y_p_boxes, y_p_scores, y_p_num_detections, y_p_classes = sess.run( 86 | [detection_boxes, detection_scores, num_detections, detection_classes], 87 | feed_dict=feed_dict) 88 | default_model = model_id 89 | break 90 | else: 91 | feed_dict = {img_input: image} 92 | y_p_boxes, y_p_scores, y_p_num_detections, y_p_classes = sess.run( 93 | [detection_boxes, detection_scores, num_detections, detection_classes], 94 | feed_dict=feed_dict) 95 | 96 | for i in range(int(y_p_num_detections[0])): 97 | 98 | if y_p_scores[0][i] > score_threshold: 99 | logging.info(y_p_classes[0][i]) 100 | logging.info(y_p_scores[0][i]) 101 | logging.info(y_p_boxes[0][i]) 102 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 103 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 104 | boxs.append((y_p_scores[0][i], x1, y1, x2, y2)) 105 | if is_draw: 106 | img_draw = draw(image=img_original, boxs=boxs, predict=True) 107 | return boxs, img_draw 108 | else: 109 | return boxs, img_original 110 | 111 | 112 | def callback_on_message(ch, method, properties, body): 113 | try: 114 | time_start = time.time() 115 | # byte array to bitmap 116 | str_json = body.decode('utf-8') 117 | data_total = json.loads(str_json.replace("'", '"')) 118 | data = data_total['message'] 119 | file_path = data['image_path'] 120 | model_id = config_models[0]['model_id'] 121 | try: 122 | model_id = data['model_id'] 123 | except: 124 | logging.error('Not selected model') 125 | try: 126 | content = data['content'] 127 | image_64_decode = base64.b64decode(content) 128 | image_result = open('image.jpg', 'wb') # create a writable image and write the decoding result 129 | image_result.write(image_64_decode) 130 | file_path = 'image.jpg' 131 | except: 132 | logging.error('Not encode to base64') 133 | boxs, image_draw = detector(file_path, model_id) 134 | logging.info(boxs) 135 | # Display the results 136 | data = { 137 | "bounding_boxes": [], 138 | "image_path": data['image_path'], 139 | "success": "true", 140 | 'image_id': data['image_id'] 141 | } 142 | id = 0 143 | for (conf, left, top, right, bottom) in boxs: 144 | object_detect = { 145 | "ObjectClassName": "thyroid_cancer", 146 | "confidence": conf, 147 | "coordinates": { 148 | "left": left, 149 | "top": top, 150 | "right": right, 151 | "bottom": bottom 152 | }, 153 | "ObjectClassId": id 154 | } 155 | data['bounding_boxes'].append(object_detect) 156 | id += 1 157 | data_total['messageType'] = ['urn:message:{}'.format(rmq_completed_exchange)] 158 | data_total["destinationAddress"] = "rabbitmq://{}/DataClasses.Messages:FrameCompleted".format(rmq_server) 159 | data_total['message'] = data 160 | json_str = str(data_total) 161 | end_time = time.time() 162 | logging.info('=========> Time processing: {}s'.format(end_time - time_start)) 163 | ch.basic_ack(delivery_tag=method.delivery_tag) 164 | rmq_client.publish_exchange(ch, rmq_completed_exchange, json_str) 165 | except: 166 | logging.error('========= Error =============') 167 | 168 | 169 | sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes = get_model( 170 | config_models[0]['model_path']) 171 | default_model = config_models[0]['model_id'] 172 | image = cv2.imread('data/samples/20653934_AI_14_NGUYEN_THI_HUONG_1970_20200529105316651.jpg') 173 | img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 174 | image = img.reshape(1, img.shape[0], img.shape[1], 3) 175 | feed_dict = {img_input: image} 176 | y_p_boxes, y_p_scores, y_p_num_detections, y_p_classes = sess.run( 177 | [detection_boxes, detection_scores, num_detections, detection_classes], 178 | feed_dict=feed_dict) 179 | # Create RMQ client 180 | rmq_client = BasicRMQClient(rmq_server, rmq_port, rmq_user, rmq_password, rmq_virtual_host) 181 | # start processing messages from the rmq_source_queue 182 | rmq_client.process(callback_on_message, rmq_source_queue) 183 | -------------------------------------------------------------------------------- /model/src/object-detection.pbtxt: -------------------------------------------------------------------------------- 1 | item { 2 | id: 1 3 | name: "thyroid_cancer" 4 | } 5 | -------------------------------------------------------------------------------- /model/src/pipeline.config: -------------------------------------------------------------------------------- 1 | model { 2 | faster_rcnn { 3 | num_classes: 1 4 | image_resizer { 5 | keep_aspect_ratio_resizer { 6 | min_dimension: 600 7 | max_dimension: 1024 8 | } 9 | } 10 | feature_extractor { 11 | type: "faster_rcnn_inception_resnet_v2" 12 | first_stage_features_stride: 8 13 | } 14 | first_stage_anchor_generator { 15 | grid_anchor_generator { 16 | height_stride: 8 17 | width_stride: 8 18 | scales: 0.050000001 19 | scales: 0.125 20 | scales: 0.25 21 | scales: 0.5 22 | scales: 1.0 23 | scales: 1.5 24 | aspect_ratios: 0.75 25 | aspect_ratios: 1.0 26 | aspect_ratios: 1.25 27 | aspect_ratios: 1.5 28 | aspect_ratios: 1.75 29 | aspect_ratios: 2.0 30 | } 31 | } 32 | first_stage_atrous_rate: 2 33 | first_stage_box_predictor_conv_hyperparams { 34 | op: CONV 35 | regularizer { 36 | l2_regularizer { 37 | weight: 0.0 38 | } 39 | } 40 | initializer { 41 | truncated_normal_initializer { 42 | stddev: 0.0099999998 43 | } 44 | } 45 | } 46 | first_stage_nms_score_threshold: 0.0 47 | first_stage_nms_iou_threshold: 0.69999999 48 | first_stage_max_proposals: 300 49 | first_stage_localization_loss_weight: 2.0 50 | first_stage_objectness_loss_weight: 1.0 51 | initial_crop_size: 17 52 | maxpool_kernel_size: 1 53 | maxpool_stride: 1 54 | second_stage_box_predictor { 55 | mask_rcnn_box_predictor { 56 | fc_hyperparams { 57 | op: FC 58 | regularizer { 59 | l2_regularizer { 60 | weight: 0.0 61 | } 62 | } 63 | initializer { 64 | variance_scaling_initializer { 65 | factor: 1.0 66 | uniform: true 67 | mode: FAN_AVG 68 | } 69 | } 70 | } 71 | use_dropout: true 72 | dropout_keep_probability: 0.9999001 73 | } 74 | } 75 | second_stage_post_processing { 76 | batch_non_max_suppression { 77 | score_threshold: 0.30000001 78 | iou_threshold: 0.60000002 79 | max_detections_per_class: 100 80 | max_total_detections: 100 81 | } 82 | score_converter: SOFTMAX 83 | } 84 | second_stage_localization_loss_weight: 2.0 85 | second_stage_classification_loss_weight: 1.0 86 | } 87 | } 88 | train_config { 89 | batch_size: 1 90 | data_augmentation_options { 91 | random_horizontal_flip { 92 | } 93 | } 94 | optimizer { 95 | momentum_optimizer { 96 | learning_rate { 97 | manual_step_learning_rate { 98 | initial_learning_rate: 0.00010000001 99 | schedule { 100 | step: 80000 101 | learning_rate: 5.0000002e-05 102 | } 103 | } 104 | } 105 | momentum_optimizer_value: 0.89999998 106 | } 107 | use_moving_average: false 108 | } 109 | gradient_clipping_by_norm: 10.0 110 | fine_tune_checkpoint: "/weights/frcnn_inception_resnet_v2/model.ckpt" 111 | from_detection_checkpoint: true 112 | num_steps: 40000 113 | freeze_variables: "FirstStageBoxPredictor" 114 | } 115 | train_input_reader { 116 | label_map_path: "/training_dir/data/object-detection.pbtxt" 117 | tf_record_input_reader { 118 | input_path: "/training_dir/data/train.record" 119 | } 120 | } 121 | eval_config { 122 | num_examples: 50 123 | max_evals: 10 124 | metrics_set: "coco_detection_metrics" 125 | use_moving_averages: false 126 | retain_original_images: true 127 | } 128 | eval_input_reader { 129 | label_map_path: "/training_dir/data/object-detection.pbtxt" 130 | shuffle: false 131 | num_readers: 1 132 | tf_record_input_reader { 133 | input_path: "/training_dir/data/test.record" 134 | } 135 | } 136 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.7.1 2 | astor==0.7.1 3 | atomicwrites==1.3.0 4 | attrs==19.1.0 5 | certifi==2019.3.9 6 | chardet==3.0.4 7 | Click==7.0 8 | cloudpickle==0.8.1 9 | cycler==0.10.0 10 | dask==1.1.5 11 | decorator==4.4.0 12 | deepexplain @ git+https://github.com/marcoancona/DeepExplain.git@6ac43e729dc8db6ce05ca241dfeb35f1c3eb5b0b 13 | deeplift==0.6.9.0 14 | Flask==1.0.2 15 | Flask-Assets==0.12 16 | Flask-Cors==3.0.7 17 | future==0.17.1 18 | gast==0.2.2 19 | grpcio==1.19.0 20 | h5py==2.10.0 21 | idna==2.8 22 | innvestigate @ git+https://github.com/albermax/innvestigate@e76ec2d85a6d59d56b6bbb6dbbf1efee6cea6166 23 | itsdangerous==1.1.0 24 | Jinja2==2.10 25 | Keras~=2.2.4 26 | Keras-Applications==1.0.7 27 | Keras-Preprocessing==1.0.9 28 | kiwisolver==1.0.1 29 | lime==0.2.0.1 30 | Markdown==3.1 31 | MarkupSafe==1.1.1 32 | matplotlib==3.0.3 33 | mock==2.0.0 34 | more-itertools==7.0.0 35 | networkx==2.2 36 | numpy==1.16.2 37 | pandas 38 | pbr==5.1.3 39 | Pillow==9.3.0 40 | pluggy==0.9.0 41 | protobuf==3.7.1 42 | py==1.8.0 43 | pyparsing==2.3.1 44 | pytest==4.3.1 45 | python-dateutil==2.8.0 46 | pytz==2018.9 47 | PyWavelets==1.0.2 48 | PyYAML==5.1 49 | requests==2.21.0 50 | scikit-image==0.14.2 51 | scikit-learn==0.20.3 52 | scipy==1.2.1 53 | six==1.12.0 54 | tensorboard==1.13.1 55 | tensorflow==1.13.1 56 | tensorflow-estimator==1.13.0 57 | tensorflow-gpu==2.9.3 58 | termcolor==1.1.0 59 | toolz==0.9.0 60 | urllib3==1.24.1 61 | webassets==0.12.1 62 | Werkzeug==0.1 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import argparse 6 | import json 7 | import logging 8 | import os 9 | import sys 10 | import warnings 11 | import cv2 12 | import innvestigate.utils as iutils 13 | import innvestigate.utils.visualizations as ivis 14 | import numpy as np 15 | import tensorflow as tf 16 | import xml.etree.ElementTree as ET 17 | from collections import OrderedDict 18 | from skimage import io, img_as_ubyte 19 | from skimage.util import view_as_windows 20 | from tensorflow.python.framework import ops 21 | from tensorflow.python.ops import nn_grad, math_grad 22 | 23 | SUPPORTED_ACTIVATIONS = [ 24 | 'Relu', 'Elu', 'Sigmoid', 'Tanh', 'Softplus' 25 | ] 26 | 27 | UNSUPPORTED_ACTIVATIONS = [ 28 | 'CRelu', 'Relu6', 'Softsign' 29 | ] 30 | 31 | _ENABLED_METHOD_CLASS = None 32 | _GRAD_OVERRIDE_CHECKFLAG = 0 33 | 34 | 35 | # ----------------------------------------------------------------------------- 36 | # UTILITY FUNCTIONS 37 | # ----------------------------------------------------------------------------- 38 | 39 | 40 | def activation(type): 41 | """ 42 | Returns Tensorflow's activation op, given its type 43 | :param type: string 44 | :return: op 45 | """ 46 | if type not in SUPPORTED_ACTIVATIONS: 47 | warnings.warn('Activation function (%s) not supported' % type) 48 | f = getattr(tf.nn, type.lower()) 49 | return f 50 | 51 | 52 | def original_grad(op, grad): 53 | """ 54 | Return original Tensorflow gradient for an op 55 | :param op: op 56 | :param grad: Tensor 57 | :return: Tensor 58 | """ 59 | if op.type not in SUPPORTED_ACTIVATIONS: 60 | warnings.warn('Activation function (%s) not supported' % op.type) 61 | opname = '_%sGrad' % op.type 62 | if hasattr(nn_grad, opname): 63 | f = getattr(nn_grad, opname) 64 | else: 65 | f = getattr(math_grad, opname) 66 | return f(op, grad) 67 | 68 | 69 | # ----------------------------------------------------------------------------- 70 | # ATTRIBUTION METHODS BASE CLASSES 71 | # ----------------------------------------------------------------------------- 72 | 73 | 74 | class AttributionMethod(object): 75 | """ 76 | Attribution method base class 77 | """ 78 | 79 | def __init__(self, T, X, session, keras_learning_phase=None): 80 | self.T = T # target Tensor 81 | self.X = X # input Tensor 82 | self.Y_shape = [None, ] + T.get_shape().as_list()[1:] 83 | # Most often T contains multiple output units. In this case, it is often necessary to select 84 | # a single unit to compute contributions for. This can be achieved passing 'ys' as weight for the output Tensor. 85 | self.Y = tf.placeholder(tf.float32, self.Y_shape) 86 | # placeholder_from_data(ys) if ys is not None else 1.0 # Tensor that represents weights for T 87 | self.T = self.T * self.Y 88 | self.symbolic_attribution = None 89 | self.session = session 90 | self.keras_learning_phase = keras_learning_phase 91 | self.has_multiple_inputs = type(self.X) is list or type(self.X) is tuple 92 | logging.info('Model with multiple inputs: %s' % self.has_multiple_inputs) 93 | 94 | # References 95 | self._init_references() 96 | 97 | # Create symbolic explanation once during construction (affects only gradient-based methods) 98 | self.explain_symbolic() 99 | 100 | def explain_symbolic(self): 101 | return None 102 | 103 | def run(self, xs, ys=None, batch_size=None): 104 | pass 105 | 106 | def _init_references(self): 107 | pass 108 | 109 | def _check_input_compatibility(self, xs, ys=None, batch_size=None): 110 | if ys is not None: 111 | if not self.has_multiple_inputs and len(xs) != len(ys): 112 | raise RuntimeError( 113 | 'When provided, ys must have the same batch size as xs (xs has batch size {} and ys {})'.format( 114 | len(xs), len(ys))) 115 | elif self.has_multiple_inputs and np.all([len(i) != len(ys) for i in xs]): 116 | raise RuntimeError('When provided, ys must have the same batch size as all elements of xs') 117 | if batch_size is not None and batch_size > 0: 118 | if self.T.shape[0].value is not None and self.T.shape[0].value is not batch_size: 119 | raise RuntimeError('When using batch evaluation, the first dimension of the target tensor ' 120 | 'must be compatible with the batch size. Found %s instead' % self.T.shape[0].value) 121 | if isinstance(self.X, list): 122 | for x in self.X: 123 | if x.shape[0].value is not None and x.shape[0].value is not batch_size: 124 | raise RuntimeError('When using batch evaluation, the first dimension of the input tensor ' 125 | 'must be compatible with the batch size. Found %s instead' % x.shape[ 126 | 0].value) 127 | else: 128 | if self.X.shape[0].value is not None and self.X.shape[0].value is not batch_size: 129 | raise RuntimeError('When using batch evaluation, the first dimension of the input tensor ' 130 | 'must be compatible with the batch size. Found %s instead' % self.X.shape[ 131 | 0].value) 132 | 133 | def _session_run_batch(self, T, xs, ys=None): 134 | feed_dict = {} 135 | if self.has_multiple_inputs: 136 | for k, v in zip(self.X, xs): 137 | feed_dict[k] = v 138 | else: 139 | feed_dict[self.X] = xs 140 | 141 | # If ys is not passed, produce a vector of ones that will be broadcasted to all batch samples 142 | feed_dict[self.Y] = ys if ys is not None else np.ones([1, ] + self.Y_shape[1:]) 143 | 144 | if self.keras_learning_phase is not None: 145 | feed_dict[self.keras_learning_phase] = 0 146 | return self.session.run(T, feed_dict) 147 | 148 | def _session_run(self, T, xs, ys=None, batch_size=None): 149 | num_samples = len(xs) 150 | if self.has_multiple_inputs is True: 151 | num_samples = len(xs[0]) 152 | if len(xs) != len(self.X): 153 | raise RuntimeError('List of input tensors and input data have different lengths (%s and %s)' 154 | % (str(len(xs)), str(len(self.X)))) 155 | if batch_size is not None: 156 | for xi in xs: 157 | if len(xi) != num_samples: 158 | raise RuntimeError('Evaluation in batches requires all inputs to have ' 159 | 'the same number of samples') 160 | 161 | if batch_size is None or batch_size <= 0 or num_samples <= batch_size: 162 | return self._session_run_batch(T, xs, ys) 163 | else: 164 | outs = [] 165 | batches = make_batches(num_samples, batch_size) 166 | for batch_index, (batch_start, batch_end) in enumerate(batches): 167 | # Get a batch from data 168 | xs_batch = slice_arrays(xs, batch_start, batch_end) 169 | # If the target tensor has one entry for each sample, we need to batch it as well 170 | ys_batch = None 171 | if ys is not None: 172 | ys_batch = slice_arrays(ys, batch_start, batch_end) 173 | batch_outs = self._session_run_batch(T, xs_batch, ys_batch) 174 | batch_outs = to_list(batch_outs) 175 | if batch_index == 0: 176 | # Pre-allocate the results arrays. 177 | for batch_out in batch_outs: 178 | shape = (num_samples,) + batch_out.shape[1:] 179 | outs.append(np.zeros(shape, dtype=batch_out.dtype)) 180 | for i, batch_out in enumerate(batch_outs): 181 | outs[i][batch_start:batch_end] = batch_out 182 | return unpack_singleton(outs) 183 | 184 | 185 | class GradientBasedMethod(AttributionMethod): 186 | """ 187 | Base class for gradient-based attribution methods 188 | """ 189 | 190 | def get_symbolic_attribution(self): 191 | return tf.gradients(self.T, self.X) 192 | 193 | def explain_symbolic(self): 194 | if self.symbolic_attribution is None: 195 | self.symbolic_attribution = self.get_symbolic_attribution() 196 | return self.symbolic_attribution 197 | 198 | def run(self, xs, ys=None, batch_size=None): 199 | self._check_input_compatibility(xs, ys, batch_size) 200 | results = self._session_run(self.explain_symbolic(), xs, ys, batch_size) 201 | return results[0] if not self.has_multiple_inputs else results 202 | 203 | @classmethod 204 | def nonlinearity_grad_override(cls, op, grad): 205 | return original_grad(op, grad) 206 | 207 | 208 | class PerturbationBasedMethod(AttributionMethod): 209 | """ 210 | Base class for perturbation-based attribution methods 211 | """ 212 | 213 | def __init__(self, T, X, session, keras_learning_phase): 214 | super(PerturbationBasedMethod, self).__init__(T, X, session, keras_learning_phase) 215 | self.base_activation = None 216 | 217 | def get_symbolic_attribution(self): 218 | return tf.gradients(self.T, self.X) 219 | 220 | def explain_symbolic(self): 221 | if self.symbolic_attribution is None: 222 | self.symbolic_attribution = self.get_symbolic_attribution() 223 | return self.symbolic_attribution 224 | 225 | def run(self, xs, ys=None, batch_size=None): 226 | self._check_input_compatibility(xs, ys, batch_size) 227 | results = self._session_run(self.explain_symbolic(), xs, ys, batch_size) 228 | return results[0] if not self.has_multiple_inputs else results 229 | 230 | @classmethod 231 | def nonlinearity_grad_override(cls, op, grad): 232 | return original_grad(op, grad) 233 | 234 | 235 | # ----------------------------------------------------------------------------- 236 | # ATTRIBUTION METHODS 237 | # ----------------------------------------------------------------------------- 238 | 239 | 240 | class DummyZero(GradientBasedMethod): 241 | """ 242 | Returns zero attributions. For testing only. 243 | """ 244 | 245 | def get_symbolic_attribution(self, ): 246 | return tf.gradients(self.T, self.X) 247 | 248 | @classmethod 249 | def nonlinearity_grad_override(cls, op, grad): 250 | input = op.inputs[0] 251 | return tf.zeros_like(input) 252 | 253 | 254 | class Saliency(GradientBasedMethod): 255 | """ 256 | Saliency maps 257 | https://arxiv.org/abs/1312.6034 258 | """ 259 | 260 | def get_symbolic_attribution(self): 261 | return [tf.abs(g) for g in tf.gradients(self.T, self.X)] 262 | 263 | 264 | class GradientXInput(GradientBasedMethod): 265 | """ 266 | Gradient * Input 267 | https://arxiv.org/pdf/1704.02685.pdf - https://arxiv.org/abs/1611.07270 268 | """ 269 | 270 | def get_symbolic_attribution(self): 271 | return [g * x for g, x in zip( 272 | tf.gradients(self.T, self.X), 273 | self.X if self.has_multiple_inputs else [self.X])] 274 | 275 | 276 | class IntegratedGradients(GradientBasedMethod): 277 | """ 278 | Integrated Gradients 279 | https://arxiv.org/pdf/1703.01365.pdf 280 | """ 281 | 282 | def __init__(self, T, X, session, keras_learning_phase, steps=100, baseline=None): 283 | self.steps = steps 284 | self.baseline = baseline 285 | super(IntegratedGradients, self).__init__(T, X, session, keras_learning_phase) 286 | 287 | def run(self, xs, ys=None, batch_size=None): 288 | self._check_input_compatibility(xs, ys, batch_size) 289 | 290 | gradient = None 291 | for alpha in list(np.linspace(1. / self.steps, 1.0, self.steps)): 292 | xs_mod = [b + (x - b) * alpha for x, b in zip(xs, self.baseline)] if self.has_multiple_inputs \ 293 | else self.baseline + (xs - self.baseline) * alpha 294 | _attr = self._session_run(self.explain_symbolic(), xs_mod, ys, batch_size) 295 | if gradient is None: 296 | gradient = _attr 297 | else: 298 | gradient = [g + a for g, a in zip(gradient, _attr)] 299 | 300 | results = [g * (x - b) / self.steps for g, x, b in zip( 301 | gradient, 302 | xs if self.has_multiple_inputs else [xs], 303 | self.baseline if self.has_multiple_inputs else [self.baseline])] 304 | 305 | return results[0] if not self.has_multiple_inputs else results 306 | 307 | 308 | class EpsilonLRP(GradientBasedMethod): 309 | """ 310 | Layer-wise Relevance Propagation with epsilon rule 311 | http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0130140 312 | """ 313 | 314 | def __init__(self, T, X, session, keras_learning_phase, epsilon=1e-4): 315 | assert epsilon > 0.0, 'LRP epsilon must be greater than zero' 316 | global eps 317 | eps = epsilon 318 | super(EpsilonLRP, self).__init__(T, X, session, keras_learning_phase) 319 | 320 | def get_symbolic_attribution(self): 321 | return [g * x for g, x in zip( 322 | tf.gradients(self.T, self.X), 323 | self.X if self.has_multiple_inputs else [self.X])] 324 | 325 | @classmethod 326 | def nonlinearity_grad_override(cls, op, grad): 327 | output = op.outputs[0] 328 | input = op.inputs[0] 329 | return grad * output / (input + eps * tf.where(input >= 0, tf.ones_like(input), -1 * tf.ones_like(input))) 330 | 331 | 332 | class DeepLIFTRescale(GradientBasedMethod): 333 | """ 334 | DeepLIFT 335 | This reformulation only considers the "Rescale" rule 336 | https://arxiv.org/abs/1704.02685 337 | """ 338 | _deeplift_ref = {} 339 | 340 | def __init__(self, T, X, session, keras_learning_phase, baseline=None): 341 | self.baseline = baseline 342 | super(DeepLIFTRescale, self).__init__(T, X, session, keras_learning_phase) 343 | 344 | def get_symbolic_attribution(self): 345 | return [g * (x - b) for g, x, b in zip( 346 | tf.gradients(self.T, self.X), 347 | self.X if self.has_multiple_inputs else [self.X], 348 | self.baseline if self.has_multiple_inputs else [self.baseline])] 349 | 350 | @classmethod 351 | def nonlinearity_grad_override(cls, op, grad): 352 | output = op.outputs[0] 353 | input = op.inputs[0] 354 | ref_input = cls._deeplift_ref[op.name] 355 | ref_output = activation(op.type)(ref_input) 356 | delta_out = output - ref_output 357 | delta_in = input - ref_input 358 | instant_grad = activation(op.type)(0.5 * (ref_input + input)) 359 | return tf.where(tf.abs(delta_in) > 1e-5, grad * delta_out / delta_in, 360 | original_grad(instant_grad.op, grad)) 361 | 362 | def _init_references(self): 363 | # print ('DeepLIFT: computing references...') 364 | sys.stdout.flush() 365 | self._deeplift_ref.clear() 366 | ops = [] 367 | g = tf.get_default_graph() 368 | for op in g.get_operations(): 369 | if len(op.inputs) > 0 and not op.name.startswith('gradients'): 370 | if op.type in SUPPORTED_ACTIVATIONS: 371 | ops.append(op) 372 | YR = self._session_run([o.inputs[0] for o in ops], self.baseline) 373 | for (r, op) in zip(YR, ops): 374 | self._deeplift_ref[op.name] = r 375 | # print('DeepLIFT: references ready') 376 | sys.stdout.flush() 377 | 378 | 379 | class Occlusion(PerturbationBasedMethod): 380 | """ 381 | Occlusion method 382 | Generalization of the grey-box method presented in https://arxiv.org/pdf/1311.2901.pdf 383 | This method performs a systematic perturbation of contiguous hyperpatches in the input, 384 | replacing each patch with a user-defined value (by default 0). 385 | 386 | window_shape : integer or tuple of length xs_ndim 387 | Defines the shape of the elementary n-dimensional orthotope the rolling window view. 388 | If an integer is given, the shape will be a hypercube of sidelength given by its value. 389 | 390 | step : integer or tuple of length xs_ndim 391 | Indicates step size at which extraction shall be performed. 392 | If integer is given, then the step is uniform in all dimensions. 393 | """ 394 | 395 | def __init__(self, T, X, session, keras_learning_phase, window_shape=None, step=None): 396 | super(Occlusion, self).__init__(T, X, session, keras_learning_phase) 397 | if self.has_multiple_inputs: 398 | raise RuntimeError('Multiple inputs not yet supported for perturbation methods') 399 | 400 | input_shape = X[0].get_shape().as_list() 401 | if window_shape is not None: 402 | assert len(window_shape) == len(input_shape), \ 403 | 'window_shape must have length of input (%d)' % len(input_shape) 404 | self.window_shape = tuple(window_shape) 405 | else: 406 | self.window_shape = (1,) * len(input_shape) 407 | 408 | if step is not None: 409 | assert isinstance(step, int) or len(step) == len(input_shape), \ 410 | 'step must be integer or tuple with the length of input (%d)' % len(input_shape) 411 | self.step = step 412 | else: 413 | self.step = 1 414 | self.replace_value = 0.0 415 | logging.info('Input shape: %s; window_shape %s; step %s' % (input_shape, self.window_shape, self.step)) 416 | 417 | def run(self, xs, ys=None, batch_size=None): 418 | self._check_input_compatibility(xs, ys, batch_size) 419 | input_shape = xs.shape[1:] 420 | batch_size = xs.shape[0] 421 | total_dim = np.asscalar(np.prod(input_shape)) 422 | 423 | # Create mask 424 | index_matrix = np.arange(total_dim).reshape(input_shape) 425 | idx_patches = view_as_windows(index_matrix, self.window_shape, self.step).reshape((-1,) + self.window_shape) 426 | heatmap = np.zeros_like(xs, dtype=np.float32).reshape((-1), total_dim) 427 | w = np.zeros_like(heatmap) 428 | 429 | # Compute original output 430 | eval0 = self._session_run(self.T, xs, ys, batch_size) 431 | 432 | # Start perturbation loop 433 | for i, p in enumerate(idx_patches): 434 | mask = np.ones(input_shape).flatten() 435 | mask[p.flatten()] = self.replace_value 436 | masked_xs = mask.reshape((1,) + input_shape) * xs 437 | delta = eval0 - self._session_run(self.T, masked_xs, ys, batch_size) 438 | delta_aggregated = np.sum(delta.reshape((batch_size, -1)), -1, keepdims=True) 439 | heatmap[:, p.flatten()] += delta_aggregated 440 | w[:, p.flatten()] += p.size 441 | 442 | attribution = np.reshape(heatmap / w, xs.shape) 443 | if np.isnan(attribution).any(): 444 | warnings.warn('Attributions generated by Occlusion method contain nans, ' 445 | 'probably because window_shape and step do not allow to cover the all input.') 446 | return attribution 447 | 448 | 449 | class ShapleySampling(PerturbationBasedMethod): 450 | """ 451 | Shapley Value sampling 452 | Computes approximate Shapley Values using "Polynomial calculation of the Shapley value based on sampling", 453 | Castro et al, 2009 (https://www.sciencedirect.com/science/article/pii/S0305054808000804) 454 | 455 | samples : integer (default 5) 456 | Defined the number of samples for each input feature. 457 | Notice that evaluating a model samples * n_input_feature times might take a while. 458 | 459 | sampling_dims : list of dimension indexes to run sampling on (feature dimensions). 460 | By default, all dimensions except the batch dimension will be sampled. 461 | For example, with a 4-D tensor that contains color images, single color channels are sampled. 462 | To sample pixels, instead, use sampling_dims=[1,2] 463 | """ 464 | 465 | def __init__(self, T, X, session, keras_learning_phase, samples=5, sampling_dims=None): 466 | super(ShapleySampling, self).__init__(T, X, session, keras_learning_phase) 467 | if self.has_multiple_inputs: 468 | raise RuntimeError('Multiple inputs not yet supported for perturbation methods') 469 | dims = len(X.shape) 470 | if sampling_dims is not None: 471 | if not 0 < len(sampling_dims) <= (dims - 1): 472 | raise RuntimeError('sampling_dims must be a list containing 1 to %d elements' % (dims - 1)) 473 | if 0 in sampling_dims: 474 | raise RuntimeError('Cannot sample batch dimension: remove 0 from sampling_dims') 475 | if any([x < 1 or x > dims - 1 for x in sampling_dims]): 476 | raise RuntimeError('Invalid value in sampling_dims') 477 | else: 478 | sampling_dims = list(range(1, dims)) 479 | 480 | self.samples = samples 481 | self.sampling_dims = sampling_dims 482 | 483 | def run(self, xs, ys=None, batch_size=None): 484 | xs_shape = list(xs.shape) 485 | batch_size = xs.shape[0] 486 | n_features = int(np.asscalar(np.prod([xs.shape[i] for i in self.sampling_dims]))) 487 | result = np.zeros((xs_shape[0], n_features)) 488 | 489 | run_shape = list(xs_shape) # a copy 490 | run_shape = np.delete(run_shape, self.sampling_dims).tolist() 491 | run_shape.insert(1, -1) 492 | 493 | reconstruction_shape = [xs_shape[0]] 494 | for j in self.sampling_dims: 495 | reconstruction_shape.append(xs_shape[j]) 496 | 497 | for r in range(self.samples): 498 | p = np.random.permutation(n_features) 499 | x = xs.copy().reshape(run_shape) 500 | y = None 501 | for i in p: 502 | if y is None: 503 | y = self._session_run(self.T, x.reshape(xs_shape), ys, batch_size) 504 | x[:, i] = 0 505 | y0 = self._session_run(self.T, x.reshape(xs_shape), ys, batch_size) 506 | delta = y - y0 507 | delta_aggregated = np.sum(delta.reshape((batch_size, -1)), -1, keepdims=False) 508 | result[:, i] += delta_aggregated 509 | y = y0 510 | 511 | shapley = result / self.samples 512 | return shapley.reshape(reconstruction_shape) 513 | 514 | 515 | class DeepExplain(object): 516 | def __init__(self, graph=None, session=tf.get_default_session()): 517 | self.method = None 518 | self.batch_size = None 519 | self.session = session 520 | self.graph = session.graph if graph is None else graph 521 | self.graph_context = self.graph.as_default() 522 | self.override_context = self.graph.gradient_override_map(self.get_override_map()) 523 | self.keras_phase_placeholder = None 524 | self.context_on = False 525 | if self.session is None: 526 | raise RuntimeError('DeepExplain: could not retrieve a session. Use DeepExplain(session=your_session).') 527 | 528 | def __enter__(self): 529 | # Override gradient of all ops created in context 530 | self.graph_context.__enter__() 531 | self.override_context.__enter__() 532 | self.context_on = True 533 | return self 534 | 535 | def __exit__(self, type, value, traceback): 536 | self.graph_context.__exit__(type, value, traceback) 537 | self.override_context.__exit__(type, value, traceback) 538 | self.context_on = False 539 | 540 | def get_explainer(self, method, T, X, **kwargs): 541 | if not self.context_on: 542 | raise RuntimeError('Explain can be called only within a DeepExplain context.') 543 | global _ENABLED_METHOD_CLASS, _GRAD_OVERRIDE_CHECKFLAG 544 | self.method = method 545 | if self.method in attribution_methods: 546 | method_class, method_flag = attribution_methods[self.method] 547 | else: 548 | raise RuntimeError('Method must be in %s' % list(attribution_methods.keys())) 549 | if isinstance(X, list): 550 | for x in X: 551 | if 'tensor' not in str(type(x)).lower(): 552 | raise RuntimeError('If a list, X must contain only Tensorflow Tensor objects') 553 | else: 554 | if 'tensor' not in str(type(X)).lower(): 555 | raise RuntimeError('X must be a Tensorflow Tensor object or a list of them') 556 | 557 | if 'tensor' not in str(type(T)).lower(): 558 | raise RuntimeError('T must be a Tensorflow Tensor object') 559 | 560 | logging.info('DeepExplain: running "%s" explanation method (%d)' % (self.method, method_flag)) 561 | self._check_ops() 562 | _GRAD_OVERRIDE_CHECKFLAG = 0 563 | 564 | _ENABLED_METHOD_CLASS = method_class 565 | method = _ENABLED_METHOD_CLASS(T, X, 566 | self.session, 567 | keras_learning_phase=self.keras_phase_placeholder, 568 | **kwargs) 569 | 570 | if issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod) and _GRAD_OVERRIDE_CHECKFLAG == 0: 571 | warnings.warn('DeepExplain detected you are trying to use an attribution method that requires ' 572 | 'gradient override but the original gradient was used instead. You might have forgot to ' 573 | '(re)create your graph within the DeepExlain context. Results are not reliable!') 574 | _ENABLED_METHOD_CLASS = None 575 | _GRAD_OVERRIDE_CHECKFLAG = 0 576 | self.keras_phase_placeholder = None 577 | return method 578 | 579 | def explain(self, method, T, X, xs, ys=None, batch_size=None, **kwargs): 580 | explainer = self.get_explainer(method, T, X, **kwargs) 581 | return explainer.run(xs, ys, batch_size) 582 | 583 | @staticmethod 584 | def get_override_map(): 585 | return dict((a, 'DeepExplainGrad') for a in SUPPORTED_ACTIVATIONS) 586 | 587 | def _check_ops(self): 588 | """ 589 | Heuristically check if any op is in the list of unsupported activation functions. 590 | This does not cover all cases where explanation methods would fail, and must be improved in the future. 591 | Also, check if the placeholder named 'keras_learning_phase' exists in the graph. This is used by Keras 592 | and needs to be passed in feed_dict. 593 | """ 594 | g = tf.get_default_graph() 595 | for op in g.get_operations(): 596 | if len(op.inputs) > 0 and not op.name.startswith('gradients'): 597 | if op.type in UNSUPPORTED_ACTIVATIONS: 598 | warnings.warn('Detected unsupported activation (%s). ' 599 | 'This might lead to unexpected or wrong results.' % op.type) 600 | elif 'keras_learning_phase' in op.name: 601 | self.keras_phase_placeholder = op.outputs[0] 602 | 603 | 604 | class GradientMethod(object): 605 | def __init__(self, session, image_resize, output_tensor, explainer, baseline): 606 | """ 607 | Initialize GradientMethod 608 | :param session: Tensorflow session 609 | :param image_resize: Image resize 610 | :param output_tensor: Output tensor 611 | :param explainer: Explainer 612 | :param baseline: Baseline 613 | """ 614 | self.sess = session 615 | self.img_rs = image_resize 616 | self.output_tensor = output_tensor 617 | self.explainer = explainer 618 | self.baseline = baseline 619 | 620 | def __call__(self, imgs, method, img_input, img_resize=None): 621 | """ 622 | Calculate Gradient Method 623 | :param imgs: Input image 624 | :param method: Choose a method to calculate gradient: IntGrad, DeepLIFT or others 625 | :param img_resize: Resize image 626 | :return: Gradient Method explanation 627 | """ 628 | with DeepExplain(session=self.sess) as de: 629 | if method in ['intgrad', 'deeplift']: 630 | attributions = de.explain(method, 631 | np.sum(self.output_tensor[0, :, 1:2]), 632 | self.img_rs, img_resize, 633 | baseline=self.baseline) 634 | else: 635 | img_resize = self.sess.run(self.img_rs, feed_dict={img_input: imgs}) 636 | attributions = self.explainer.run(img_resize) 637 | analysis = attributions 638 | analysis = iutils.postprocess_images(analysis, 639 | color_coding='BGRtoRGB', 640 | channels_first=False) 641 | analysis = ivis.gamma(analysis, minamp=0, gamma=0.95) 642 | analysis = ivis.heatmap(analysis) 643 | analysis = cv2.resize(analysis[0], dsize=(imgs.shape[2], imgs.shape[1]), interpolation=cv2.INTER_LINEAR) 644 | return analysis 645 | 646 | 647 | # ----------------------------------------------------------------------------- 648 | # END ATTRIBUTION METHODS 649 | # ----------------------------------------------------------------------------- 650 | 651 | 652 | attribution_methods = OrderedDict({ 653 | 'zero': (DummyZero, 0), 654 | 'saliency': (Saliency, 1), 655 | 'grad*input': (GradientXInput, 2), 656 | 'intgrad': (IntegratedGradients, 3), 657 | 'eLRP': (EpsilonLRP, 4), 658 | 'deeplift': (DeepLIFTRescale, 5), 659 | 'occlusion': (Occlusion, 6), 660 | 'shapley_sampling': (ShapleySampling, 7) 661 | }) 662 | 663 | 664 | @ops.RegisterGradient("DeepExplainGrad") 665 | def deepexplain_grad(op, grad): 666 | global _ENABLED_METHOD_CLASS, _GRAD_OVERRIDE_CHECKFLAG 667 | _GRAD_OVERRIDE_CHECKFLAG = 1 668 | if _ENABLED_METHOD_CLASS is not None and issubclass(_ENABLED_METHOD_CLASS, GradientBasedMethod): 669 | return _ENABLED_METHOD_CLASS.nonlinearity_grad_override(op, grad) 670 | else: 671 | return original_grad(op, grad) 672 | 673 | 674 | def make_batches(size, batch_size): 675 | """Returns a list of batch indices (tuples of indices). 676 | # Arguments 677 | size: Integer, total size of the data to slice into batches. 678 | batch_size: Integer, batch size. 679 | # Returns 680 | A list of tuples of array indices. 681 | """ 682 | num_batches = (size + batch_size - 1) // batch_size # round up 683 | return [(i * batch_size, min(size, (i + 1) * batch_size)) 684 | for i in range(num_batches)] 685 | 686 | 687 | def to_list(x, allow_tuple=False): 688 | """Normalizes a list/tensor into a list. 689 | If a tensor is passed, we return list of size 1 containing the tensor. 690 | :param x: target object to be normalized. 691 | :param allow_tuple: If False and x is a tuple, it will be converted into a list with a single element (the tuple). Else converts the tuple to a list. 692 | :return: list of size 1 containing the tensor. 693 | """ 694 | if isinstance(x, list): 695 | return x 696 | if allow_tuple and isinstance(x, tuple): 697 | return list(x) 698 | return [x] 699 | 700 | 701 | def unpack_singleton(x): 702 | """Gets the equivalent np-array if the iterable has only one value. 703 | :param x: a list of tuples. 704 | :return: the same iterable or the iterable converted to a np-array. 705 | """ 706 | if len(x) == 1: 707 | return np.array(x) 708 | return x 709 | 710 | 711 | def slice_arrays(arrays, start=None, stop=None): 712 | """ 713 | Slices an array or list of arrays. 714 | :param arrays: list of arrays to slice. 715 | :param start: int, start index. 716 | :param stop: int, end index. 717 | :return: list of sliced arrays. 718 | """ 719 | if arrays is None: 720 | return [None] 721 | elif isinstance(arrays, list): 722 | return [None if x is None else x[start:stop] for x in arrays] 723 | else: 724 | return arrays[start:stop] 725 | 726 | 727 | def placeholder_from_data(numpy_array): 728 | """ 729 | Creates a placeholder from a numpy array. 730 | :param numpy_array: a numpy array. 731 | :return: a tensorflow placeholder. 732 | """ 733 | if numpy_array is None: 734 | return None 735 | return tf.placeholder('float', [None, ] + list(numpy_array.shape[1:])) 736 | 737 | 738 | def get_info(path): 739 | """ 740 | Get the ground-truth bounding boxes and labels of the image 741 | :param path: Path to the xml file 742 | :return: list of bounding boxes and labels 743 | """ 744 | gr_truth = [] 745 | root = ET.parse(path).getroot() 746 | for type_tag in root.findall('object'): 747 | xmin = int(type_tag.find('bndbox/xmin').text) 748 | ymin = int(type_tag.find('bndbox/ymin').text) 749 | xmax = int(type_tag.find('bndbox/xmax').text) 750 | ymax = int(type_tag.find('bndbox/ymax').text) 751 | gr_truth.append([xmin, ymin, xmax, ymax]) 752 | return gr_truth 753 | 754 | 755 | def bbox_iou(boxA, boxB, x1y1x2y2=False): 756 | """ 757 | Compute the intersection over union by taking the intersection area and dividing it by the sum of prediction + 758 | ground-truth areas - the interesection area 759 | :param boxA: array of shape [4*1] = [x1,y1,x2,y2] 760 | :param boxB: array of shape [4*1] = [x1,y1,x2,y2] 761 | :param x1y1x2y2: if True, interpret box coordinates as [x1,y1,w,h] 762 | :return: IoU 763 | """ 764 | if x1y1x2y2: 765 | my = min(boxA[0], boxB[0]) 766 | My = max(boxA[2], boxB[2]) 767 | mx = min(boxA[1], boxB[1]) 768 | Mx = max(boxA[3], boxB[3]) 769 | h1 = boxA[2] - boxA[0] 770 | w1 = boxA[3] - boxA[1] 771 | h2 = boxB[2] - boxB[0] 772 | w2 = boxB[3] - boxB[1] 773 | uw = Mx - mx 774 | uh = My - my 775 | cw = w1 + w2 - uw 776 | ch = h1 + h2 - uh 777 | 778 | if cw <= 0 or ch <= 0: 779 | return 0.0 780 | 781 | area1 = w1 * h1 782 | area2 = w2 * h2 783 | carea = cw * ch 784 | uarea = area1 + area2 - carea 785 | return carea / uarea 786 | else: 787 | xA = max(boxA[0], boxB[0]) 788 | yA = max(boxA[1], boxB[1]) 789 | xB = min(boxA[2], boxB[2]) 790 | yB = min(boxA[3], boxB[3]) 791 | interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1) 792 | boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1) 793 | boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1) 794 | iou = interArea / float(boxAArea + boxBArea - interArea) 795 | return iou 796 | 797 | 798 | def gen_cam(image, mask, gr_truth_boxes, threshold, boxs=None): 799 | """ 800 | Generate CAM map 801 | :param image: [H,W,C],the original image 802 | :param mask: [H,W], range 0~1 803 | :param gr_truth_boxes: ground-truth bounding boxes 804 | :param threshold: threshold to filter the bounding boxes 805 | :param boxs: [N,4], the bounding boxes 806 | :return: tuple(cam,heatmap) 807 | """ 808 | if boxs is None: 809 | boxs = [[0, 0, 0, 0]] 810 | heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) 811 | # heatmap = np.float32(heatmap) / 255 812 | # heatmap = heatmap[..., ::-1] # gbr to rgb 813 | image_cam = cv2.addWeighted(heatmap, 0.5, image, 0.5, 0) 814 | 815 | image_cam = draw(image_cam, boxs, threshold, gr_truth_boxes) 816 | 817 | # heatmap = np.float32(heatmap) / 255 818 | heatmap = heatmap[..., ::-1] 819 | # image_cam = np.float32(image_cam) / 255 820 | image_cam = image_cam[..., ::-1] 821 | return image_cam, heatmap 822 | 823 | 824 | def draw(image, boxs, threshold=None, gr_truth_boxes=None, color=(0, 255, 0), thickness=2): 825 | """ 826 | Draw bounding boxes on image 827 | :param image: [H,W,C],the original image 828 | :param boxs: [N,4], the bounding boxes 829 | :param threshold: the threshold to filter the bounding boxes 830 | :param gr_truth_boxes: [N,4], the ground-truth bounding boxes 831 | :param color: the color of the bounding boxes 832 | :param thickness: the thickness of the bounding boxes 833 | :return: image with bounding boxes 834 | """ 835 | img_draw = image 836 | if gr_truth_boxes is not None: 837 | for a in boxs: 838 | iou = [] 839 | for b in gr_truth_boxes: 840 | iou.append(bbox_iou(a, b)) 841 | test_iou = any(l > threshold for l in iou) 842 | if test_iou: 843 | color = (255, 0, 0) 844 | img_draw = cv2.rectangle(image, (a[0], a[1]), (a[2], a[3]), color, thickness) 845 | else: 846 | color = (0, 0, 255) 847 | img_draw = cv2.rectangle(image, (a[0], a[1]), (a[2], a[3]), color, thickness) 848 | else: 849 | for b in boxs: 850 | start_point, end_point = (b[0], b[1]), (b[2], b[3]) 851 | image = cv2.rectangle(image, start_point, end_point, color, thickness) 852 | return img_draw 853 | 854 | 855 | def save_image(image_dicts, input_image_name, output_dir, index): 856 | """ 857 | Save output in folder named results 858 | :param image_dicts: Dictionary results 859 | :param input_image_name: Name of original image 860 | :param output_dir: Path to output directory 861 | :param index: Index of image 862 | """ 863 | name_img = os.path.splitext(input_image_name)[0] 864 | for key, image in image_dicts.items(): 865 | io.imsave(os.path.join(output_dir, f'{name_img}-{key}-{index}.jpg'), img_as_ubyte(image)) 866 | 867 | 868 | def get_config(path_config): 869 | """ 870 | Get config from json file 871 | :param path_config: Path to config file 872 | :return: config 873 | """ 874 | with open(path_config, 'r') as fin: 875 | config_xAI = json.load(fin) 876 | return config_xAI 877 | 878 | 879 | def get_model(model_path): 880 | """ 881 | Get model from file 882 | :param model_path: Path to model file 883 | :return: model 884 | """ 885 | graph = tf.Graph() 886 | with graph.as_default(): 887 | with tf.gfile.GFile(model_path, 'rb') as file: 888 | graph_def = tf.GraphDef() 889 | graph_def.ParseFromString(file.read()) 890 | tf.import_graph_def(graph_def, name='') 891 | img_input = graph.get_tensor_by_name('image_tensor:0') 892 | detection_boxes = graph.get_tensor_by_name('detection_boxes:0') 893 | detection_scores = graph.get_tensor_by_name('detection_scores:0') 894 | num_detections = graph.get_tensor_by_name('num_detections:0') 895 | detection_classes = graph.get_tensor_by_name('detection_classes:0') 896 | 897 | sess = tf.Session(graph=graph) 898 | return sess, img_input, detection_boxes, detection_scores, num_detections, detection_classes 899 | 900 | 901 | def get_tensor_mini(sess, layer_name, image, img_input): 902 | """ 903 | Get tensor from mini model 904 | :param sess: Session 905 | :param layer_name: Name of layer 906 | :param image: Image 907 | :param img_input: Input tensor 908 | :return: tensor units 909 | """ 910 | print(layer_name) 911 | layer = sess.graph.get_tensor_by_name(layer_name + ':0') 912 | units = sess.run(layer, feed_dict={img_input: image}) 913 | return units 914 | 915 | 916 | def get_center(box): 917 | center_box_x = np.zeros(len(box)) 918 | center_box_y = np.zeros(len(box)) 919 | for i in range(len(box)): 920 | center_box_x[i] = int((box[i][2] + box[i][0]) / 2) 921 | center_box_y[i] = int((box[i][3] + box[i][1]) / 2) 922 | return center_box_x, center_box_y 923 | 924 | 925 | def softmax(x): 926 | f = np.exp(x) / np.sum(np.exp(x), axis=1, keepdims=True) 927 | return f 928 | 929 | def create_file(path): 930 | """ 931 | Create file/directory if file/directory doesn't exist 932 | """ 933 | if not os.path.exists(path): 934 | os.makedirs(path) 935 | 936 | def energy_point_game(bbox, saliency_map): 937 | """ 938 | Caculate energy-based pointing game evaluation 939 | :param bbox: [N,4], the bounding boxes 940 | :param saliency_map: [H, W], final saliency map 941 | """ 942 | h, w = saliency_map.shape 943 | empty = np.zeros((h, w)) 944 | for b in bbox: 945 | x1, y1, x2, y2 = b 946 | # print(x1, y1, x2, y2, h, w) 947 | x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2) 948 | h, w = saliency_map.shape 949 | empty[y1:y2, x1:x2] = 1 950 | mask_bbox = saliency_map * empty 951 | energy_bbox = mask_bbox.sum() 952 | energy_whole = saliency_map.sum() 953 | proportion = energy_bbox / energy_whole 954 | return proportion 955 | 956 | def bounding_boxes(bboxs, saliency_map): 957 | """ 958 | Caculate bounding boxes evaluation 959 | :param bbox: [N,4], the bounding boxes 960 | :param saliency_map: [H, W], final saliency map 961 | """ 962 | height, width = saliency_map.shape 963 | HW = height*width 964 | area = 0 965 | mask = np.zeros((height, width)) 966 | for bbox in bboxs: 967 | xi, yi, xa, ya = bbox 968 | area += (xa-xi)*(ya-yi) 969 | mask[yi:ya, xi:xa] = 1 970 | sal_order = np.flip(np.argsort(saliency_map.reshape(HW, -1), axis=0), axis=0) 971 | y= sal_order//saliency_map.shape[1] 972 | x = sal_order - y*saliency_map.shape[1] 973 | mask_cam = np.zeros_like(saliency_map) 974 | mask_cam[y[0:area, :], x[0:area, :]] = 1 975 | ratio = (mask*mask_cam).sum()/(area) 976 | return ratio 977 | def IoU(mask, cam_map): 978 | heatmap = cv2.applyColorMap(np.uint8(255 * cam_map), cv2.COLORMAP_JET) 979 | area_mask = np.count_nonzero(mask == 1) 980 | gray = cv2.cvtColor(heatmap, cv2.COLOR_BGR2GRAY) 981 | thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)[1] 982 | # Find contours 983 | cnts = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) 984 | cnts = cnts[0] if len(cnts) == 2 else cnts[1] 985 | mask_cam = np.zeros_like(cam_map) 986 | for c in cnts: 987 | x,y,w,h = cv2.boundingRect(c) 988 | mask_cam[y:y+h, x:x+w] = 1 989 | area_mask_cam = np.count_nonzero(mask_cam == 1) 990 | mask_sum = mask*mask_cam 991 | area_sum = np.count_nonzero(mask_sum) 992 | iou = area_sum/(area_mask + area_mask_cam - area_sum) 993 | return iou 994 | def get_parser(): 995 | """ 996 | Parse command line arguments 997 | :return: parser 998 | """ 999 | parser = argparse.ArgumentParser(description='xAI for thyroid cancer detection') 1000 | parser.add_argument( 1001 | '--config-path', 1002 | default='xAI_config.json', 1003 | metavar='FILE', 1004 | help='path to config file', 1005 | ) 1006 | parser.add_argument('--method', 1007 | help='Choose Backpropagation methods: eLRP, GradCAM, GradCAM++, RISE, LIME, DRISE, KDE, ' 1008 | 'DensityMap, AdaSISE') 1009 | parser.add_argument('--image-path', help='path to input images', default='data/test_images') 1010 | parser.add_argument('--stage', default='first_stage', 1011 | help='Choose a stage to visualize: first_stage or second_stage') 1012 | parser.add_argument('--threshold', default=0.6, type=float, help='Threshold of output values to visualize') 1013 | parser.add_argument('--output-path', help='A file or directory to save output visualizations.') 1014 | parser.add_argument('--output-numpy', help='A file or directory to save saliency map(np.ndarray).') 1015 | return parser.parse_args() -------------------------------------------------------------------------------- /xAI_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "LIME": { 3 | "num_features": 60, 4 | "num_sample": 1000, 5 | "top_label": 1, 6 | "index": 0 7 | }, 8 | "RISE": { 9 | "num_sample": 500, 10 | "grid_size": 8, 11 | "prob": 0.5, 12 | "index": 0 13 | }, 14 | "DRISE": { 15 | "num_sample": 500, 16 | "grid_size": 8, 17 | "prob": 0.5, 18 | "index": 0 19 | }, 20 | "CAM": { 21 | "first_stage": { 22 | "output": "concat_1", 23 | "target": "Conv/Relu6", 24 | "NMS": "BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/non_max_suppression/NonMaxSuppressionV3" 25 | }, 26 | "second_stage": { 27 | "output": "SecondStagePostprocessor/scale_logits", 28 | "target": "SecondStageFeatureExtractor/InceptionResnetV2/Conv2d_7b_1x1/Relu", 29 | "index": 0 , 30 | "NMS": "SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/while/MultiClassNonMaxSuppression/non_max_suppression/NonMaxSuppressionV3" 31 | } 32 | }, 33 | "Gradient": { 34 | "output": "concat_1", 35 | "target": "Preprocessor/sub" 36 | }, 37 | "Model": { 38 | "file_config": "model/config/model_config.json", 39 | "folder_xml": "data/test_annotation/", 40 | "threshold": 0.5, 41 | "box_prediction_true": "blue_color", 42 | "box_prediction_false": "red_color" 43 | } 44 | } -------------------------------------------------------------------------------- /xai/adasise.py: -------------------------------------------------------------------------------- 1 | from re import S 2 | import cv2 3 | import numpy as np 4 | import tensorflow as tf 5 | 6 | from utils import softmax 7 | 8 | 9 | class AdaSISE(object): 10 | def __init__(self, image, sess): 11 | self.image = image 12 | self.sess = sess 13 | self.tensor_names = [n.name for n in self.sess.graph.as_graph_def().node] 14 | self.last_block = [n.name for n in self.sess.graph.as_graph_def().node if 'Pool' in n.name] 15 | self.target_layer = self.sess.graph.get_tensor_by_name('Softmax:0') 16 | 17 | def explain(self, img, img_input): 18 | feed_dict = {img_input: self.image} 19 | 20 | l1 = self.sess.graph.get_tensor_by_name( 21 | 'FirstStageFeatureExtractor/InceptionResnetV2/InceptionResnetV2/Conv2d_2b_3x3/Relu:0') 22 | l2 = self.sess.graph.get_tensor_by_name( 23 | 'FirstStageFeatureExtractor/InceptionResnetV2/InceptionResnetV2/Conv2d_4a_3x3/Relu:0') 24 | l3 = self.sess.graph.get_tensor_by_name( 25 | 'FirstStageFeatureExtractor/InceptionResnetV2/InceptionResnetV2/Mixed_5b/Branch_2/Conv2d_0c_3x3/Relu:0') 26 | l4 = self.sess.graph.get_tensor_by_name( 27 | 'FirstStageFeatureExtractor/InceptionResnetV2/InceptionResnetV2/Mixed_6a/Branch_1/Conv2d_1a_3x3/Relu:0') 28 | l5 = self.sess.graph.get_tensor_by_name('Conv/Relu6:0') 29 | 30 | grad1 = tf.gradients(np.sum(self.target_layer[0, :, 1:2]), l1)[0] 31 | grad2 = tf.gradients(np.sum(self.target_layer[0, :, 1:2]), l2)[0] 32 | grad3 = tf.gradients(np.sum(self.target_layer[0, :, 1:2]), l3)[0] 33 | grad4 = tf.gradients(np.sum(self.target_layer[0, :, 1:2]), l4)[0] 34 | grad5 = tf.gradients(np.sum(self.target_layer[0, :, 1:2]), l5)[0] 35 | 36 | o1, g1 = self.sess.run([l1, grad1], feed_dict=feed_dict) 37 | o2, g2 = self.sess.run([l2, grad2], feed_dict=feed_dict) 38 | o3, g3 = self.sess.run([l3, grad3], feed_dict=feed_dict) 39 | o4, g4 = self.sess.run([l4, grad4], feed_dict=feed_dict) 40 | o5, g5 = self.sess.run([l5, grad5], feed_dict=feed_dict) 41 | 42 | a1 = np.mean(g1[0], axis=(0, 1)) 43 | a2 = np.mean(g2[0], axis=(0, 1)) 44 | a3 = np.mean(g3[0], axis=(0, 1)) 45 | a4 = np.mean(g4[0], axis=(0, 1)) 46 | a5 = np.mean(g5[0], axis=(0, 1)) 47 | 48 | a1_nor = a1 / a1.max() 49 | a2_nor = a2 / a2.max() 50 | a3_nor = a3 / a3.max() 51 | a4_nor = a4 / a4.max() 52 | a5_nor = a5 / a5.max() 53 | 54 | b1 = np.where(a1_nor > 0)[0] 55 | b2 = np.where(a2_nor > 0)[0] 56 | b3 = np.where(a3_nor > 0)[0] 57 | b4 = np.where(a4_nor > 0)[0] 58 | b5 = np.where(a5_nor > 0)[0] 59 | 60 | a1_nor_p = a1_nor[b1] 61 | a2_nor_p = a2_nor[b2] 62 | a3_nor_p = a3_nor[b3] 63 | a4_nor_p = a4_nor[b4] 64 | a5_nor_p = a5_nor[b5] 65 | 66 | th1 = cv2.threshold( 67 | (a1_nor_p * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, 68 | )[0] / 255 69 | 70 | th2 = cv2.threshold( 71 | (a2_nor_p * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, 72 | )[0] / 255 73 | 74 | th3 = cv2.threshold( 75 | (a3_nor_p * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, 76 | )[0] / 255 77 | 78 | th4 = cv2.threshold( 79 | (a4_nor_p * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, 80 | )[0] / 255 81 | 82 | th5 = cv2.threshold( 83 | (a5_nor_p * 255).astype(np.uint8), 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, 84 | )[0] / 255 85 | 86 | b1b = np.where(a1_nor > th1)[0] 87 | b2b = np.where(a2_nor > th2)[0] 88 | b3b = np.where(a3_nor > th3)[0] 89 | b4b = np.where(a4_nor > th4)[0] 90 | b5b = np.where(a5_nor > th5)[0] 91 | 92 | score_saliency_map = [] 93 | 94 | o = [o1, o2, o3, o4, o5] 95 | b = [b1b, b2b, b3b, b4b, b5b] 96 | 97 | for k in range(len(o)): 98 | score_saliency = 0 99 | for j in (b[k]): 100 | md1 = cv2.resize(o[k][0, :, :, j], (img.shape[2], img.shape[1]), interpolation=cv2.INTER_LINEAR) 101 | if md1.max() == md1.min(): 102 | continue 103 | md1 = (md1 - np.min(md1)) / (np.max(md1) - np.min(md1)) 104 | img_md1 = ((img * (md1[None,:, :, None].astype(np.float32))).astype(np.uint8)) 105 | output_md1 = self.sess.run(self.target_layer, feed_dict={img_input: img_md1}) 106 | x = softmax(output_md1[0, :]) 107 | score = np.sum(x[:, 1]) 108 | score_saliency += score * md1 109 | score_saliency_map.append(score_saliency) 110 | heatmap = 0 111 | for i in range(len(score_saliency_map)): 112 | if type(score_saliency_map[i]) == int: 113 | continue 114 | score_saliency_map[i] = (score_saliency_map[i] - score_saliency_map[i].min()) / ( 115 | score_saliency_map[i].max() - score_saliency_map[i].min()) 116 | s = np.array(score_saliency_map[i] * 255, dtype=np.uint8) 117 | block = cv2.threshold(s, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU, )[1] / 255 118 | heatmap += score_saliency_map[i] 119 | if i > 0: 120 | heatmap *= block 121 | return heatmap 122 | -------------------------------------------------------------------------------- /xai/density_map.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | 4 | from utils import get_tensor_mini 5 | 6 | 7 | class DensityMap(object): 8 | def __init__(self, sess, image, j): 9 | self.sess = sess 10 | img_color = cv2.cvtColor(cv2.imread(j), cv2.COLOR_BGR2RGB) 11 | self.image = img_color.reshape((1, img_color.shape[0], img_color.shape[1], 3)) 12 | 13 | def explain(self, img_input, y_p_num_detections, y_p_boxes): 14 | b = [n.name for n in self.sess.graph.as_graph_def().node if 'SecondStageBoxPredictor' in n.name] 15 | post_process = [n.name for n in self.sess.graph.as_graph_def().node if 'SecondStagePostprocessor' in n.name] 16 | 17 | boxes = get_tensor_mini(self.sess, b[18], self.image, img_input) 18 | boxes_post_process = get_tensor_mini(self.sess, post_process[14], self.image, img_input) 19 | 20 | box = [] 21 | h_img, w_img = self.image.shape[:2] 22 | ratio = w_img / h_img 23 | 24 | num_box = int(y_p_num_detections[0]) 25 | box_predicted = [] 26 | for i in range(num_box): 27 | x1, x2 = int(y_p_boxes[0][i][1] * w_img), int(y_p_boxes[0][i][3] * w_img) 28 | y1, y2 = int(y_p_boxes[0][i][0] * h_img), int(y_p_boxes[0][i][2] * h_img) 29 | box_predicted.append([x1, y1, x2, y2]) 30 | 31 | for i in range(300): 32 | x1, x2 = int(boxes_post_process[i][1] * ratio), int(boxes_post_process[i][3] * ratio) 33 | y1, y2 = int(boxes_post_process[i][0] * ratio), int(boxes_post_process[i][2] * ratio) 34 | box.append([x1, y1, x2, y2]) 35 | 36 | h_image, w_image = self.image.shape[0], self.image.shape[1] 37 | num_box = len(box) 38 | density_map = np.zeros((h_image, w_image)) 39 | for i in range(num_box): 40 | box_a = box[i] 41 | for j in range(max(box_a[1], 0), min(box_a[3], h_image)): 42 | for k in range(max(box_a[0], 0), min(box_a[2], w_image)): 43 | density_map[j][k] += 1 44 | max_density = np.max(density_map) 45 | density_map_normalize = (density_map * 256 / max_density).astype(np.uint8) 46 | heatmap = cv2.applyColorMap(density_map_normalize, cv2.COLORMAP_JET) 47 | 48 | return heatmap 49 | -------------------------------------------------------------------------------- /xai/drise.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import warnings 3 | import cv2 4 | import numpy as np 5 | import tensorflow.compat.v1 as tf 6 | import base64 7 | import json 8 | # import pysmile 9 | import time 10 | import logging 11 | import os 12 | from numpy import dot 13 | from numpy.linalg import norm 14 | from tqdm import tqdm 15 | import math 16 | import cv2 17 | from datetime import datetime 18 | 19 | from utils import bbox_iou 20 | 21 | now = datetime.now() 22 | 23 | # from RMQ import BasicRMQClient 24 | 25 | warnings.filterwarnings('ignore') 26 | 27 | logname = '/logs/log-{}.log'.format(now.strftime("%Y-%m-%d")) 28 | 29 | logging.info('=' * 10 + ' LOG FILE FOR IMAGE ' + '=' * 10) 30 | 31 | 32 | class DRISE(object): 33 | def __init__(self, image, sess, grid_size, prob, num_samples=500, batch_size=1): 34 | self.image = image 35 | self.num_samples = num_samples 36 | self.sess = sess 37 | self.grid_size = grid_size 38 | self.prob = prob 39 | self.image_size = (image.shape[1], image.shape[2]) 40 | self.batch_size = batch_size 41 | 42 | def generate_mask(self, ): 43 | """ 44 | Return a mask with shape [H, W] 45 | :return: mask generated by bilinear interpolation 46 | """ 47 | image_h, image_w = self.image_size 48 | grid_h, grid_w = self.grid_size, self.grid_size 49 | 50 | # Create cell for mask 51 | cell_w, cell_h = math.ceil(image_w / grid_w), math.ceil(image_h / grid_h) 52 | up_w, up_h = (grid_w + 1) * cell_w, (grid_h + 1) * cell_h 53 | 54 | # Create {0, 1} mask 55 | mask = (np.random.uniform(0, 1, size=(grid_h, grid_w)) < self.prob).astype(np.float32) 56 | # Up-size to get value in [0, 1] 57 | mask = cv2.resize(mask, (up_w, up_h), interpolation=cv2.INTER_LINEAR) 58 | 59 | # Randomly crop the mask 60 | offset_w = np.random.randint(0, cell_w) 61 | offset_h = np.random.randint(0, cell_h) 62 | mask = mask[offset_h:offset_h + image_h, offset_w:offset_w + image_w] 63 | return mask 64 | 65 | def mask_image(self, image, mask): 66 | masked = ((image.astype(np.float32) / 255 * np.dstack([mask] * 3)) * 255).astype(np.uint8) 67 | return masked 68 | 69 | def explain(self, image, img_input, y_p_boxes, y_p_num_detections, detection_boxes, detection_scores, 70 | num_detections, 71 | detection_classes): 72 | num_objs = int(y_p_num_detections[0]) 73 | h, w = self.image_size 74 | res = np.zeros((num_objs, h, w), dtype=np.float32) # ---> shape[num_objs, h, w] 75 | max_score = np.zeros((num_objs,), dtype=np.float32) # ---> shape[num_objs,] 76 | for i in range(0, self.num_samples): 77 | mask = self.generate_mask() 78 | # masked = mask * image 79 | masked = self.mask_image(image, mask) 80 | input_dict = {img_input: masked} 81 | p_boxes, p_scores, p_num_detections, p_classes = self.sess.run( 82 | [detection_boxes, detection_scores, num_detections, detection_classes], 83 | feed_dict=input_dict) 84 | if int(p_num_detections[0]) == 0: 85 | continue 86 | for idx in range(num_objs): 87 | iou = np.array([bbox_iou(p_boxes[0][k], y_p_boxes[0][idx]) 88 | for k in range(int(p_num_detections[0]))]) 89 | cos_sin = p_scores[0][0:int(p_num_detections[0])] 90 | score = cos_sin * iou 91 | max_score[idx] = max(score) 92 | res[idx] += mask * max_score[idx] 93 | 94 | return res 95 | -------------------------------------------------------------------------------- /xai/gradcam.py: -------------------------------------------------------------------------------- 1 | # --------------------------------GradCAM--------------------------------------- 2 | import cv2 3 | import numpy as np 4 | 5 | 6 | class GradCAM(object): 7 | """ 8 | 1: GradCAM calculate gradient on two stages 9 | 2: Output tensor: Prediction boxes before Non-Max Suppression (first_stage) 10 | 3: Get index target boxes to backpropagation (second_stage), output: The final prediction of the model 11 | """ 12 | 13 | def __init__(self, session, conv_tensor, output_tensor): 14 | """ 15 | Initialize GradCAM 16 | :param session: Tensorflow session 17 | :param conv_tensor: Tensor of convolution layer 18 | :param output_tensor: Tensor of output layer 19 | """ 20 | self.sess = session 21 | self.conv_tensor = conv_tensor 22 | self.output_tensor = output_tensor 23 | 24 | def __call__(self, imgs, grads, img_input, stage, y_p_boxes, indices=0, index=0): 25 | """ 26 | Calculate GradCAM 27 | :param imgs: Input image 28 | :param grads: Gradient of output layer 29 | 30 | :param stage: Choose a stage to visualize: first_stage or second_stage 31 | :param indices: Index of target boxes to backpropagation (second_stage) 32 | :param index: Index of image 33 | :return: GradCAM explanation 34 | """ 35 | if stage == 'first_stage': 36 | # first image in batch 37 | conv_output, grads_val = self.sess.run([self.conv_tensor, grads], feed_dict={img_input: imgs}) 38 | weights = np.mean(grads_val[indices], axis=(0, 1)) 39 | feature = conv_output[indices] 40 | cam = feature * weights[np.newaxis, np.newaxis, :] 41 | else: 42 | conv_output, grads_val = self.sess.run([self.conv_tensor, grads], feed_dict={img_input: imgs}) 43 | weights = np.mean(grads_val[indices], axis=(0, 1)) 44 | feature = conv_output[indices] 45 | cam = feature * weights[np.newaxis, np.newaxis, :] 46 | cam = np.sum(cam, axis=2) 47 | # cam = np.maximum(cam, 0) #Relu 48 | # Normalize data (0, 1) 49 | cam -= np.min(cam) 50 | cam /= (np.max(cam) - np.min(cam)) 51 | h_img, w_img = imgs.shape[1:3] 52 | x1, x2 = int(y_p_boxes[0][index][1] * w_img), int(y_p_boxes[0][index][3] * w_img) 53 | y1, y2 = int(y_p_boxes[0][index][0] * h_img), int(y_p_boxes[0][index][2] * h_img) 54 | # Resize CAM 55 | if stage == 'first_stage': 56 | cam = cv2.resize(cam, (w_img, h_img)) 57 | return cam 58 | else: 59 | cam = cv2.resize(cam, (x2 - x1, y2 - y1)) 60 | return cam, x1, y1, x2, y2 61 | 62 | 63 | # --------------------------------GradCAM++------------------------------------- 64 | 65 | class GradCAMPlusPlus(GradCAM): 66 | def __init__(self, session, conv_tensor, output_tensor): 67 | """ 68 | Initialize GradCAM++ 69 | :param session: Tensorflow session 70 | """ 71 | super().__init__(session, conv_tensor, output_tensor) 72 | 73 | def __call__(self, imgs, grads, img_input, stage, y_p_boxes, indices=0, index=0): 74 | """ 75 | Calculate GradCAM++ 76 | :param imgs: Input image 77 | :param grads: Gradient of output layer 78 | :param stage: Choose a stage to visualize: first_stage or second_stage 79 | :param indices: Index of target boxes to backpropagation (second_stage) 80 | :param index: Index of image 81 | :return: GradCAM++ explanation 82 | """ 83 | if stage == 'first_stage': 84 | outputs, grads_val_1 = self.sess.run([self.conv_tensor, grads], feed_dict={img_input: imgs}) 85 | grads_val_2 = grads_val_1 ** 2 86 | grads_val_3 = grads_val_2 * grads_val_1 87 | global_sum = np.sum(outputs[0], axis=(0, 1)) 88 | eps = 0.000001 89 | aij = grads_val_2[indices] / ( 90 | 2 * grads_val_2[indices] + global_sum[None, None, :] * grads_val_3[indices] + eps) 91 | aij = np.where(grads_val_1[indices] != 0, aij, 0) 92 | weights = np.maximum(grads_val_1[indices], 0) * aij # Relu * aij = weight 93 | weights = np.sum(weights, axis=(0, 1)) 94 | cam = outputs[indices] * weights[np.newaxis, np.newaxis, :] 95 | else: 96 | outputs, grads_val_1 = self.sess.run([self.conv_tensor, grads], feed_dict={img_input: imgs}) 97 | grads_val_2 = grads_val_1 ** 2 98 | grads_val_3 = grads_val_2 * grads_val_1 99 | global_sum = np.sum(outputs[0], axis=(0, 1)) 100 | eps = 0.000001 101 | aij = grads_val_2[indices] / ( 102 | 2 * grads_val_2[indices] + global_sum[None, None, :] * grads_val_3[indices] + eps) 103 | aij = np.where(grads_val_1[indices] != 0, aij, 0) 104 | weights = np.maximum(grads_val_1[indices], 0) * aij # Relu * aij = weight 105 | weights = np.sum(weights, axis=(0, 1)) 106 | cam = outputs[indices] * weights[np.newaxis, np.newaxis, :] 107 | cam = np.sum(cam, axis=2) 108 | # cam = np.maximum(cam, 0) #Relu 109 | # Normalize 110 | cam -= np.min(cam) 111 | cam /= (np.max(cam) - np.min(cam)) 112 | h_img, w_img = imgs.shape[1:3] 113 | x1, x2 = int(y_p_boxes[0][index][1] * w_img), int(y_p_boxes[0][index][3] * w_img) 114 | y1, y2 = int(y_p_boxes[0][index][0] * h_img), int(y_p_boxes[0][index][2] * h_img) 115 | # Resize cam 116 | 117 | if stage == 'first_stage': 118 | cam = cv2.resize(cam, (w_img, h_img)) 119 | return cam 120 | else: 121 | cam = cv2.resize(cam, (x2 - x1, y2 - y1)) 122 | return cam, x1, y1, x2, y2 123 | -------------------------------------------------------------------------------- /xai/kde.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | import numpy as np 4 | import scipy.stats as st 5 | 6 | from utils import get_center, get_tensor_mini, bbox_iou 7 | 8 | 9 | class KDE(object): 10 | def __init__(self, sess, image, image_path, y_p_num_detections, y_p_boxes): 11 | self.image = image 12 | self.h_img, self.w_img = image.shape[1:3] 13 | self.sess = sess 14 | self.y_p_num_detections = y_p_num_detections 15 | self.y_p_boxes = y_p_boxes 16 | self.image_name = os.path.basename(image_path) 17 | 18 | def get_kde_map(self, box): 19 | """ 20 | This function is to estimate the probability density function (PDF) of 21 | a random variable in a non-parametric way 22 | :return: 23 | kernel: KDE object 24 | f: KDE map add resize to the image size 25 | """ 26 | x_box, y_box = get_center(box) 27 | x_train = np.vstack([x_box, y_box]).T 28 | x, y = x_train[:, 0], x_train[:, 1] 29 | xmin, xmax = 0, self.w_img 30 | ymin, ymax = 0, self.h_img 31 | xx, yy = np.mgrid[xmin:xmax, ymin:ymax] 32 | positions = np.vstack([xx.ravel(), yy.ravel()]) 33 | values = np.vstack([x, y]) 34 | kernel = st.gaussian_kde(values) 35 | f = np.reshape(kernel(positions).T, xx.shape) 36 | return kernel, f 37 | 38 | def get_kde_score(self, kde_kernel, box_predicted): 39 | """ 40 | This function is to find the kde score for the box predicted by AI model 41 | KDE score is the ratio of the KDE value of the predicted box center 42 | divided by the highest KDE value on the KDE map. 43 | Input: 44 | - box_predicted: the coordinates of the predicted box 45 | - kde_kernel: KDE object containing the kde map that estimated from 46 | the previous phase 47 | Output: 48 | - predict_value_kde: KDE score of the predicted box center 49 | """ 50 | x_predict, y_predict = get_center(box_predicted) 51 | kde_map = np.zeros((self.h_img, self.w_img)) 52 | for i in range(self.h_img): 53 | for j in range(self.w_img): 54 | kde_map[i][j] = kde_kernel.evaluate([i, j]) 55 | predict_value_kde = kde_kernel.evaluate([x_predict, y_predict]) / np.max(kde_map) 56 | return predict_value_kde 57 | 58 | def show_kde_map(self, box_predicted, f, save_file=None): 59 | """ 60 | This function is to show the kde map in the input image 61 | Input: 62 | - box: list of coordinate of 300 boxes 63 | - box_predicted: coodinate of the predicted box 64 | - image: the predict image 65 | """ 66 | xmin, xmax = 0, self.w_img 67 | ymin, ymax = 0, self.h_img 68 | xx, yy = np.mgrid[0:self.w_img, 0:self.h_img] 69 | 70 | fig = plt.figure() 71 | ax = fig.gca() 72 | plt.axis([xmin, xmax, ymin, ymax]) 73 | ax.imshow(self.image[0]) 74 | 75 | ax.set_xlim(xmin, xmax) 76 | ax.set_ylim(ax.get_ylim()[::-1]) 77 | x_predict, y_predict = get_center(box_predicted) 78 | plt.scatter(x_predict, y_predict, c='red') 79 | cfset = ax.contourf(xx, yy, f, cmap='Blues') 80 | cset = ax.contour(xx, yy, f, colors='g') 81 | ax.set_xlabel('Y1') 82 | ax.set_ylabel('Y0') 83 | if save_file is not None: 84 | plt.savefig(os.path.join(save_file, 'kde_' + self.image_name + '_kde_blue.jpg'), dpi=600) 85 | plt.show() 86 | 87 | def get_box_predicted(self, img_input): 88 | # Get boxes 89 | box = [] 90 | box_predicted = [] 91 | b = [n.name for n in self.sess.graph.as_graph_def().node if 'SecondStageBoxPredictor' in n.name] 92 | boxes = get_tensor_mini(self.sess, b[18], self.image, img_input) 93 | # Get predicted box 94 | post_process = [n.name for n in self.sess.graph.as_graph_def().node if 'SecondStagePostprocessor' in n.name] 95 | boxes_post_process = get_tensor_mini(self.sess, post_process[14], self.image, img_input) 96 | # Preprocess boxes 97 | ratio = self.w_img / self.h_img 98 | for i in range(300): 99 | x1, x2 = int(boxes_post_process[i][1] * ratio), int(boxes_post_process[i][3] * ratio) 100 | y1, y2 = int(boxes_post_process[i][0] * ratio), int(boxes_post_process[i][2] * ratio) 101 | box.append([x1, y1, x2, y2]) 102 | 103 | num_box = int(self.y_p_num_detections[0]) 104 | 105 | for i in range(num_box): 106 | x1, x2 = int(self.y_p_boxes[0][i][1] * self.w_img), int(self.y_p_boxes[0][i][3] * self.w_img) 107 | y1, y2 = int(self.y_p_boxes[0][i][0] * self.h_img), int(self.y_p_boxes[0][i][2] * self.h_img) 108 | box_predicted.append([x1, y1, x2, y2]) 109 | 110 | iou = [] 111 | for i in range(num_box): 112 | for j in range(300): 113 | iou.append(bbox_iou(box[j], box_predicted[i])) 114 | 115 | return box, box_predicted -------------------------------------------------------------------------------- /xai/lime_method.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage 3 | 4 | from skimage.segmentation import slic 5 | from lime import lime_image 6 | 7 | 8 | class LIME(object): 9 | def __init__(self, session, img_input, detection_scores, image, indices, top_labels=1, num_features=60): 10 | """ 11 | Initialize LIME 12 | :param session: Tensorflow session 13 | :param image: Input image 14 | :param indices: Indices of image 15 | :param top_labels: Number of top labels 16 | :param num_features: Number of features for segmentation 17 | """ 18 | self.image = image 19 | self.sess = session 20 | self.img_input = img_input 21 | self.detection_scores = detection_scores 22 | self.top_labels = top_labels 23 | self.result = None 24 | self.num_features = num_features 25 | self.indices = indices 26 | self.segments = self.segment_fn(self.image) 27 | 28 | def segment_fn(self, image): 29 | """ 30 | Segment image 31 | :param image: Input image 32 | :return: Segmented image 33 | """ 34 | segments_slic = slic(image, n_segments=self.num_features, compactness=30, sigma=3) 35 | return segments_slic 36 | 37 | def _predict_(self, sample, flag=False): 38 | """ 39 | Predict image 40 | :param sample: Input image 41 | :param flag: Flag for prediction 42 | :return: Prediction 43 | """ 44 | img = sample 45 | input_dict = {self.img_input: img} 46 | p_scores = self.sess.run(self.detection_scores, feed_dict=input_dict) 47 | if flag: 48 | return p_scores 49 | else: 50 | rs = np.array([]) 51 | n = img.shape[0] 52 | for i in range(n): 53 | rs = np.append(rs, p_scores[i][self.indices]) 54 | return rs.reshape(n, 1) 55 | 56 | def explain(self, num_features, num_samples=100, top_labels=0, positive=False): 57 | """ 58 | Calculate LIME explanation 59 | :param num_features: Number of features for segmentation 60 | :param num_samples: Number of samples 61 | :param top_labels: Number of top labels 62 | :param positive: Flag for positive explanation 63 | :return: LIME explanation 64 | """ 65 | explainer = lime_image.LimeImageExplainer() 66 | explanation = explainer.explain_instance(self.image, 67 | self._predict_, 68 | num_samples=num_samples, 69 | top_labels=self.top_labels, 70 | hide_color=0, 71 | segmentation_fn=self.segment_fn) 72 | self.result = explanation 73 | temp, mask = self.result.get_image_and_mask(self.result.top_labels[top_labels], 74 | positive_only=positive, 75 | num_features=num_features, 76 | hide_rest=False, 77 | min_weight=0.) 78 | img_boundary = skimage.segmentation.mark_boundaries(temp, mask) 79 | return img_boundary -------------------------------------------------------------------------------- /xai/rise.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage.transform import resize 3 | from tqdm import tqdm 4 | 5 | 6 | class RISE(object): 7 | def __init__(self, image, sess, grid_size, prob, num_samples=500, batch_size=1): 8 | """ 9 | Initialize RISE 10 | :param image: Input image 11 | :param sess: Tensorflow session 12 | :param grid_size: Grid size 13 | :param prob: Probability of sampling 14 | :param num_samples: Number of samples 15 | :param batch_size: Batch size 16 | """ 17 | self.image = image 18 | if num_samples > 700 or num_samples <= 0: 19 | num_samples = 700 20 | self.num_samples = num_samples 21 | self.sess = sess 22 | self.grid_size = grid_size 23 | self.prob = prob 24 | self.image_size = (image.shape[1], image.shape[2]) 25 | self.batch_size = batch_size 26 | self.mask = self.generate_mask(self.num_samples, self.grid_size, self.prob) 27 | 28 | def generate_mask(self, num_samples, grid_size, prob): 29 | """ 30 | Generate mask 31 | :param num_samples: Number of samples 32 | :param grid_size: Grid size 33 | :param prob: Probability of sampling 34 | :return: Mask 35 | """ 36 | cell_size = np.ceil(np.array(self.image_size) / grid_size) 37 | up_size = (grid_size + 1) * cell_size 38 | grid = np.random.rand(num_samples, grid_size, grid_size) < prob 39 | grid = grid.astype('float32') 40 | masks = np.empty((num_samples, *self.image_size)) 41 | for i in tqdm(range(num_samples), desc='Generating masks'): 42 | # Random shifts 43 | x = np.random.randint(0, cell_size[0]) 44 | y = np.random.randint(0, cell_size[1]) 45 | # Linear upsampling and cropping 46 | masks[i, :, :] = resize(grid[i], up_size, order=1, mode='reflect', 47 | anti_aliasing=False)[x:x + self.image_size[0], y:y + self.image_size[1]] 48 | masks = masks.reshape(-1, 1) 49 | return masks 50 | 51 | def explain(self, image, index, mask, detection_boxes, detection_scores, num_detections, detection_classes): 52 | """ 53 | Calculate RISE explanation 54 | :param image: Input image 55 | :param index: Index of image 56 | :param mask: Mask 57 | :param detection_boxes: Detection boxes 58 | :param detection_scores: Detection scores 59 | :param num_detections: Number of detections 60 | :param detection_classes: Detection classes 61 | :return: RISE explanation as saliency map 62 | """ 63 | N = self.num_samples 64 | p = self.prob 65 | preds = np.array([]) 66 | masked = self.mask * image 67 | for i in tqdm(range(0, N, self.batch_size)): 68 | input_dict = {mask: masked[i:i + self.batch_size]} 69 | p_boxes, p_scores, p_num_detections, p_classes = self.sess.run( 70 | [detection_boxes, detection_scores, num_detections, detection_classes], 71 | feed_dict=input_dict) 72 | for j in range(self.batch_size): 73 | preds = np.append(preds, p_scores[j][index]) 74 | preds = preds.reshape(N, 1) 75 | sal = preds.T.dot(self.mask.reshape(N, -1)) 76 | sal = sal.reshape(-1, *self.image_size) 77 | sal = sal / N / p 78 | sal -= np.min(sal) 79 | sal /= (np.max(sal) - np.min(sal)) 80 | return sal 81 | --------------------------------------------------------------------------------