├── dataaugment.sh ├── list_requirements.sh ├── see_prototype_grid.sh ├── plot_graph.sh ├── run_gradCAM.sh ├── save.py ├── log.py ├── train_vanilla_malignancy.sh ├── train.sh ├── preprocess.py ├── LICENSE ├── settings.py ├── last_layer.py ├── see_explanations.sh ├── vis_protos.py ├── run_global_analysis.sh ├── helpers.py ├── README.md ├── delong_2.py ├── prune.py ├── receptive_field.py ├── run_pruning.py ├── dataHandling.py ├── global_analysis.py ├── our_vgg.py ├── gradcam.py ├── delong.py ├── dataHelper.py ├── vanilla_vgg.py ├── local_analysis_vis.py ├── vgg_features.py ├── gradcam_utils.py ├── resnet_features.py ├── train_and_test.py ├── main.py ├── gradcam_APs.py ├── densenet_features.py └── find_nearest.py /dataaugment.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | srun -u python dataHandling.py 6 | -------------------------------------------------------------------------------- /list_requirements.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | srun -u python -m pip freeze 6 | -------------------------------------------------------------------------------- /see_prototype_grid.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | srun -u python3 vis_protos.py 8 | -------------------------------------------------------------------------------- /plot_graph.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | echo "To get the output from paper, you must have access to the entire test set." 8 | 9 | srun -u python3 graphing.py -------------------------------------------------------------------------------- /run_gradCAM.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | echo "start running" 6 | 7 | nvidia-smi 8 | 9 | python gradcam_APs.py -save_loc /usr/xtmp/IAIABL/gradCAM_imgs/view.png 10 | 11 | echo "finish running" -------------------------------------------------------------------------------- /save.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | def save_model_w_condition(model, model_dir, model_name, accu, target_accu, log=print): 5 | ''' 6 | model: this is not the multigpu model 7 | ''' 8 | if accu > target_accu: 9 | log('\tabove {0:.2f}%'.format(target_accu * 100)) 10 | # torch.save(obj=model.state_dict(), f=os.path.join(model_dir, (model_name + '{0:.4f}.pth').format(accu))) 11 | torch.save(obj=model, f=os.path.join(model_dir, (model_name + '{0:.4f}.pth').format(accu))) 12 | -------------------------------------------------------------------------------- /log.py: -------------------------------------------------------------------------------- 1 | import os 2 | def create_logger(log_filename, display=True): 3 | f = open(log_filename, 'a') 4 | counter = [0] 5 | # this function will still have access to f after create_logger terminates 6 | def logger(text): 7 | if display: 8 | print(text) 9 | f.write(text + '\n') 10 | counter[0] += 1 11 | if counter[0] % 10 == 0: 12 | f.flush() 13 | os.fsync(f.fileno()) 14 | # Question: do we need to flush() 15 | return logger, f.close 16 | -------------------------------------------------------------------------------- /train_vanilla_malignancy.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | srun -u python vanilla_vgg.py -model="vgg16" \ 8 | -train_dir="/usr/xtmp/IAIABL/Lo1136i/bymal/train/" \ 9 | -test_dir="/usr/xtmp/IAIABL/Lo1136i/bymal/test/"\ 10 | -name="0202_vanilla_2mal_vgg16_latent512_random=12"\ 11 | -lr="1e-5" \ 12 | -wd="1e-1" \ 13 | -num_classes="2" -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | echo "You would require an entire dataset to train using this script." 8 | 9 | srun -u python main.py -latent=512 -experiment_run='0112_topkk=9_fa=0.001_random=4' \ 10 | -base="vgg16" \ 11 | -last_layer_weight=-1 \ 12 | -fa_coeff=0.001 \ 13 | -topk_k=9 \ 14 | -train_dir="/usr/xtmp/mammo/Lo1136i_with_fa/train_augmented_5000/" \ 15 | -push_dir="/usr/xtmp/mammo/Lo1136i_finer/by_margin/train/" \ 16 | -test_dir="/usr/xtmp/mammo/Lo1136i_with_fa/validation/" \ 17 | -random_seed=4 \ 18 | -finer_dir="/usr/xtmp/mammo/Lo1136i_finer/by_margin/train_augmented_250/" \ 19 | # -model=".pth" \ 20 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | mean = (0.485, 0.456, 0.406) 4 | std = (0.229, 0.224, 0.225) 5 | 6 | def preprocess(x, mean, std): 7 | assert x.size(1) == 3 8 | y = torch.zeros_like(x) 9 | for i in range(3): 10 | y[:, i, :, :] = (x[:, i, :, :] - mean[i]) / std[i] 11 | return y 12 | 13 | 14 | def preprocess_input_function(x): 15 | ''' 16 | allocate new tensor like x and apply the normalization used in the 17 | pretrained model 18 | ''' 19 | return preprocess(x, mean=mean, std=std) 20 | 21 | def undo_preprocess(x, mean, std): 22 | assert x.size(1) == 3 23 | y = torch.zeros_like(x) 24 | for i in range(3): 25 | y[:, i, :, :] = x[:, i, :, :] * std[i] + mean[i] 26 | return y 27 | 28 | def undo_preprocess_input_function(x): 29 | ''' 30 | allocate new tensor like x and undo the normalization used in the 31 | pretrained model 32 | ''' 33 | return undo_preprocess(x, mean=mean, std=std) 34 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The copyrights of this software are owned by Duke University. As such, two licenses to this software are offered: 2 | 1. An open-source license under the MIT license for non-commercial use. 3 | 2. A custom license with Duke University, for commercial use or for use without the MIT license restrictions. 4 | 5 | As a recipient of this software, you may choose which license to receive the code under. Outside contributions to the Duke owned code base cannot be accepted unless the contributor transfers the copyright to those changes over to Duke University. 6 | To enter a custom license agreement without the MIT license restrictions, please contact the Digital Innovations department at Duke Office for Translation & Commercialization (OTC) (https://olv.duke.edu/software/) at olvquestions@duke.edu with reference to “OLV File No. 007674” in your email. 7 | 8 | Please note that this software is distributed AS IS, WITHOUT ANY WARRANTY; and without the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 9 | -- 10 | 11 | (c) Copyright 2021. Duke University. All Rights Reserved. 12 | Developed by Alina Jade Barnett, Fides Regina Schwartz, Chaofan Tao, Chaofan Chen, Yinhao Ren, Joseph Y. Lo, 13 | and Cynthia Rudin at Duke University. 14 | -------------------------------------------------------------------------------- /settings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | base_architecture = 'vgg16' 5 | img_size = 224 6 | prototype_shape = (15, 512, 1, 1) 7 | num_classes = 3 8 | 9 | prototype_activation_function = "log" 10 | prototype_activation_function_in_numpy = prototype_activation_function 11 | 12 | class_specific = True 13 | 14 | add_on_layers_type = 'regular' 15 | 16 | experiment_run = '1218_fa=' 17 | data_path = '/usr/xtmp/mammo/Lo1136i_with_fa/' 18 | train_dir = data_path + 'train_augmented_5000/' 19 | test_dir = data_path + 'validation/' 20 | train_push_dir = '/usr/xtmp/mammo/Lo1136i_finer/by_margin/train/' 21 | 22 | train_batch_size = 75 23 | test_batch_size = 100 24 | train_push_batch_size = 75 25 | 26 | joint_optimizer_lrs = {'features': 2e-4, 27 | 'add_on_layers': 3e-3, 28 | 'prototype_vectors': 3e-3} 29 | joint_lr_step_size = 5 30 | 31 | warm_optimizer_lrs = {'add_on_layers': 2e-3, 32 | 'prototype_vectors': 3e-3} 33 | 34 | last_layer_optimizer_lr = 1e-3 35 | 36 | coefs = { 37 | 'crs_ent': 1, 38 | 'clst': 0.8, 39 | 'sep': -0.08, 40 | 'l1': 1e-4, 41 | 'fine': 0.001, 42 | } 43 | 44 | num_train_epochs = 130 45 | num_warm_epochs = 10 46 | 47 | push_start = 10 48 | push_epochs = [i for i in range(num_train_epochs) if i % 10 == 0] 49 | -------------------------------------------------------------------------------- /last_layer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import matplotlib.pyplot as plt 8 | import cv2 9 | from PIL import Image 10 | from dataHelper import DatasetFolder 11 | import re 12 | import numpy as np 13 | import os 14 | import copy 15 | from skimage.transform import resize 16 | from helpers import makedir, find_high_activation_crop 17 | import model 18 | import push 19 | import train_and_test as tnt 20 | import save 21 | from log import create_logger 22 | from preprocess import mean, std, preprocess_input_function, undo_preprocess_input_function 23 | 24 | def show_last_layer_connections(ppnet): 25 | print(ppnet.num_prototypes, ppnet.num_classes) 26 | last_layer_connections = np.zeros((ppnet.num_prototypes, ppnet.num_classes)) 27 | last_layer_connections = ppnet.last_layer.weight 28 | return last_layer_connections 29 | 30 | def show_last_layer_connections_T(ppnet): 31 | print(ppnet.num_prototypes, ppnet.num_classes) 32 | last_layer_connections = np.zeros((ppnet.num_prototypes, ppnet.num_classes)) 33 | last_layer_connections = ppnet.last_layer.weight 34 | last_layer_connections_T = torch.transpose(last_layer_connections, 0, 1) 35 | return last_layer_connections_T -------------------------------------------------------------------------------- /see_explanations.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | echo "Begun generating explanations." 8 | 9 | MODELFOLDER=/usr/xtmp/IAIABL/saved_models/vgg16/0129_pushonall_topkk=9_fa=0.001_random=4/pruned_prototypes_epoch50_k6_pt3 10 | MODELNAME=50_4_0prune0.9533.pth 11 | 12 | for FILENAME in DP_AJOU_197104_1.npy DP_AKKN_7728_1.npy 13 | do 14 | srun -u python local_analysis.py -test_img_name "$FILENAME" \ 15 | -test_img_dir '/usr/xtmp/IAIABL/Lo1136i/test/Circumscribed/' \ 16 | -test_img_label 0 \ 17 | -test_model_dir "$MODELFOLDER/" \ 18 | -test_model_name "$MODELNAME" &>/dev/null 19 | 20 | srun -u python3 local_analysis_vis.py -local_analysis_directory "$MODELFOLDER/$FILENAME/" 21 | done 22 | 23 | for FILENAME in DP_AKAY_89028_1.npy DP_AKVP_18401_1.npy DP_ALFQ_28102_1.npy 24 | do 25 | srun -u python local_analysis.py -test_img_name "$FILENAME" \ 26 | -test_img_dir '/usr/xtmp/IAIABL/Lo1136i/test/Spiculated/' \ 27 | -test_img_label 2 \ 28 | -test_model_dir "$MODELFOLDER/" \ 29 | -test_model_name "$MODELNAME" &>/dev/null 30 | 31 | srun -u python3 local_analysis_vis.py -local_analysis_directory "$MODELFOLDER/$FILENAME/" 32 | done 33 | 34 | echo "End." -------------------------------------------------------------------------------- /vis_protos.py: -------------------------------------------------------------------------------- 1 | from matplotlib.pyplot import imsave, imread 2 | import numpy as np 3 | from skimage.transform import resize 4 | import os 5 | 6 | img_dir = "/usr/xtmp/IAIABL/saved_models/vgg16/0129_pushonall_topkk=9_fa=0.001_random=4/pruned_prototypes_epoch50_k6_pt3/img/epoch-50/" 7 | paths = [] 8 | size = 5 9 | sizeh = 3 10 | sizew = 5 11 | save_dir = img_dir + "model_results_proto_visualization/" 12 | if not os.path.exists(save_dir): 13 | os.makedirs(save_dir) 14 | 15 | for i in range(15): 16 | paths.append(img_dir + "prototype-img-original_with_self_act"+ str(i) + ".png") 17 | 18 | tosave = np.zeros((sizeh * 250, sizew * 250, 4)) 19 | index = 0 20 | index_ = 1 21 | for path in paths: 22 | try: 23 | arr = imread(path) 24 | # print("size = ", arr.shape) 25 | except: 26 | arr = np.ones((224,224,4)) 27 | # print(arr.shape) 28 | # print(path) 29 | arr = np.pad(arr, ((13, 13), (13, 13), (0, 0)), constant_values=0) 30 | tosave[(index // sizew) * 250:(index // sizew) * 250 + 250, (index % sizew) * 250:(index % sizew) * 250 + 250] = arr 31 | index += 1 32 | if index == sizeh * sizew: 33 | imsave(save_dir+ "/"+str(index_), tosave, cmap="gray") 34 | tosave = np.zeros((sizeh * 250, sizew * 250)) 35 | index_ += 1 36 | index = 0 37 | # print("Saved!") 38 | imsave(save_dir + "/last", tosave) 39 | print(f'Saved to {save_dir}.') 40 | 41 | -------------------------------------------------------------------------------- /run_global_analysis.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source /home/virtual_envs/ml/bin/activate 4 | 5 | nvidia-smi 6 | 7 | echo "This cannot be run without the original dataset." 8 | 9 | srun -u python3 global_analysis.py -modeldir='/usr/xtmp/mammo/saved_models/vgg16/0129_pushonall_topkk=9_fa=0.001_random=4/pruned_prototypes_epoch50_k6_pt3/' \ 10 | -model='50_4_0prune0.9533.pth' \ 11 | -push_dir='/usr/xtmp/mammo/Lo1136i_with_fa/train_plus_val/' \ 12 | -test_dir='/usr/xtmp/mammo/Lo1136i_with_fa/test/' 13 | 14 | echo "The above is for pruned IAIA-BL" 15 | 16 | srun -u python3 global_analysis.py -modeldir='/usr/xtmp/mammo/saved_models/vgg16/0129_pushonall_topkk=9_fa=0.001_random=4/' \ 17 | -model='50_4push0.9546.pth' \ 18 | -push_dir='/usr/xtmp/mammo/Lo1136i_with_fa/train_plus_val/' \ 19 | -test_dir='/usr/xtmp/mammo/Lo1136i_with_fa/test/' 20 | 21 | echo "The above is for UNpruned IAIA-BL" 22 | 23 | srun -u python3 global_analysis.py -modeldir='/usr/xtmp/mammo/saved_models/vgg16/0125_topkk=1_fa=0.0_random=4/' \ 24 | -model='50_5push0.9209.pth' \ 25 | -push_dir='/usr/xtmp/mammo/Lo1136i_with_fa/train_plus_val/' \ 26 | -test_dir='/usr/xtmp/mammo/Lo1136i_with_fa/test/' 27 | 28 | echo "The above is for original protopnet (Baseline 1)" -------------------------------------------------------------------------------- /helpers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | 5 | def silent_print(string): 6 | pass 7 | 8 | def list_of_distances(X, Y): 9 | return torch.sum((torch.unsqueeze(X, dim=2) - torch.unsqueeze(Y.t(), dim=0)) ** 2, dim=1) 10 | 11 | def make_one_hot(target, target_one_hot): 12 | target = target.view(-1,1) 13 | target_one_hot.zero_() 14 | target_one_hot.scatter_(dim=1, index=target, value=1.) 15 | 16 | def makedir(path): 17 | ''' 18 | if path does not exist in the file system, create it 19 | ''' 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | def print_and_write(str, file): 24 | print(str) 25 | file.write(str + '\n') 26 | 27 | def find_high_activation_crop(activation_map, percentile=95): 28 | threshold = np.percentile(activation_map, percentile) 29 | mask = np.ones(activation_map.shape) 30 | mask[activation_map < threshold] = 0 31 | lower_y, upper_y, lower_x, upper_x = 0, 0, 0, 0 32 | for i in range(mask.shape[0]): 33 | if np.amax(mask[i]) > 0.5: 34 | lower_y = i 35 | break 36 | for i in reversed(range(mask.shape[0])): 37 | if np.amax(mask[i]) > 0.5: 38 | upper_y = i 39 | break 40 | for j in range(mask.shape[1]): 41 | if np.amax(mask[:,j]) > 0.5: 42 | lower_x = j 43 | break 44 | for j in reversed(range(mask.shape[1])): 45 | if np.amax(mask[:,j]) > 0.5: 46 | upper_x = j 47 | break 48 | return lower_y, upper_y+1, lower_x, upper_x+1 49 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # IAIA-BL 2 | 3 | This code implements IAIA-BL from the manuscript "A Case-based 4 | Interpretable Deep Learning Model for Classification of Mass Lesions 5 | in Digital Mammography" published in Nature Machine Intelligence, Dec 2021, 6 | by Alina Jade Barnett, Fides Regina Schwartz, Chaofan Tao, Chaofan Chen, 7 | Yinhao Ren, Joseph Y. Lo, and Cynthia Rudin. 8 | 9 | This code package was developed by the authors at Duke University and 10 | University of Maine, and licensed as described in LICENSE (for more 11 | information regarding the use and the distribution of this code package). 12 | 13 | ## Prerequisites 14 | Any operating system on which you can run GPU-accelerated 15 | PyTorch. Python 3.6.9. For packages see requirements.txt. 16 | ### Recommended hardware 17 | 2 NVIDIA Tesla P-100 GPUs or 2 NVIDIA Tesla V-100 GPUs 18 | 19 | ## Installation instructions 20 | 1. Git clone the repository to /usr/xtmp/IAIABL/. 21 | 3. Set up your environment using Python 3.6.9 and requirements.txt. 22 | (Optional) Set up your environment using requirements.txt so that "source 23 | /home/virtual_envs/ml/bin/activate" activates your environment. You can 24 | set up the environment differently if you choose, but all .sh scripts 25 | included will attempt to activate the environment at 26 | /home/virtual_envs/ml/bin/activate. 27 | Typical install time: Less than 10 minutes. 28 | 29 | ## Train the model 30 | 1. In train.sh, the appropriate file locations should be set for train_dir, 31 | test_dir, push_dir and finer_dir: 32 | 1. train_dir is the directory containing the augmented training set 33 | 2. test_dir is the directory containing the test set 34 | 3. push_dir is the directory containing the original (unaugmented) training 35 | set, onto which prototypes can be projected 36 | 4. finer_dir is the directory containing the augmented set of training 37 | examples with fine-scale annotations 38 | 39 | 2. Run train.sh 40 | 41 | ## Reproducing figures 42 | No data is provided with this code repository. The following scripts are 43 | included to demonstrate how figures and results were created for the 44 | paper. The following scripts require data to be provided. Type "source 45 | scriptname.sh" into the command line to run. 46 | 47 | 1. see_explanations.sh 48 | 49 | Expected output from see_explanations.sh are figures from the 50 | manuscript that begin with "An automatically generated explanation 51 | of mass margin classification." The paths to the output images will 52 | appear in the relative file location "./visualizations_of_expl/". 53 | 54 | 2. see_prototype_grid.sh 55 | 56 | Expected output from see_prototype_grid.sh will be a grid of prototypes 57 | for a given model. The file location where the output image can be 58 | found will be printed onto the command line. 59 | 60 | 3. run_gradCAM.sh 61 | 62 | Expected output from run_gradCAM.sh will show the activation precision of the 63 | sample data. It will also print a visualization in 64 | /usr/xtmp/IAIABL/gradCAM_imgs/view.png. The columns from left to right are 65 | "Original Image," "GradCAM heatmap," "GradCAM++ heatmap," "GradCAM heatmap 66 | overlayed on the original image," and "GradCAM++ heatmap overlayed on the 67 | original image." The rows are "Last layer, using a network trained on natural 68 | images," "6th layer, using a network trained on natural images," "Blank," and 69 | "Last layer, using a network trained to identify the mass margin." 70 | 71 | 4. The mal_for_reviewers.ipynb Jupyter notebook is also included. 72 | 73 | Expected output from mal_for_reviewers.ipynb is in the cells of the notebook. 74 | 75 | Expected run time for these four demo files: 10 minutes. 76 | 77 | ## Other functions 78 | The following scripts require the more of the (private) dataset in order to 79 | run correctly, but are included to aid in reproducibility: 80 | 1. dataaugment.sh - for offline data augmentation 81 | 2. plot_graph.sh - plots a variety of graphs 82 | 3. run_global_analysis.sh - provides a global analysis of the model 83 | 4. train_vanilla_malignancy.sh - for training the baseline models 84 | 85 | ## Expected Data Location 86 | Scripts are set up to expect data as numpy arrays in 87 | /usr/xtmp/IAIABL/Lo1136i/test/Circumscribed/ where Circumscribed is the 88 | mass margin label. The first channel of the numpy array should be image 89 | data and the second (optional) channel should be the fine annotation 90 | label. 91 | -------------------------------------------------------------------------------- /delong_2.py: -------------------------------------------------------------------------------- 1 | #Retrieved from https://github.com/yandexdataschool/roc_comparison/blob/master/compare_auc_delong_xu.py March 12, 2021. 2 | 3 | import pandas as pd 4 | import numpy as np 5 | import scipy.stats 6 | 7 | # AUC comparison adapted from 8 | # https://github.com/Netflix/vmaf/ 9 | def compute_midrank(x): 10 | """Computes midranks. 11 | Args: 12 | x - a 1D numpy array 13 | Returns: 14 | array of midranks 15 | """ 16 | J = np.argsort(x) 17 | Z = x[J] 18 | N = len(x) 19 | T = np.zeros(N, dtype=np.float) 20 | i = 0 21 | while i < N: 22 | j = i 23 | while j < N and Z[j] == Z[i]: 24 | j += 1 25 | T[i:j] = 0.5*(i + j - 1) 26 | i = j 27 | T2 = np.empty(N, dtype=np.float) 28 | # Note(kazeevn) +1 is due to Python using 0-based indexing 29 | # instead of 1-based in the AUC formula in the paper 30 | T2[J] = T + 1 31 | return T2 32 | 33 | 34 | def fastDeLong(predictions_sorted_transposed, label_1_count): 35 | """ 36 | The fast version of DeLong's method for computing the covariance of 37 | unadjusted AUC. 38 | Args: 39 | predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples] 40 | sorted such as the examples with label "1" are first 41 | Returns: 42 | (AUC value, DeLong covariance) 43 | Reference: 44 | @article{sun2014fast, 45 | title={Fast Implementation of DeLong's Algorithm for 46 | Comparing the Areas Under Correlated Receiver Operating Characteristic Curves}, 47 | author={Xu Sun and Weichao Xu}, 48 | journal={IEEE Signal Processing Letters}, 49 | volume={21}, 50 | number={11}, 51 | pages={1389--1393}, 52 | year={2014}, 53 | publisher={IEEE} 54 | } 55 | """ 56 | # Short variables are named as they are in the paper 57 | m = label_1_count 58 | n = predictions_sorted_transposed.shape[1] - m 59 | positive_examples = predictions_sorted_transposed[:, :m] 60 | negative_examples = predictions_sorted_transposed[:, m:] 61 | k = predictions_sorted_transposed.shape[0] 62 | 63 | tx = np.empty([k, m], dtype=np.float) 64 | ty = np.empty([k, n], dtype=np.float) 65 | tz = np.empty([k, m + n], dtype=np.float) 66 | for r in range(k): 67 | tx[r, :] = compute_midrank(positive_examples[r, :]) 68 | ty[r, :] = compute_midrank(negative_examples[r, :]) 69 | tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :]) 70 | aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n 71 | v01 = (tz[:, :m] - tx[:, :]) / n 72 | v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m 73 | sx = np.cov(v01) 74 | sy = np.cov(v10) 75 | delongcov = sx / m + sy / n 76 | return aucs, delongcov 77 | 78 | 79 | def calc_pvalue(aucs, sigma): 80 | """Computes log(10) of p-values. 81 | Args: 82 | aucs: 1D array of AUCs 83 | sigma: AUC DeLong covariances 84 | Returns: 85 | log10(pvalue) 86 | """ 87 | l = np.array([[1, -1]]) 88 | z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, sigma), l.T)) 89 | return np.log10(2) + scipy.stats.norm.logsf(z, loc=0, scale=1) / np.log(10) 90 | 91 | 92 | def compute_ground_truth_statistics(ground_truth): 93 | assert np.array_equal(np.unique(ground_truth), [0, 1]) 94 | order = (-ground_truth).argsort() 95 | label_1_count = int(ground_truth.sum()) 96 | return order, label_1_count 97 | 98 | 99 | def delong_roc_variance(ground_truth, predictions): 100 | """ 101 | Computes ROC AUC variance for a single set of predictions 102 | Args: 103 | ground_truth: np.array of 0 and 1 104 | predictions: np.array of floats of the probability of being class 1 105 | """ 106 | order, label_1_count = compute_ground_truth_statistics(ground_truth) 107 | predictions_sorted_transposed = predictions[np.newaxis, order] 108 | aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count) 109 | assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers" 110 | return aucs[0], delongcov 111 | 112 | 113 | def delong_roc_test(ground_truth, predictions_one, predictions_two): 114 | """ 115 | Computes log(p-value) for hypothesis that two ROC AUCs are different 116 | Args: 117 | ground_truth: np.array of 0 and 1 118 | predictions_one: predictions of the first model, 119 | np.array of floats of the probability of being class 1 120 | predictions_two: predictions of the second model, 121 | np.array of floats of the probability of being class 1 122 | """ 123 | order, label_1_count = compute_ground_truth_statistics(ground_truth) 124 | predictions_sorted_transposed = np.vstack((predictions_one, predictions_two))[:, order] 125 | aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count) 126 | return calc_pvalue(aucs, delongcov) -------------------------------------------------------------------------------- /prune.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | from collections import Counter 4 | import numpy as np 5 | import torch 6 | 7 | from helpers import makedir 8 | import find_nearest 9 | 10 | def prune_prototypes(dataloader, 11 | prototype_network_parallel, 12 | k, 13 | prune_threshold, 14 | preprocess_input_function, 15 | original_model_dir, 16 | epoch_number, 17 | #model_name=None, 18 | log=print, 19 | copy_prototype_imgs=True, 20 | prototypes_to_keep=None): 21 | ### find prototypes to prune 22 | original_num_prototypes = prototype_network_parallel.module.num_prototypes 23 | if not prototypes_to_keep: 24 | nearest_train_patch_class_ids = \ 25 | find_nearest.find_k_nearest_patches_to_prototypes(dataloader=dataloader, 26 | prototype_network_parallel=prototype_network_parallel, 27 | k=k, 28 | preprocess_input_function=preprocess_input_function, 29 | full_save=False, 30 | log=log) 31 | prototypes_to_prune = [] 32 | for j in range(prototype_network_parallel.module.num_prototypes): 33 | class_j = torch.argmax(prototype_network_parallel.module.prototype_class_identity[j]).item() 34 | nearest_train_patch_class_counts_j = Counter(nearest_train_patch_class_ids[j]) 35 | # if no such element is in Counter, it will return 0 36 | if nearest_train_patch_class_counts_j[class_j] < prune_threshold: 37 | prototypes_to_prune.append(j) 38 | 39 | log('k = {}, prune_threshold = {}'.format(k, prune_threshold)) 40 | log('{} prototypes will be pruned'.format(len(prototypes_to_prune))) 41 | 42 | else: 43 | log("prototype to keep is selected to be {}".format(prototypes_to_keep)) 44 | prototypes_to_prune = list(set(range(original_num_prototypes)) - set(prototypes_to_keep)) 45 | 46 | ### bookkeeping of prototypes to be pruned 47 | class_of_prototypes_to_prune = \ 48 | torch.argmax(prototype_network_parallel.module.prototype_class_identity[prototypes_to_prune], dim=1).numpy().reshape(-1, 1) 49 | prototypes_to_prune_np = np.array(prototypes_to_prune).reshape(-1, 1) 50 | prune_info = np.hstack((prototypes_to_prune_np, class_of_prototypes_to_prune)) 51 | makedir(os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number, 52 | k, 53 | prune_threshold))) 54 | np.save(os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number, 55 | k, 56 | prune_threshold), 'prune_info.npy'), 57 | prune_info) 58 | ### prune prototypes 59 | prototype_network_parallel.module.prune_prototypes(prototypes_to_prune) 60 | #torch.save(obj=prototype_network_parallel.module, 61 | # f=os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number, 62 | # k, 63 | # prune_threshold), 64 | # model_name + '-pruned.pth')) 65 | if copy_prototype_imgs: 66 | original_img_dir = os.path.join(original_model_dir, 'img', 'epoch-%d' % epoch_number) 67 | dst_img_dir = os.path.join(original_model_dir, 68 | 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch_number, 69 | k, 70 | prune_threshold), 71 | 'img', 'epoch-%d' % epoch_number) 72 | makedir(dst_img_dir) 73 | if not prototypes_to_keep: 74 | prototypes_to_keep = list(set(range(original_num_prototypes)) - set(prototypes_to_prune)) 75 | 76 | for idx in range(len(prototypes_to_keep)): 77 | shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img%d.png' % prototypes_to_keep[idx]), 78 | dst=os.path.join(dst_img_dir, 'prototype-img%d.png' % idx)) 79 | 80 | shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img-original%d.png' % prototypes_to_keep[idx]), 81 | dst=os.path.join(dst_img_dir, 'prototype-img-original%d.png' % idx)) 82 | 83 | shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-img-original_with_self_act%d.png' % prototypes_to_keep[idx]), 84 | dst=os.path.join(dst_img_dir, 'prototype-img-original_with_self_act%d.png' % idx)) 85 | 86 | shutil.copyfile(src=os.path.join(original_img_dir, 'prototype-self-act%d.npy' % prototypes_to_keep[idx]), 87 | dst=os.path.join(dst_img_dir, 'prototype-self-act%d.npy' % idx)) 88 | 89 | 90 | bb = np.load(os.path.join(original_img_dir, 'bb%d.npy' % epoch_number)) 91 | bb = bb[prototypes_to_keep] 92 | np.save(os.path.join(dst_img_dir, 'bb%d.npy' % epoch_number), 93 | bb) 94 | 95 | bb_rf = np.load(os.path.join(original_img_dir, 'bb-receptive_field%d.npy' % epoch_number)) 96 | bb_rf = bb_rf[prototypes_to_keep] 97 | np.save(os.path.join(dst_img_dir, 'bb-receptive_field%d.npy' % epoch_number), 98 | bb_rf) 99 | 100 | return prune_info 101 | -------------------------------------------------------------------------------- /receptive_field.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | def compute_layer_rf_info(layer_filter_size, layer_stride, layer_padding, 4 | previous_layer_rf_info): 5 | n_in = previous_layer_rf_info[0] # input size 6 | j_in = previous_layer_rf_info[1] # receptive field jump of input layer 7 | r_in = previous_layer_rf_info[2] # receptive field size of input layer 8 | start_in = previous_layer_rf_info[3] # center of receptive field of input layer 9 | 10 | if layer_padding == 'SAME': 11 | n_out = math.ceil(float(n_in) / float(layer_stride)) 12 | if (n_in % layer_stride == 0): 13 | pad = max(layer_filter_size - layer_stride, 0) 14 | else: 15 | pad = max(layer_filter_size - (n_in % layer_stride), 0) 16 | assert(n_out == math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1) # sanity check 17 | assert(pad == (n_out-1)*layer_stride - n_in + layer_filter_size) # sanity check 18 | elif layer_padding == 'VALID': 19 | n_out = math.ceil(float(n_in - layer_filter_size + 1) / float(layer_stride)) 20 | pad = 0 21 | assert(n_out == math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1) # sanity check 22 | assert(pad == (n_out-1)*layer_stride - n_in + layer_filter_size) # sanity check 23 | else: 24 | # layer_padding is an int that is the amount of padding on one side 25 | pad = layer_padding * 2 26 | n_out = math.floor((n_in - layer_filter_size + pad)/layer_stride) + 1 27 | 28 | pL = math.floor(pad/2) 29 | 30 | j_out = j_in * layer_stride 31 | r_out = r_in + (layer_filter_size - 1)*j_in 32 | start_out = start_in + ((layer_filter_size - 1)/2 - pL)*j_in 33 | return [n_out, j_out, r_out, start_out] 34 | 35 | def compute_rf_protoL_at_spatial_location(img_size, height_index, width_index, protoL_rf_info): 36 | n = protoL_rf_info[0] 37 | j = protoL_rf_info[1] 38 | r = protoL_rf_info[2] 39 | start = protoL_rf_info[3] 40 | assert(height_index < n) 41 | assert(width_index < n) 42 | 43 | center_h = start + (height_index*j) 44 | center_w = start + (width_index*j) 45 | 46 | rf_start_height_index = max(int(center_h - (r/2)), 0) 47 | rf_end_height_index = min(int(center_h + (r/2)), img_size) 48 | 49 | rf_start_width_index = max(int(center_w - (r/2)), 0) 50 | rf_end_width_index = min(int(center_w + (r/2)), img_size) 51 | 52 | return [rf_start_height_index, rf_end_height_index, 53 | rf_start_width_index, rf_end_width_index] 54 | 55 | def compute_rf_prototype(img_size, prototype_patch_index, protoL_rf_info): 56 | img_index = prototype_patch_index[0] 57 | height_index = prototype_patch_index[1] 58 | width_index = prototype_patch_index[2] 59 | rf_indices = compute_rf_protoL_at_spatial_location(img_size, 60 | height_index, 61 | width_index, 62 | protoL_rf_info) 63 | return [img_index, rf_indices[0], rf_indices[1], 64 | rf_indices[2], rf_indices[3]] 65 | 66 | def compute_rf_prototypes(img_size, prototype_patch_indices, protoL_rf_info): 67 | rf_prototypes = [] 68 | for prototype_patch_index in prototype_patch_indices: 69 | img_index = prototype_patch_index[0] 70 | height_index = prototype_patch_index[1] 71 | width_index = prototype_patch_index[2] 72 | rf_indices = compute_rf_protoL_at_spatial_location(img_size, 73 | height_index, 74 | width_index, 75 | protoL_rf_info) 76 | rf_prototypes.append([img_index, rf_indices[0], rf_indices[1], 77 | rf_indices[2], rf_indices[3]]) 78 | return rf_prototypes 79 | 80 | def compute_proto_layer_rf_info(img_size, cfg, prototype_kernel_size): 81 | rf_info = [img_size, 1, 1, 0.5] 82 | 83 | for v in cfg: 84 | if v == 'M': 85 | rf_info = compute_layer_rf_info(layer_filter_size=2, 86 | layer_stride=2, 87 | layer_padding='SAME', 88 | previous_layer_rf_info=rf_info) 89 | else: 90 | rf_info = compute_layer_rf_info(layer_filter_size=3, 91 | layer_stride=1, 92 | layer_padding='SAME', 93 | previous_layer_rf_info=rf_info) 94 | 95 | proto_layer_rf_info = compute_layer_rf_info(layer_filter_size=prototype_kernel_size, 96 | layer_stride=1, 97 | layer_padding='VALID', 98 | previous_layer_rf_info=rf_info) 99 | 100 | return proto_layer_rf_info 101 | 102 | def compute_proto_layer_rf_info_v2(img_size, layer_filter_sizes, layer_strides, layer_paddings, prototype_kernel_size): 103 | 104 | assert(len(layer_filter_sizes) == len(layer_strides)) 105 | assert(len(layer_filter_sizes) == len(layer_paddings)) 106 | 107 | rf_info = [img_size, 1, 1, 0.5] 108 | 109 | for i in range(len(layer_filter_sizes)): 110 | filter_size = layer_filter_sizes[i] 111 | stride_size = layer_strides[i] 112 | padding_size = layer_paddings[i] 113 | 114 | rf_info = compute_layer_rf_info(layer_filter_size=filter_size, 115 | layer_stride=stride_size, 116 | layer_padding=padding_size, 117 | previous_layer_rf_info=rf_info) 118 | 119 | proto_layer_rf_info = compute_layer_rf_info(layer_filter_size=prototype_kernel_size, 120 | layer_stride=1, 121 | layer_padding='VALID', 122 | previous_layer_rf_info=rf_info) 123 | 124 | return proto_layer_rf_info 125 | 126 | -------------------------------------------------------------------------------- /run_pruning.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | import torch.utils.data 5 | import torchvision.transforms as transforms 6 | import torchvision.datasets as datasets 7 | import numpy as np 8 | import argparse 9 | from dataHelper import DatasetFolder 10 | from helpers import makedir 11 | import model 12 | import last_layer 13 | import push 14 | import prune 15 | import find_nearest 16 | import train_and_test as tnt 17 | import save 18 | from log import create_logger 19 | from preprocess import mean, std, preprocess_input_function 20 | 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument('-modeldir', nargs=1, type=str) 23 | parser.add_argument('-model', nargs=1, type=str) 24 | parser.add_argument('-train_dir', nargs=1, type=str) 25 | parser.add_argument('-test_dir', nargs=1, type=str) 26 | parser.add_argument('-push_dir', nargs=1, type=str) 27 | args = parser.parse_args() 28 | 29 | optimize_last_layer = True 30 | 31 | proto_to_keep = [0,1,5,9,10,11] #for model /usr/xtmp/mammo/saved_models/vgg16/0125_topkk=9_fa=0.001_random=4/50_9push0.9645.pth 32 | # pruning parameters 33 | 34 | k = 5 35 | prune_threshold = 3 36 | 37 | original_model_dir = args.modeldir[0] 38 | original_model_name = args.model[0] 39 | train_dir, test_dir, train_push_dir = args.train_dir[0], args.test_dir[0], args.push_dir[0] 40 | 41 | need_push = ('nopush' in original_model_name) 42 | if need_push: 43 | assert(False) # pruning must happen after push 44 | else: 45 | epoch = original_model_name.split('push')[0] 46 | 47 | if '_' in epoch: 48 | epoch = int(epoch.split('_')[0]) 49 | else: 50 | epoch = int(epoch) 51 | 52 | model_dir = os.path.join(original_model_dir, 'pruned_prototypes_epoch{}_k{}_pt{}'.format(epoch, 53 | k, 54 | prune_threshold)) 55 | makedir(model_dir) 56 | shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir) 57 | 58 | log, logclose = create_logger(log_filename=os.path.join(model_dir, 'prune.log')) 59 | 60 | ppnet = torch.load(original_model_dir + original_model_name) 61 | ppnet = ppnet.cuda() 62 | ppnet_multi = torch.nn.DataParallel(ppnet) 63 | class_specific = True 64 | 65 | train_batch_size = 80 66 | test_batch_size = 100 67 | img_size = 224 68 | train_push_batch_size = 80 69 | 70 | # all datasets 71 | # train set 72 | train_dataset = DatasetFolder( 73 | train_dir, 74 | augmentation=False, 75 | loader=np.load, 76 | extensions=("npy",), 77 | transform = transforms.Compose([ 78 | torch.from_numpy, 79 | ])) 80 | train_loader = torch.utils.data.DataLoader( 81 | train_dataset, batch_size=train_batch_size, shuffle=True, 82 | num_workers=4, pin_memory=False) 83 | 84 | # push set 85 | train_push_dataset = DatasetFolder( 86 | root = train_push_dir, 87 | loader = np.load, 88 | extensions=("npy",), 89 | transform = transforms.Compose([ 90 | torch.from_numpy, 91 | ])) 92 | train_push_loader = torch.utils.data.DataLoader( 93 | train_push_dataset, batch_size=train_push_batch_size, shuffle=False, 94 | num_workers=4, pin_memory=False) 95 | 96 | # test set 97 | test_dataset =DatasetFolder( 98 | test_dir, 99 | loader=np.load, 100 | extensions=("npy",), 101 | transform = transforms.Compose([ 102 | torch.from_numpy, 103 | ])) 104 | test_loader = torch.utils.data.DataLoader( 105 | test_dataset, batch_size=test_batch_size, shuffle=False, 106 | num_workers=4, pin_memory=False) 107 | 108 | 109 | log('push set size: {0}'.format(len(train_push_loader.dataset))) 110 | 111 | tnt.test(model=ppnet_multi, dataloader=test_loader, 112 | class_specific=class_specific, log=log) 113 | print(find_nearest.find_k_nearest_patches_to_prototypes(dataloader=train_push_loader, 114 | prototype_network_parallel=ppnet_multi, 115 | k=5, 116 | preprocess_input_function=preprocess_input_function, 117 | full_save=False, 118 | log=log)) 119 | print("last layer trasnpose: \n", last_layer.show_last_layer_connections_T(ppnet)) 120 | 121 | # prune prototypes 122 | log('========================================================prune======================================================') 123 | prune.prune_prototypes(dataloader=train_push_loader, 124 | prototype_network_parallel=ppnet_multi, 125 | k=k, 126 | prune_threshold=prune_threshold, 127 | preprocess_input_function=preprocess_input_function, # normalize 128 | original_model_dir=original_model_dir, 129 | epoch_number=epoch, 130 | #model_name=None, 131 | log=log, 132 | copy_prototype_imgs=True, 133 | prototypes_to_keep=proto_to_keep) 134 | accu = tnt.test(model=ppnet_multi, dataloader=test_loader, 135 | class_specific=class_specific, log=log) 136 | print(find_nearest.find_k_nearest_patches_to_prototypes(dataloader=train_push_loader, 137 | prototype_network_parallel=ppnet_multi, 138 | k=5, 139 | preprocess_input_function=preprocess_input_function, 140 | full_save=False, 141 | log=log)) 142 | print("last layer trasnpose: \n", last_layer.show_last_layer_connections_T(ppnet)) 143 | save.save_model_w_condition(model=ppnet, model_dir=model_dir, 144 | model_name=original_model_name.split('push')[0] + 'prune', 145 | accu=accu, 146 | target_accu=0.70, log=log) 147 | 148 | # last layer optimization 149 | if optimize_last_layer: 150 | last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': 1e-4}] 151 | last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs) 152 | 153 | from settings import coefs 154 | 155 | log('optimize last layer') 156 | tnt.last_only(model=ppnet_multi, log=log) 157 | for i in range(25): 158 | log('iteration: \t{0}'.format(i)) 159 | _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer, 160 | class_specific=class_specific, coefs=coefs, log=log) 161 | accu = tnt.test(model=ppnet_multi, dataloader=test_loader, 162 | class_specific=class_specific, log=log) 163 | print("last layer trasnpose: \n", last_layer.show_last_layer_connections_T(ppnet)) 164 | save.save_model_w_condition(model=ppnet, model_dir=model_dir, 165 | model_name=original_model_name.split('push')[0] + '_' + str(i) + 'prune', 166 | accu=accu, 167 | target_accu=0.70, log=log) 168 | logclose() 169 | -------------------------------------------------------------------------------- /dataHandling.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import argparse 6 | import sys 7 | import random 8 | import png 9 | from matplotlib.pyplot import imsave, imread 10 | import matplotlib 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | import cv2 14 | matplotlib.use("Agg") 15 | import torchvision.datasets as datasets 16 | from skimage.transform import resize 17 | import ast 18 | import pickle 19 | import csv 20 | import pydicom as dcm 21 | import Augmentor 22 | from tqdm import tqdm 23 | import pathlib 24 | from torch import randint, manual_seed 25 | from copy import copy 26 | from collections import defaultdict 27 | 28 | def random_flip(input, axis, with_fa=False): 29 | ran = random.random() 30 | if ran > 0.5: 31 | if with_fa: 32 | axis += 1 33 | return np.flip(input, axis=axis) 34 | else: 35 | return input 36 | 37 | def random_crop(input, with_fa=False): 38 | ran = random.random() 39 | if ran > 0.2: 40 | # find a random place to be the left upper corner of the crop 41 | if with_fa: 42 | rx = int(random.random() * input.shape[1] // 10) 43 | ry = int(random.random() * input.shape[2] // 10) 44 | return input[:, rx: rx + int(input.shape[1] * 9 // 10), ry: ry + int(input.shape[2] * 9 // 10)] 45 | else: 46 | rx = int(random.random() * input.shape[0] // 10) 47 | ry = int(random.random() * input.shape[1] // 10) 48 | return input[rx: rx + int(input.shape[0] * 9 // 10), ry: ry + int(input.shape[1] * 9 // 10)] 49 | else: 50 | return input 51 | 52 | def random_rotate_90(input, with_fa=False): 53 | ran = random.random() 54 | if ran > 0.5: 55 | if with_fa: 56 | return np.rot90(input, axes=(1,2)) 57 | return np.rot90(input) 58 | else: 59 | return input 60 | 61 | def random_rotation(x, chance, with_fa=False): 62 | ran = random.random() 63 | if with_fa: 64 | img = Image.fromarray(x[0]) 65 | mask = Image.fromarray(x[1]) 66 | if ran > 1 - chance: 67 | # create black edges 68 | angle = np.random.randint(0, 90) 69 | img = img.rotate(angle=angle, expand=1) 70 | mask = mask.rotate(angle=angle, expand=1, fillcolor=1) 71 | return np.stack([np.asarray(img), np.asarray(mask)]) 72 | else: 73 | return np.stack([np.asarray(img), np.asarray(mask)]) 74 | img = Image.fromarray(x) 75 | if ran > 1 - chance: 76 | # create black edges 77 | angle = np.random.randint(0, 90) 78 | img = img.rotate(angle=angle, expand=1) 79 | return np.asarray(img) 80 | else: 81 | return np.asarray(img) 82 | 83 | def augment_numpy_images(path, targetNumber, targetDir, skip=None, rot=True, with_fa=False): 84 | classes = os.listdir(path) 85 | if not os.path.exists(targetDir): 86 | os.mkdir(targetDir) 87 | for class_ in classes: 88 | if not os.path.exists(targetDir + class_): 89 | os.makedirs(targetDir + class_) 90 | 91 | for class_ in classes: 92 | count, round = 0, 0 93 | while count < targetNumber: 94 | round += 1 95 | for root, dir, files in os.walk(os.path.join(path, class_)): 96 | for file in files: 97 | if skip and skip in file: 98 | continue 99 | filepath = os.path.join(root, file) 100 | arr = np.load(filepath) 101 | print("loaded ", file) 102 | print(arr.shape) 103 | try: 104 | arr = random_crop(arr, with_fa) 105 | print(arr.shape) 106 | if rot: 107 | arr = random_rotation(arr, 0.9, with_fa) 108 | print(arr.shape) 109 | arr = random_flip(arr, 0, with_fa) 110 | arr = random_flip(arr, 1, with_fa) 111 | arr = random_rotate_90(arr, with_fa) 112 | arr = random_rotate_90(arr, with_fa) 113 | arr = random_rotate_90(arr, with_fa) 114 | print(arr.shape) 115 | if with_fa: 116 | whites = arr.shape[2] * arr.shape[1] - np.count_nonzero(np.round(arr[0] - np.amax(arr[0]), 2)) 117 | black = arr.shape[2] * arr.shape[1] - np.count_nonzero(np.round(arr[0], 2)) 118 | if arr.shape[2] < 10 or arr.shape[1] < 10 or black >= arr.shape[2] * arr.shape[1] * 0.8 or \ 119 | whites >= arr.shape[2] * arr.shape[1] * 0.8: 120 | print("illegal content") 121 | continue 122 | 123 | else: 124 | whites = arr.shape[0] * arr.shape[1] - np.count_nonzero(np.round(arr - np.amax(arr), 2)) 125 | black = arr.shape[0] * arr.shape[1] - np.count_nonzero(np.round(arr, 2)) 126 | 127 | if arr.shape[0] < 10 or arr.shape[1] < 10 or black >= arr.shape[0] * arr.shape[1] * 0.8 or \ 128 | whites >= arr.shape[0] * arr.shape[1] * 0.8: 129 | print("illegal content") 130 | continue 131 | 132 | if count % 10 == 0: 133 | if not os.path.exists("./visualizations_of_augmentation/" + class_ + "/"): 134 | os.makedirs("./visualizations_of_augmentation/" + class_ + "/") 135 | if with_fa: 136 | imsave("./visualizations_of_augmentation/" + class_ + "/" + str(count), np.transpose(np.stack([arr[0], arr[0], arr[1]]), (1,2,0))) 137 | else: 138 | imsave("./visualizations_of_augmentation/" + class_ + "/" + str(count), np.transpose(np.stack([arr, arr, arr]), (1,2,0))) 139 | 140 | 141 | np.save(targetDir + class_ + "/" + file[:-4] + "aug" + str(round), arr) 142 | count += 1 143 | print(count) 144 | except: 145 | print("something is wrong in try, details:", sys.exc_info()[2]) 146 | if not os.path.exists("./error_of_augmentation/" + class_ + "/"): 147 | os.makedirs("./error_of_augmentation/" + class_ + "/") 148 | np.save("./error_of_augmentation/" + class_ + "/" + str(count), arr) 149 | if count > targetNumber: 150 | break 151 | print(count) 152 | 153 | def window_augmentation(wwidth, wcen): 154 | if wcen == 2047 and wwidth == 4096: 155 | return wwidth, wcen 156 | else: 157 | new_wcen = np.random.randint(-100, 300) 158 | new_wwidth = np.random.randint(-200, 300) 159 | wwidth += new_wwidth 160 | wcen += new_wcen 161 | return wwidth, wcen 162 | 163 | if __name__ == "__main__": 164 | 165 | print("Data augmentation") 166 | for pos in ["Spiculated","Circumscribed", "Indistinct"]: 167 | augment_numpy_images( 168 | path="/usr/xtmp/mammo/npdata/datasetname_with_fa/train/", 169 | targetNumber=5000, 170 | targetDir="/usr/xtmp/mammo/npdata/datasetname_with_fa/train_augmented_5000/", 171 | rot=True, 172 | with_fa=True) 173 | 174 | -------------------------------------------------------------------------------- /global_analysis.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data 3 | import torchvision.transforms as transforms 4 | import torchvision.datasets as datasets 5 | from dataHelper import DatasetFolder 6 | import numpy as np 7 | import cv2 8 | import matplotlib.pyplot as plt 9 | 10 | import re 11 | from settings import prototype_activation_function_in_numpy 12 | import os 13 | 14 | from helpers import makedir 15 | import model 16 | import find_nearest 17 | import last_layer 18 | import train_and_test as tnt 19 | 20 | from preprocess import preprocess_input_function 21 | from highlighting_precision import hp 22 | 23 | import argparse 24 | 25 | # Usage: python3 global_analysis.py -modeldir='./saved_models/' -model='' 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('-gpuid', nargs=1, type=str, default='0') 28 | parser.add_argument('-modeldir', nargs=1, type=str) 29 | parser.add_argument('-model', nargs=1, type=str) 30 | parser.add_argument('-test_dir', nargs=1, type=str) 31 | parser.add_argument('-push_dir', nargs=1, type=str) 32 | #parser.add_argument('-dataset', nargs=1, type=str, default='cub200') 33 | args = parser.parse_args() 34 | 35 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuid[0] 36 | load_model_dir = args.modeldir[0] 37 | load_model_name = args.model[0] 38 | 39 | load_model_path = os.path.join(load_model_dir, load_model_name) 40 | epoch_number_str = re.search(r'\d+', load_model_name).group(0) 41 | start_epoch_number = int(epoch_number_str) 42 | 43 | # load the model 44 | print('load model from ' + load_model_path) 45 | ppnet = torch.load(load_model_path) 46 | # print('convert to maxpool logic') 47 | # ppnet.set_topk_k(1) 48 | ppnet = ppnet.cuda() 49 | ppnet_multi = torch.nn.DataParallel(ppnet) 50 | 51 | img_size = ppnet_multi.module.img_size 52 | 53 | # load the data 54 | # must use unaugmented (original) dataset 55 | train_dir = args.push_dir[0] 56 | test_dir = args.test_dir[0] 57 | 58 | batch_size = 100 59 | 60 | # train set: do not normalize 61 | train_dataset = DatasetFolder( 62 | train_dir, 63 | augmentation=False, 64 | loader=np.load, 65 | extensions=("npy",), 66 | transform = transforms.Compose([ 67 | torch.from_numpy, 68 | ])) 69 | train_loader = torch.utils.data.DataLoader( 70 | train_dataset, batch_size=batch_size, shuffle=True, 71 | num_workers=4, pin_memory=False) 72 | # test set 73 | test_dataset =DatasetFolder( 74 | test_dir, 75 | loader=np.load, 76 | extensions=("npy",), 77 | transform = transforms.Compose([ 78 | torch.from_numpy, 79 | ])) 80 | test_loader = torch.utils.data.DataLoader( 81 | test_dataset, batch_size=batch_size, shuffle=False, 82 | num_workers=4, pin_memory=False) 83 | 84 | root_dir_for_saving_train_images = os.path.join(load_model_dir, 85 | load_model_name.split('.pth')[0] + '_nearest_train') 86 | root_dir_for_saving_test_images = os.path.join(load_model_dir, 87 | load_model_name.split('.pth')[0] + '_nearest_test') 88 | makedir(root_dir_for_saving_train_images) 89 | makedir(root_dir_for_saving_test_images) 90 | 91 | # save prototypes in original images 92 | load_img_dir = os.path.join(load_model_dir, 'img') 93 | prototype_info = np.load(os.path.join(load_img_dir, 'epoch-'+str(start_epoch_number), 'bb'+str(start_epoch_number)+'.npy')) 94 | def save_prototype_original_img_with_bbox(fname, epoch, index, 95 | bbox_height_start, bbox_height_end, 96 | bbox_width_start, bbox_width_end, color=(0, 255, 255)): 97 | p_img_bgr = cv2.imread(os.path.join(load_img_dir, 'epoch-'+str(epoch), 'prototype-img-original'+str(index)+'.png')) 98 | cv2.rectangle(p_img_bgr, (bbox_width_start, bbox_height_start), (bbox_width_end-1, bbox_height_end-1), 99 | color, thickness=2) 100 | p_img_rgb = p_img_bgr[...,::-1] 101 | p_img_rgb = np.float32(p_img_rgb) / 255 102 | #plt.imshow(p_img_rgb) 103 | #plt.axis('off') 104 | plt.imsave(fname, p_img_rgb) 105 | 106 | for j in range(ppnet.num_prototypes): 107 | makedir(os.path.join(root_dir_for_saving_train_images, str(j))) 108 | makedir(os.path.join(root_dir_for_saving_test_images, str(j))) 109 | save_prototype_original_img_with_bbox(fname=os.path.join(root_dir_for_saving_train_images, str(j), 110 | 'prototype_in_original_pimg.png'), 111 | epoch=start_epoch_number, 112 | index=j, 113 | bbox_height_start=prototype_info[j][1], 114 | bbox_height_end=prototype_info[j][2], 115 | bbox_width_start=prototype_info[j][3], 116 | bbox_width_end=prototype_info[j][4], 117 | color=(0, 255, 255)) 118 | save_prototype_original_img_with_bbox(fname=os.path.join(root_dir_for_saving_test_images, str(j), 119 | 'prototype_in_original_pimg.png'), 120 | epoch=start_epoch_number, 121 | index=j, 122 | bbox_height_start=prototype_info[j][1], 123 | bbox_height_end=prototype_info[j][2], 124 | bbox_width_start=prototype_info[j][3], 125 | bbox_width_end=prototype_info[j][4], 126 | color=(0, 255, 255)) 127 | 128 | k = 5 129 | 130 | print(find_nearest.find_k_nearest_patches_to_prototypes( 131 | dataloader=train_loader, # pytorch dataloader (must be unnormalized in [0,1]) 132 | prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors 133 | k=k+1, 134 | preprocess_input_function=preprocess_input_function, # normalize if needed 135 | full_save=True, 136 | root_dir_for_saving_images=root_dir_for_saving_train_images, 137 | prototype_activation_function_in_numpy=prototype_activation_function_in_numpy, 138 | log=print)) 139 | 140 | print(find_nearest.find_k_nearest_patches_to_prototypes( 141 | dataloader=test_loader, # pytorch dataloader (must be unnormalized in [0,1]) 142 | prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors 143 | k=k, 144 | preprocess_input_function=preprocess_input_function, # normalize if needed 145 | full_save=True, 146 | root_dir_for_saving_images=root_dir_for_saving_test_images, 147 | prototype_activation_function_in_numpy=prototype_activation_function_in_numpy, 148 | log=print)) 149 | 150 | #activation precisions by proto 151 | per_proto_lhp = np.asarray(hp(test_dir, load_model_path, per_proto=True)) 152 | per_proto_fhp = np.asarray(hp('/usr/xtmp/mammo/Lo1136i_finer/by_margin/test/', load_model_path, per_proto=True)) 153 | print("Avg lesion activation precision: ", per_proto_lhp) 154 | print("Avg fine activation precision: ", per_proto_fhp) 155 | 156 | #class connection weights 157 | print("last layer trasnpose: \n", last_layer.show_last_layer_connections_T(ppnet)) 158 | 159 | #activation precision averages 160 | avg_lhp = hp(test_dir, load_model_path) 161 | avg_fhp = hp('/usr/xtmp/mammo/Lo1136i_finer/by_margin/test/', load_model_path) 162 | pm_lhp = 1.96 * per_proto_lhp.std(axis=0)[1] / np.sqrt(per_proto_lhp.shape[0]) 163 | pm_fhp = 1.96 * per_proto_fhp.std(axis=0)[1] / np.sqrt(per_proto_fhp.shape[0]) 164 | print("Avg lesion activation precision: ", avg_lhp,\ 165 | " pm ", pm_lhp, ". \n", avg_lhp-pm_lhp, " to ", avg_lhp+pm_lhp) 166 | print("Avg fine activation precision: ", avg_fhp,\ 167 | " pm ", pm_fhp, ". \n", avg_fhp-pm_fhp, " to ", avg_fhp+pm_fhp) 168 | 169 | print("see analysis in ", root_dir_for_saving_train_images) 170 | print("see analysis in ", root_dir_for_saving_test_images) -------------------------------------------------------------------------------- /our_vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | # from .utils import load_state_dict_from_url 4 | from typing import Union, List, Dict, Any, cast 5 | 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn', 9 | 'vgg19_bn', 'vgg19', 10 | ] 11 | 12 | 13 | model_urls = { 14 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 15 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 16 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 17 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 18 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 19 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 20 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 21 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 22 | } 23 | 24 | 25 | class VGG(nn.Module): 26 | 27 | def __init__( 28 | self, 29 | features: nn.Module, 30 | num_classes: int = 1000, 31 | init_weights: bool = True 32 | ) -> None: 33 | super(VGG, self).__init__() 34 | self.features = features 35 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 36 | self.classifier = nn.Sequential( 37 | nn.Linear(512 * 7 * 7, 512), 38 | nn.ReLU(True), 39 | nn.Dropout(), 40 | nn.Linear(512, 512), 41 | nn.ReLU(True), 42 | nn.Dropout(), 43 | nn.Linear(512, 3), 44 | ) 45 | if init_weights: 46 | self._initialize_weights() 47 | 48 | def forward(self, x: torch.Tensor) -> torch.Tensor: 49 | x = self.features(x) 50 | x = self.avgpool(x) 51 | x = torch.flatten(x, 1) 52 | x = self.classifier(x) 53 | return x 54 | 55 | def _initialize_weights(self) -> None: 56 | for m in self.modules(): 57 | if isinstance(m, nn.Conv2d): 58 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 59 | if m.bias is not None: 60 | nn.init.constant_(m.bias, 0) 61 | elif isinstance(m, nn.BatchNorm2d): 62 | nn.init.constant_(m.weight, 1) 63 | nn.init.constant_(m.bias, 0) 64 | elif isinstance(m, nn.Linear): 65 | nn.init.normal_(m.weight, 0, 0.01) 66 | nn.init.constant_(m.bias, 0) 67 | 68 | 69 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 70 | layers: List[nn.Module] = [] 71 | in_channels = 3 72 | for v in cfg: 73 | if v == 'M': 74 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 75 | else: 76 | v = cast(int, v) 77 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 78 | if batch_norm: 79 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 80 | else: 81 | layers += [conv2d, nn.ReLU(inplace=True)] 82 | in_channels = v 83 | return nn.Sequential(*layers) 84 | 85 | 86 | cfgs: Dict[str, List[Union[str, int]]] = { 87 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 88 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 89 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'], 90 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 91 | } 92 | 93 | 94 | def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG: 95 | if pretrained: 96 | kwargs['init_weights'] = False 97 | model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) 98 | if pretrained: 99 | state_dict = load_state_dict_from_url(model_urls[arch], 100 | progress=progress) 101 | model.load_state_dict(state_dict) 102 | return model 103 | 104 | 105 | def vgg11(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 106 | r"""VGG 11-layer model (configuration "A") from 107 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 108 | Args: 109 | pretrained (bool): If True, returns a model pre-trained on ImageNet 110 | progress (bool): If True, displays a progress bar of the download to stderr 111 | """ 112 | return _vgg('vgg11', 'A', False, pretrained, progress, **kwargs) 113 | 114 | 115 | def vgg11_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 116 | r"""VGG 11-layer model (configuration "A") with batch normalization 117 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 118 | Args: 119 | pretrained (bool): If True, returns a model pre-trained on ImageNet 120 | progress (bool): If True, displays a progress bar of the download to stderr 121 | """ 122 | return _vgg('vgg11_bn', 'A', True, pretrained, progress, **kwargs) 123 | 124 | 125 | def vgg13(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 126 | r"""VGG 13-layer model (configuration "B") 127 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | progress (bool): If True, displays a progress bar of the download to stderr 131 | """ 132 | return _vgg('vgg13', 'B', False, pretrained, progress, **kwargs) 133 | 134 | 135 | def vgg13_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 136 | r"""VGG 13-layer model (configuration "B") with batch normalization 137 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 138 | Args: 139 | pretrained (bool): If True, returns a model pre-trained on ImageNet 140 | progress (bool): If True, displays a progress bar of the download to stderr 141 | """ 142 | return _vgg('vgg13_bn', 'B', True, pretrained, progress, **kwargs) 143 | 144 | 145 | def vgg16(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 146 | r"""VGG 16-layer model (configuration "D") 147 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 148 | Args: 149 | pretrained (bool): If True, returns a model pre-trained on ImageNet 150 | progress (bool): If True, displays a progress bar of the download to stderr 151 | """ 152 | return _vgg('vgg16', 'D', False, pretrained, progress, **kwargs) 153 | 154 | 155 | def vgg16_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 156 | r"""VGG 16-layer model (configuration "D") with batch normalization 157 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 158 | Args: 159 | pretrained (bool): If True, returns a model pre-trained on ImageNet 160 | progress (bool): If True, displays a progress bar of the download to stderr 161 | """ 162 | return _vgg('vgg16_bn', 'D', True, pretrained, progress, **kwargs) 163 | 164 | 165 | def vgg19(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 166 | r"""VGG 19-layer model (configuration "E") 167 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 168 | Args: 169 | pretrained (bool): If True, returns a model pre-trained on ImageNet 170 | progress (bool): If True, displays a progress bar of the download to stderr 171 | """ 172 | return _vgg('vgg19', 'E', False, pretrained, progress, **kwargs) 173 | 174 | 175 | def vgg19_bn(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> VGG: 176 | r"""VGG 19-layer model (configuration 'E') with batch normalization 177 | `"Very Deep Convolutional Networks For Large-Scale Image Recognition" `._ 178 | Args: 179 | pretrained (bool): If True, returns a model pre-trained on ImageNet 180 | progress (bool): If True, displays a progress bar of the download to stderr 181 | """ 182 | return _vgg('vgg19_bn', 'E', True, pretrained, progress, **kwargs) -------------------------------------------------------------------------------- /gradcam.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/stefannc/GradCAM-Pytorch/blob/07fd6ece5010f7c1c9fbcc8155a60023819111d7/gradcam.py retrieved Mar 4 2021 ##### 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | from gradcam_utils import find_alexnet_layer, find_vgg_layer, find_vgg_us_layer, find_resnet_layer, find_densenet_layer, find_squeezenet_layer 7 | 8 | 9 | class GradCAM(object): 10 | """Calculate GradCAM salinecy map. 11 | A simple example: 12 | # initialize a model, model_dict and gradcam 13 | resnet = torchvision.models.resnet101(pretrained=True) 14 | resnet.eval() 15 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 16 | gradcam = GradCAM(model_dict) 17 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 18 | img = load_img() 19 | normed_img = normalizer(img) 20 | # get a GradCAM saliency map on the class index 10. 21 | mask, logit = gradcam(normed_img, class_idx=10) 22 | # make heatmap from mask and synthesize saliency map using heatmap and img 23 | heatmap, cam_result = visualize_cam(mask, img) 24 | Args: 25 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 26 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 27 | """ 28 | def __init__(self, model_dict, verbose=False): 29 | model_type = model_dict['type'] 30 | layer_name = model_dict['layer_name'] 31 | self.model_arch = model_dict['arch'] 32 | 33 | self.gradients = dict() 34 | self.activations = dict() 35 | def backward_hook(module, grad_input, grad_output): 36 | self.gradients['value'] = grad_output[0] 37 | return None 38 | def forward_hook(module, input, output): 39 | self.activations['value'] = output 40 | return None 41 | 42 | if 'vgg_us' in model_type.lower(): 43 | target_layer = find_vgg_us_layer(self.model_arch, layer_name) 44 | elif 'vgg' in model_type.lower(): 45 | target_layer = find_vgg_layer(self.model_arch, layer_name) 46 | elif 'resnet' in model_type.lower(): 47 | target_layer = find_resnet_layer(self.model_arch, layer_name) 48 | elif 'densenet' in model_type.lower(): 49 | target_layer = find_densenet_layer(self.model_arch, layer_name) 50 | elif 'alexnet' in model_type.lower(): 51 | target_layer = find_alexnet_layer(self.model_arch, layer_name) 52 | elif 'squeezenet' in model_type.lower(): 53 | target_layer = find_squeezenet_layer(self.model_arch, layer_name) 54 | 55 | target_layer.register_forward_hook(forward_hook) 56 | target_layer.register_backward_hook(backward_hook) 57 | 58 | if verbose: 59 | try: 60 | input_size = model_dict['input_size'] 61 | except KeyError: 62 | print("please specify size of input image in model_dict. e.g. {'input_size':(224, 224)}") 63 | pass 64 | else: 65 | device = 'cuda' if next(self.model_arch.parameters()).is_cuda else 'cpu' 66 | self.model_arch(torch.zeros(1, 3, *(input_size), device=device)) 67 | print('saliency_map size :', self.activations['value'].shape[2:]) 68 | 69 | 70 | def forward(self, input, class_idx=None, retain_graph=False): 71 | """ 72 | Args: 73 | input: input image with shape of (1, 3, H, W) 74 | class_idx (int): class index for calculating GradCAM. 75 | If not specified, the class index that makes the highest model prediction score will be used. 76 | Return: 77 | mask: saliency map of the same spatial dimension with input 78 | logit: model output 79 | """ 80 | # print("input size from gradcam.py: ", input.size()) 81 | b, c, h, w = input.size() 82 | 83 | logit = self.model_arch(input) 84 | if class_idx is None: 85 | score = logit[:, logit.max(1)[-1]].squeeze() 86 | else: 87 | score = logit[:, class_idx].squeeze() 88 | 89 | self.model_arch.zero_grad() 90 | score.backward(retain_graph=retain_graph) 91 | gradients = self.gradients['value'] 92 | activations = self.activations['value'] 93 | b, k, u, v = gradients.size() 94 | 95 | alpha = gradients.view(b, k, -1).mean(2) 96 | #alpha = F.relu(gradients.view(b, k, -1)).mean(2) 97 | weights = alpha.view(b, k, 1, 1) 98 | 99 | saliency_map = (weights*activations).sum(1, keepdim=True) 100 | saliency_map = F.relu(saliency_map) 101 | saliency_map = F.upsample(saliency_map, size=(h, w), mode='bilinear', align_corners=False) 102 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 103 | saliency_map = (saliency_map - saliency_map_min).div(saliency_map_max - saliency_map_min).data 104 | 105 | return saliency_map, logit 106 | 107 | def __call__(self, input, class_idx=None, retain_graph=False): 108 | return self.forward(input, class_idx, retain_graph) 109 | 110 | 111 | class GradCAMpp(GradCAM): 112 | """Calculate GradCAM++ salinecy map. 113 | A simple example: 114 | # initialize a model, model_dict and gradcampp 115 | resnet = torchvision.models.resnet101(pretrained=True) 116 | resnet.eval() 117 | model_dict = dict(model_type='resnet', arch=resnet, layer_name='layer4', input_size=(224, 224)) 118 | gradcampp = GradCAMpp(model_dict) 119 | # get an image and normalize with mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 120 | img = load_img() 121 | normed_img = normalizer(img) 122 | # get a GradCAM saliency map on the class index 10. 123 | mask, logit = gradcampp(normed_img, class_idx=10) 124 | # make heatmap from mask and synthesize saliency map using heatmap and img 125 | heatmap, cam_result = visualize_cam(mask, img) 126 | Args: 127 | model_dict (dict): a dictionary that contains 'model_type', 'arch', layer_name', 'input_size'(optional) as keys. 128 | verbose (bool): whether to print output size of the saliency map givien 'layer_name' and 'input_size' in model_dict. 129 | """ 130 | def __init__(self, model_dict, verbose=False): 131 | super(GradCAMpp, self).__init__(model_dict, verbose) 132 | 133 | def forward(self, input, class_idx=None, retain_graph=False): 134 | """ 135 | Args: 136 | input: input image with shape of (1, 3, H, W) 137 | class_idx (int): class index for calculating GradCAM. 138 | If not specified, the class index that makes the highest model prediction score will be used. 139 | Return: 140 | mask: saliency map of the same spatial dimension with input 141 | logit: model output 142 | """ 143 | b, c, h, w = input.size() 144 | 145 | logit = self.model_arch(input) 146 | if class_idx is None: 147 | score = logit[:, logit.max(1)[-1]].squeeze() 148 | else: 149 | score = logit[:, class_idx].squeeze() 150 | 151 | self.model_arch.zero_grad() 152 | score.backward(retain_graph=retain_graph) 153 | gradients = self.gradients['value'] # dS/dA 154 | activations = self.activations['value'] # A 155 | b, k, u, v = gradients.size() 156 | 157 | alpha_num = gradients.pow(2) 158 | alpha_denom = gradients.pow(2).mul(2) + \ 159 | activations.mul(gradients.pow(3)).view(b, k, u*v).sum(-1, keepdim=True).view(b, k, 1, 1) 160 | alpha_denom = torch.where(alpha_denom != 0.0, alpha_denom, torch.ones_like(alpha_denom)) 161 | 162 | alpha = alpha_num.div(alpha_denom+1e-7) 163 | positive_gradients = F.relu(score.exp()*gradients) # ReLU(dY/dA) == ReLU(exp(S)*dS/dA)) 164 | weights = (alpha*positive_gradients).view(b, k, u*v).sum(-1).view(b, k, 1, 1) 165 | 166 | saliency_map = (weights*activations).sum(1, keepdim=True) 167 | saliency_map = F.relu(saliency_map) 168 | saliency_map = F.upsample(saliency_map, size=(224, 224), mode='bilinear', align_corners=False) 169 | saliency_map_min, saliency_map_max = saliency_map.min(), saliency_map.max() 170 | saliency_map = (saliency_map-saliency_map_min).div(saliency_map_max-saliency_map_min).data 171 | 172 | return saliency_map, logit -------------------------------------------------------------------------------- /delong.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.stats 3 | from scipy import stats 4 | 5 | # Ref: https://stackoverflow.com/a/53180614/7521428 6 | 7 | # Original creator comment included below. 8 | """ 9 | Created on Tue Nov 6 10:06:52 2018 10 | 11 | @author: yandexdataschool 12 | 13 | Original Code found in: 14 | https://github.com/yandexdataschool/roc_comparison 15 | 16 | updated: Raul Sanchez-Vazquez 17 | """ 18 | 19 | # AUC comparison adapted from 20 | # https://github.com/Netflix/vmaf/ 21 | def compute_midrank(x): 22 | """Computes midranks. 23 | Args: 24 | x - a 1D numpy array 25 | Returns: 26 | array of midranks 27 | """ 28 | J = np.argsort(x) 29 | Z = x[J] 30 | N = len(x) 31 | T = np.zeros(N, dtype=np.float) 32 | i = 0 33 | while i < N: 34 | j = i 35 | while j < N and Z[j] == Z[i]: 36 | j += 1 37 | T[i:j] = 0.5*(i + j - 1) 38 | i = j 39 | T2 = np.empty(N, dtype=np.float) 40 | # Note(kazeevn) +1 is due to Python using 0-based indexing 41 | # instead of 1-based in the AUC formula in the paper 42 | T2[J] = T + 1 43 | return T2 44 | 45 | 46 | def compute_midrank_weight(x, sample_weight): 47 | """Computes midranks. 48 | Args: 49 | x - a 1D numpy array 50 | Returns: 51 | array of midranks 52 | """ 53 | J = np.argsort(x) 54 | Z = x[J] 55 | cumulative_weight = np.cumsum(sample_weight[J]) 56 | N = len(x) 57 | T = np.zeros(N, dtype=np.float) 58 | i = 0 59 | while i < N: 60 | j = i 61 | while j < N and Z[j] == Z[i]: 62 | j += 1 63 | T[i:j] = cumulative_weight[i:j].mean() 64 | i = j 65 | T2 = np.empty(N, dtype=np.float) 66 | T2[J] = T 67 | return T2 68 | 69 | 70 | def fastDeLong(predictions_sorted_transposed, label_1_count, sample_weight): 71 | if sample_weight is None: 72 | return fastDeLong_no_weights(predictions_sorted_transposed, label_1_count) 73 | else: 74 | return fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight) 75 | 76 | 77 | def fastDeLong_weights(predictions_sorted_transposed, label_1_count, sample_weight): 78 | """ 79 | The fast version of DeLong's method for computing the covariance of 80 | unadjusted AUC. 81 | Args: 82 | predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples] 83 | sorted such as the examples with label "1" are first 84 | Returns: 85 | (AUC value, DeLong covariance) 86 | Reference: 87 | @article{sun2014fast, 88 | title={Fast Implementation of DeLong's Algorithm for 89 | Comparing the Areas Under Correlated Receiver Oerating Characteristic Curves}, 90 | author={Xu Sun and Weichao Xu}, 91 | journal={IEEE Signal Processing Letters}, 92 | volume={21}, 93 | number={11}, 94 | pages={1389--1393}, 95 | year={2014}, 96 | publisher={IEEE} 97 | } 98 | """ 99 | # Short variables are named as they are in the paper 100 | m = label_1_count 101 | n = predictions_sorted_transposed.shape[1] - m 102 | positive_examples = predictions_sorted_transposed[:, :m] 103 | negative_examples = predictions_sorted_transposed[:, m:] 104 | k = predictions_sorted_transposed.shape[0] 105 | 106 | tx = np.empty([k, m], dtype=np.float) 107 | ty = np.empty([k, n], dtype=np.float) 108 | tz = np.empty([k, m + n], dtype=np.float) 109 | for r in range(k): 110 | tx[r, :] = compute_midrank_weight(positive_examples[r, :], sample_weight[:m]) 111 | ty[r, :] = compute_midrank_weight(negative_examples[r, :], sample_weight[m:]) 112 | tz[r, :] = compute_midrank_weight(predictions_sorted_transposed[r, :], sample_weight) 113 | total_positive_weights = sample_weight[:m].sum() 114 | total_negative_weights = sample_weight[m:].sum() 115 | pair_weights = np.dot(sample_weight[:m, np.newaxis], sample_weight[np.newaxis, m:]) 116 | total_pair_weights = pair_weights.sum() 117 | aucs = (sample_weight[:m]*(tz[:, :m] - tx)).sum(axis=1) / total_pair_weights 118 | v01 = (tz[:, :m] - tx[:, :]) / total_negative_weights 119 | v10 = 1. - (tz[:, m:] - ty[:, :]) / total_positive_weights 120 | sx = np.cov(v01) 121 | sy = np.cov(v10) 122 | delongcov = sx / m + sy / n 123 | return aucs, delongcov 124 | 125 | 126 | def fastDeLong_no_weights(predictions_sorted_transposed, label_1_count): 127 | """ 128 | The fast version of DeLong's method for computing the covariance of 129 | unadjusted AUC. 130 | Args: 131 | predictions_sorted_transposed: a 2D numpy.array[n_classifiers, n_examples] 132 | sorted such as the examples with label "1" are first 133 | Returns: 134 | (AUC value, DeLong covariance) 135 | Reference: 136 | @article{sun2014fast, 137 | title={Fast Implementation of DeLong's Algorithm for 138 | Comparing the Areas Under Correlated Receiver Oerating 139 | Characteristic Curves}, 140 | author={Xu Sun and Weichao Xu}, 141 | journal={IEEE Signal Processing Letters}, 142 | volume={21}, 143 | number={11}, 144 | pages={1389--1393}, 145 | year={2014}, 146 | publisher={IEEE} 147 | } 148 | """ 149 | # Short variables are named as they are in the paper 150 | m = label_1_count 151 | n = predictions_sorted_transposed.shape[1] - m 152 | positive_examples = predictions_sorted_transposed[:, :m] 153 | negative_examples = predictions_sorted_transposed[:, m:] 154 | k = predictions_sorted_transposed.shape[0] 155 | 156 | tx = np.empty([k, m], dtype=np.float) 157 | ty = np.empty([k, n], dtype=np.float) 158 | tz = np.empty([k, m + n], dtype=np.float) 159 | for r in range(k): 160 | tx[r, :] = compute_midrank(positive_examples[r, :]) 161 | ty[r, :] = compute_midrank(negative_examples[r, :]) 162 | tz[r, :] = compute_midrank(predictions_sorted_transposed[r, :]) 163 | aucs = tz[:, :m].sum(axis=1) / m / n - float(m + 1.0) / 2.0 / n 164 | v01 = (tz[:, :m] - tx[:, :]) / n 165 | v10 = 1.0 - (tz[:, m:] - ty[:, :]) / m 166 | sx = np.cov(v01) 167 | sy = np.cov(v10) 168 | delongcov = sx / m + sy / n 169 | return aucs, delongcov 170 | 171 | 172 | def calc_pvalue(aucs, sigma): 173 | """Computes log(10) of p-values. 174 | Args: 175 | aucs: 1D array of AUCs 176 | sigma: AUC DeLong covariances 177 | Returns: 178 | log10(pvalue) 179 | """ 180 | l = np.array([[1, -1]]) 181 | z = np.abs(np.diff(aucs)) / np.sqrt(np.dot(np.dot(l, sigma), l.T)) 182 | return np.log10(2) + scipy.stats.norm.logsf(z, loc=0, scale=1) / np.log(10) 183 | 184 | 185 | def compute_ground_truth_statistics(ground_truth, sample_weight): 186 | assert np.array_equal(np.unique(ground_truth), [0, 1]) 187 | order = (-ground_truth).argsort() 188 | label_1_count = int(ground_truth.sum()) 189 | if sample_weight is None: 190 | ordered_sample_weight = None 191 | else: 192 | ordered_sample_weight = sample_weight[order] 193 | 194 | return order, label_1_count, ordered_sample_weight 195 | 196 | 197 | def delong_roc_variance(ground_truth, predictions, sample_weight=None): 198 | """ 199 | Computes ROC AUC variance for a single set of predictions 200 | Args: 201 | ground_truth: np.array of 0 and 1 202 | predictions: np.array of floats of the probability of being class 1 203 | """ 204 | order, label_1_count, ordered_sample_weight = compute_ground_truth_statistics( 205 | ground_truth, sample_weight) 206 | predictions_sorted_transposed = predictions[np.newaxis, order] 207 | aucs, delongcov = fastDeLong(predictions_sorted_transposed, label_1_count, ordered_sample_weight) 208 | assert len(aucs) == 1, "There is a bug in the code, please forward this to the developers" 209 | return aucs[0], delongcov 210 | 211 | def print_delong_AUROCs(y_groundtruth, y_predictions): 212 | alpha = .95 213 | y_p = np.asarray(y_predictions).reshape(-1) 214 | y_true = np.asarray(y_groundtruth).reshape(-1) 215 | 216 | auc, auc_cov = delong_roc_variance( 217 | y_true, 218 | y_p) 219 | 220 | auc_std = np.sqrt(auc_cov) 221 | lower_upper_q = np.abs(np.array([0, 1]) - (1 - alpha) / 2) 222 | 223 | ci = stats.norm.ppf( 224 | lower_upper_q, 225 | loc=auc, 226 | scale=auc_std) 227 | 228 | ci[ci > 1] = 1 229 | 230 | print('Delong CIs follow.') 231 | print('AUC:', auc) 232 | print('AUC COV:', auc_cov) 233 | print('95% AUC CI:', ci) 234 | 235 | return auc, auc_cov, ci -------------------------------------------------------------------------------- /dataHelper.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | import argparse 6 | import sys 7 | import random 8 | import png 9 | from matplotlib.pyplot import imsave, imread 10 | import matplotlib 11 | from PIL import Image 12 | import cv2 13 | matplotlib.use("Agg") 14 | import torchvision.datasets as datasets 15 | from skimage.transform import resize 16 | import ast 17 | import pickle 18 | import csv 19 | import pydicom as dcm 20 | import Augmentor 21 | from tqdm import tqdm 22 | import pathlib 23 | from torch import randint, manual_seed 24 | from copy import copy 25 | from collections import defaultdict 26 | 27 | def make_dataset(dir, class_to_idx, extensions=None, is_valid_file=None): 28 | images = [] 29 | dir = os.path.expanduser(dir) 30 | if not ((extensions is None) ^ (is_valid_file is None)): 31 | raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time") 32 | 33 | for target in sorted(class_to_idx.keys()): 34 | d = os.path.join(dir, target) 35 | if not os.path.isdir(d): 36 | continue 37 | for root, _, fnames in sorted(os.walk(d)): 38 | for fname in sorted(fnames): 39 | path = os.path.join(root, fname) 40 | item = (path, class_to_idx[target]) 41 | images.append(item) 42 | return images 43 | 44 | 45 | def random_flip(input, axis, with_fa=False): 46 | ran = random.random() 47 | if ran > 0.5: 48 | if with_fa: 49 | axis += 1 50 | return np.flip(input, axis=axis) 51 | else: 52 | return input 53 | 54 | 55 | def random_crop(input, with_fa=False): 56 | ran = random.random() 57 | if ran > 0.2: 58 | # find a random place to be the left upper corner of the crop 59 | if with_fa: 60 | rx = int(random.random() * input.shape[1] // 10) 61 | ry = int(random.random() * input.shape[2] // 10) 62 | return input[:, rx: rx + int(input.shape[1] * 9 // 10), ry: ry + int(input.shape[2] * 9 // 10)] 63 | else: 64 | rx = int(random.random() * input.shape[0] // 10) 65 | ry = int(random.random() * input.shape[1] // 10) 66 | return input[rx: rx + int(input.shape[0] * 9 // 10), ry: ry + int(input.shape[1] * 9 // 10)] 67 | else: 68 | return input 69 | 70 | 71 | def random_rotate_90(input, with_fa=False): 72 | ran = random.random() 73 | if ran > 0.5: 74 | if with_fa: 75 | return np.rot90(input, axes=(1,2)) 76 | return np.rot90(input) 77 | else: 78 | return input 79 | 80 | 81 | def random_rotation(x, chance, with_fa=False): 82 | ran = random.random() 83 | if with_fa: 84 | img = Image.fromarray(x[0]) 85 | mask = Image.fromarray(x[1]) 86 | if ran > 1 - chance: 87 | # create black edges 88 | angle = np.random.randint(0, 90) 89 | img = img.rotate(angle=angle, expand=1) 90 | mask = mask.rotate(angle=angle, expand=1, fillcolor=1) 91 | return np.stack([np.asarray(img), np.asarray(mask)]) 92 | else: 93 | return np.stack([np.asarray(img), np.asarray(mask)]) 94 | img = Image.fromarray(x) 95 | if ran > 1 - chance: 96 | # create black edges 97 | angle = np.random.randint(0, 90) 98 | img = img.rotate(angle=angle, expand=1) 99 | return np.asarray(img) 100 | else: 101 | return np.asarray(img) 102 | 103 | 104 | class DatasetFolder(datasets.DatasetFolder): 105 | def __init__(self, root, loader, augmentation=False, extensions=None, transform=None, 106 | target_transform=None, is_valid_file=None, target_size=(224, 224)): 107 | 108 | super(DatasetFolder, self).__init__(root, loader, ("npy",), 109 | transform=transform, 110 | target_transform=target_transform, ) 111 | classes, class_to_idx = self._find_classes(self.root) 112 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 113 | if len(samples) == 0: 114 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 115 | "Supported extensions are: " + ",".join( 116 | extensions))) 117 | self.loader = loader 118 | self.extensions = extensions 119 | self.classes = classes 120 | self.class_to_idx = class_to_idx 121 | self.samples = samples 122 | self.augment = augmentation 123 | self.target_size = target_size 124 | self.targets = [s[1] for s in samples] 125 | 126 | def _find_classes(self, dir): 127 | if sys.version_info >= (3, 5): 128 | # Faster and available in Python 3.5 and above 129 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 130 | else: 131 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 132 | classes.sort() 133 | class_to_idx = {classes[i]: i for i in range(len(classes))} 134 | return classes, class_to_idx 135 | 136 | def __getitem__(self, index): 137 | path, target = self.samples[index] 138 | patient_id = path.split("/")[-1][:-4] 139 | sample = self.loader(path) 140 | if len(sample.shape) == 3: 141 | if self.target_size: 142 | sample = np.stack([resize(sample[0], self.target_size), resize(sample[1], self.target_size)]) 143 | temp = [sample[0], sample[0], sample[0], sample[1]] 144 | else: 145 | if self.target_size: 146 | sample = resize(sample, self.target_size) 147 | if self.augment: 148 | sample = random_rotation(sample, 0.7) 149 | temp = [sample, sample, sample] 150 | n = np.stack(temp) 151 | if self.transform is not None: 152 | sample = self.transform(n) 153 | if self.target_transform is not None: 154 | target = self.target_transform(target) 155 | # print("after transform", sample.shape) 156 | return sample.float(), target, patient_id 157 | 158 | 159 | class DatasetFolder_WithReplacement(datasets.DatasetFolder): 160 | def __init__(self, root, loader, augmentation=False, extensions=None, transform=None, 161 | target_transform=None, is_valid_file=None, target_size=(224, 224)): 162 | 163 | super(DatasetFolder_WithReplacement, self).__init__(root, loader, ("npy",), 164 | transform=transform, 165 | target_transform=target_transform, ) 166 | classes, class_to_idx = self._find_classes(self.root) 167 | samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file) 168 | if len(samples) == 0: 169 | raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n" 170 | "Supported extensions are: " + ",".join( 171 | extensions))) 172 | self.loader = loader 173 | self.extensions = extensions 174 | self.classes = classes 175 | self.class_to_idx = class_to_idx 176 | self.samples = samples 177 | self.augment = augmentation 178 | self.target_size = target_size 179 | self.targets = [s[1] for s in samples] 180 | 181 | def _find_classes(self, dir): 182 | if sys.version_info >= (3, 5): 183 | # Faster and available in Python 3.5 and above 184 | classes = [d.name for d in os.scandir(dir) if d.is_dir()] 185 | else: 186 | classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] 187 | classes.sort() 188 | class_to_idx = {classes[i]: i for i in range(len(classes))} 189 | return classes, class_to_idx 190 | 191 | def __getitem__(self, index): 192 | index = randint(0, len(self.samples), (1,))[0] #pull with replacement 193 | path, target = self.samples[index] 194 | patient_id = path.split("/")[-1][:-4] 195 | sample = self.loader(path) 196 | if len(sample.shape) == 3: 197 | if self.target_size: 198 | sample = np.stack([resize(sample[0], self.target_size), resize(sample[1], self.target_size)]) 199 | 200 | temp = [sample[0], sample[0], sample[0], sample[1]] 201 | else: 202 | if self.target_size: 203 | sample = resize(sample, self.target_size) 204 | if self.augment: 205 | sample = random_rotation(sample, 0.7) 206 | temp = [sample, sample, sample] 207 | n = np.stack(temp) 208 | if self.transform is not None: 209 | sample = self.transform(n) 210 | if self.target_transform is not None: 211 | target = self.target_transform(target) 212 | # print("after transform", sample.shape) 213 | return sample.float(), target, patient_id -------------------------------------------------------------------------------- /vanilla_vgg.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | import matplotlib.pyplot as plt 3 | from vgg_features import vgg11_features, vgg11_bn_features, vgg13_features, vgg13_bn_features, vgg16_features, vgg16_bn_features,\ 4 | vgg19_features, vgg19_bn_features 5 | import argparse 6 | import torch.nn as nn 7 | from dataHelper import DatasetFolder 8 | from torchvision import transforms 9 | import torch 10 | from torch import optim 11 | from torch.utils.tensorboard import SummaryWriter 12 | import numpy as np 13 | from sklearn.metrics import roc_auc_score 14 | import os 15 | import random 16 | 17 | # build model 18 | class Vanilla_VGG(nn.Module): 19 | def __init__(self, myfeatures, num_classes): 20 | super(Vanilla_VGG, self).__init__() 21 | 22 | self.features = myfeatures 23 | self.avgpool = nn.AdaptiveAvgPool2d((7,7)) 24 | self.num_classes = num_classes 25 | self.classifier = nn.Sequential( 26 | nn.Linear(512*7*7, 4096), 27 | nn.ReLU(True), 28 | nn.Dropout(), 29 | nn.Linear(4096, 4096), 30 | nn.ReLU(True), 31 | nn.Dropout(), 32 | nn.Linear(4096, num_classes), 33 | nn.LogSoftmax(dim=0) 34 | ) 35 | 36 | 37 | def forward(self, x): 38 | x = self.features(x) 39 | x = self.avgpool(x) 40 | x = torch.flatten(x, 1) 41 | x = self.classifier(x) 42 | return x 43 | 44 | def main(): 45 | 46 | matplotlib.use("Agg") 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("-model", type=str) 50 | parser.add_argument("-train_dir", type=str, default="/usr/project/xtmp/mammo/") 51 | parser.add_argument("-test_dir", type=str, default="/usr/project/xtmp/mammo/") 52 | parser.add_argument("-name", type=str) 53 | parser.add_argument("-lr", type=lambda x: float(x)) 54 | parser.add_argument("-wd", type=lambda x: float(x)) 55 | parser.add_argument("-num_classes", type=lambda x: int(x)) 56 | args = parser.parse_args() 57 | model_name = args.model 58 | train_dir = args.train_dir 59 | test_dir = args.test_dir 60 | task_name = args.name 61 | num_classes = args.num_classes 62 | 63 | save_loc = '/usr/xtmp/mammo/saved_models/vanilla/' 64 | os.makedirs(save_loc + task_name, exist_ok = True) 65 | lr = args.lr 66 | wd = args.wd 67 | print(wd, lr) 68 | print("Saving to: ", save_loc + task_name) 69 | 70 | random_seed_number = 12 71 | print("Random seed: ", random_seed_number) 72 | torch.manual_seed(random_seed_number) 73 | torch.cuda.manual_seed(random_seed_number) 74 | np.random.seed(random_seed_number) 75 | random.seed(random_seed_number) 76 | torch.backends.cudnn.enabled=False 77 | torch.backends.cudnn.deterministic=True 78 | 79 | writer = SummaryWriter() 80 | 81 | base_architecture_to_features = {'vgg11': vgg11_features, 82 | 'vgg11_bn': vgg11_bn_features, 83 | 'vgg13': vgg13_features, 84 | 'vgg13_bn': vgg13_bn_features, 85 | 'vgg16': vgg16_features, 86 | 'vgg16_bn': vgg16_bn_features, 87 | 'vgg19': vgg19_features, 88 | 'vgg19_bn': vgg19_bn_features} 89 | 90 | features = base_architecture_to_features[model_name](pretrained=True) 91 | 92 | model = Vanilla_VGG(features, num_classes) 93 | 94 | # load data 95 | # train set 96 | train_dataset = DatasetFolder( 97 | train_dir, 98 | augmentation=False, 99 | loader=np.load, 100 | extensions=("npy",), 101 | target_size=(224, 224), 102 | transform = transforms.Compose([ 103 | torch.from_numpy, 104 | ])) 105 | trainloader = torch.utils.data.DataLoader( 106 | train_dataset, batch_size=100, shuffle=True, 107 | num_workers=4, pin_memory=False) 108 | 109 | # test set 110 | test_dataset =DatasetFolder( 111 | test_dir, 112 | loader=np.load, 113 | target_size=(224, 224), 114 | extensions=("npy",), 115 | transform = transforms.Compose([ 116 | torch.from_numpy, 117 | ])) 118 | testloader = torch.utils.data.DataLoader( 119 | test_dataset, batch_size=100, shuffle=False, 120 | num_workers=4, pin_memory=False) 121 | 122 | 123 | # start training 124 | epochs = 250 125 | 126 | criterion = nn.CrossEntropyLoss() 127 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd) 128 | 129 | device = torch.device("cuda") 130 | model.to(device) 131 | 132 | train_losses = [] 133 | test_losses = [] 134 | train_auc = [] 135 | test_auc = [] 136 | curr_best = 0 137 | 138 | 139 | for epoch in range(epochs): 140 | # train 141 | confusion_matrix = np.zeros((num_classes, num_classes)) 142 | total_output = [] 143 | total_one_hot_label = [] 144 | running_loss = 0 145 | model.train() 146 | for inputs, labels, id in trainloader: 147 | inputs, labels = inputs.to(device), labels.to(device) 148 | optimizer.zero_grad() 149 | logps = model.forward(inputs) 150 | loss = criterion(logps, labels) 151 | loss.backward() 152 | optimizer.step() 153 | running_loss += loss.item() 154 | one_hot_label = np.zeros(shape=(len(labels), num_classes)) 155 | for k in range(len(labels)): 156 | one_hot_label[k][labels[k].item()] = 1 157 | # roc_auc_score() 158 | total_output.extend(logps.cpu().detach().numpy()) 159 | total_one_hot_label.extend(one_hot_label) 160 | # confusion matrix 161 | _, predicted = torch.max(logps.data, 1) 162 | for t_idx, t in enumerate(labels): 163 | confusion_matrix[predicted[t_idx]][t] += 1 #row is predicted, col is true 164 | # if predicted[t_idx] == t: # correct label 165 | # confusion_matrix[t][t] += 1 166 | # elif t == 0 and predicted[t_idx] == 1: 167 | # confusion_matrix[1] += 1 # false positives 168 | # elif t == 1 and predicted[t_idx] == 0: 169 | # confusion_matrix[2] += 1 # false negative 170 | # else: 171 | # confusion_matrix[3] += 1 172 | 173 | auc_score = roc_auc_score(np.array(total_one_hot_label), np.array(total_output)) 174 | 175 | train_losses.append(running_loss / len(trainloader)) 176 | train_auc.append(auc_score) 177 | print("=======================================================") 178 | print("\t at epoch {}".format(epoch)) 179 | print("\t train loss is {}".format(train_losses[-1])) 180 | print("\t train auc is {}".format(auc_score)) 181 | print('\tthe confusion matrix is: \n{0}'.format(confusion_matrix)) 182 | # test 183 | confusion_matrix = np.zeros((num_classes, num_classes)) 184 | test_loss = 0 185 | total_output = [] 186 | total_one_hot_label = [] 187 | model.eval() 188 | with torch.no_grad(): 189 | for inputs, labels, id in testloader: 190 | inputs, labels = inputs.to(device), labels.to(device) 191 | logps = model.forward(inputs) 192 | batch_loss = criterion(logps, labels) 193 | test_loss += batch_loss.item() 194 | one_hot_label = np.zeros(shape=(len(labels), num_classes)) 195 | for k in range(len(labels)): 196 | one_hot_label[k][labels[k].item()] = 1 197 | # roc_auc_score() 198 | total_output.extend(logps.cpu().numpy()) 199 | total_one_hot_label.extend(one_hot_label) 200 | # confusion matrix 201 | _, predicted = torch.max(logps.data, 1) 202 | for t_idx, t in enumerate(labels): 203 | confusion_matrix[predicted[t_idx]][t] += 1 #row is predicted, col is true 204 | # if predicted[t_idx] == t and predicted[t_idx] == 1: # true positive 205 | # confusion_matrix[0] += 1 206 | # elif t == 0 and predicted[t_idx] == 1: 207 | # confusion_matrix[1] += 1 # false positives 208 | # elif t == 1 and predicted[t_idx] == 0: 209 | # confusion_matrix[2] += 1 # false negative 210 | # else: 211 | # confusion_matrix[3] += 1 212 | auc_score = roc_auc_score(np.array(total_one_hot_label), np.array(total_output)) 213 | test_losses.append(test_loss / len(testloader)) 214 | test_auc.append(auc_score) 215 | print("===========================") 216 | if auc_score > curr_best: 217 | curr_best = auc_score 218 | print("\t test loss is {}".format(test_losses[-1])) 219 | print("\t test auc is {}".format(auc_score)) 220 | print("\t current best is {}".format(curr_best)) 221 | print('\tthe confusion matrix is: \n{0}'.format(confusion_matrix)) 222 | print("=======================================================") 223 | 224 | # save model 225 | if auc_score > 0: 226 | torch.save(model, save_loc + task_name + "/" + str(auc_score) + "_at_epoch_" + str(epoch)) 227 | 228 | # plot graphs 229 | plt.plot(train_losses, "b", label="train") 230 | plt.plot(test_losses, "r", label="test") 231 | #plt.ylim(0, 4) 232 | plt.legend() 233 | plt.savefig(save_loc + task_name + '/train_test_loss_vanilla' + ".png") 234 | plt.close() 235 | 236 | plt.plot(train_auc, "b", label="train") 237 | plt.plot(test_auc, "r", label="test") 238 | #plt.ylim(0.4, 1) 239 | plt.legend() 240 | plt.savefig(save_loc + task_name + '/train_test_auc_vanilla' + ".png") 241 | plt.close() 242 | 243 | writer.close() 244 | 245 | if __name__=="__main__": 246 | main() -------------------------------------------------------------------------------- /local_analysis_vis.py: -------------------------------------------------------------------------------- 1 | import matplotlib 2 | matplotlib.use('Agg') 3 | import matplotlib.pyplot as plt 4 | from mpl_toolkits.axes_grid1 import ImageGrid 5 | import matplotlib.gridspec as gridspec 6 | from PIL import Image 7 | import numpy as np 8 | import os 9 | import argparse 10 | import re 11 | import shutil 12 | 13 | classname_dict = dict() 14 | classname_dict[0] = "circumscribed" 15 | classname_dict[1] = "indistinct" 16 | classname_dict[2] = "spiculated" 17 | 18 | def main(): 19 | # get dir 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('-local_analysis_directory', nargs=1, type=str, default='0') 22 | args = parser.parse_args() 23 | 24 | source_dir = args.local_analysis_directory[0] 25 | 26 | os.makedirs(os.path.join(source_dir, 'visualizations_of_expl'), exist_ok=True) 27 | 28 | pred, truth = read_local_analysis_log(os.path.join(source_dir + 'local_analysis.log')) 29 | 30 | anno_opts_cen = dict(xy=(0.4, 0.5), xycoords='axes fraction', 31 | va='center', ha='center') 32 | anno_opts_symb = dict(xy=(1, 0.5), xycoords='axes fraction', 33 | va='center', ha='center') 34 | anno_opts_sum = dict(xy=(0, -0.1), xycoords='axes fraction', 35 | va='center', ha='left') 36 | 37 | ###### all classes, one expl 38 | fig = plt.figure(constrained_layout=False) 39 | fig.set_size_inches(28, 12) 40 | 41 | ncols, nrows = 7, 3 42 | spec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig) 43 | 44 | f_axes = [] 45 | for row in range(nrows): 46 | f_axes.append([]) 47 | for col in range(ncols): 48 | f_axes[-1].append(fig.add_subplot(spec[row, col])) 49 | 50 | plt.rcParams.update({'font.size': 14}) 51 | 52 | for ax_num, ax in enumerate(f_axes[0]): 53 | if ax_num == 0: 54 | ax.set_title("Test image", fontdict=None, loc='left', color = "k") 55 | elif ax_num == 1: 56 | ax.set_title("Test image activation\nby prototype", fontdict=None, loc='left', color = "k") 57 | elif ax_num == 2: 58 | ax.set_title("Prototype", fontdict=None, loc='left', color = "k") 59 | elif ax_num == 3: 60 | ax.set_title("Self-activation of\nprototype", fontdict=None, loc='left', color = "k") 61 | elif ax_num == 4: 62 | ax.set_title("Similarity score", fontdict=None, loc='left', color = "k") 63 | elif ax_num == 5: 64 | ax.set_title("Class connection", fontdict=None, loc='left', color = "k") 65 | elif ax_num == 6: 66 | ax.set_title("Contribution", fontdict=None, loc='left', color = "k") 67 | else: 68 | pass 69 | 70 | plt.rcParams.update({'font.size': 22}) 71 | 72 | for ax in [f_axes[r][4] for r in range(nrows)]: 73 | ax.annotate('x', **anno_opts_symb) 74 | 75 | for ax in [f_axes[r][5] for r in range(nrows)]: 76 | ax.annotate('=', **anno_opts_symb) 77 | 78 | # get and plot data from source directory 79 | 80 | orig_img = Image.open(os.path.join(source_dir + 'original_img.png')) 81 | 82 | for ax in [f_axes[r][0] for r in range(nrows)]: 83 | ax.imshow(orig_img) 84 | ax.get_xaxis().set_ticks([]) 85 | ax.get_yaxis().set_ticks([]) 86 | 87 | top_p_dir = os.path.join(source_dir + 'most_activated_prototypes') 88 | for top_p in range(3): 89 | # put info in place 90 | p_info_file = open(os.path.join(top_p_dir, f'top-{top_p+1}_activated_prototype.txt'), 'r') 91 | sim_score, cc_dict, class_str, top_cc_str = read_info(p_info_file) 92 | p_info_file.close() 93 | for ax in [f_axes[top_p][4]]: 94 | ax.annotate(sim_score, **anno_opts_cen) 95 | ax.set_axis_off() 96 | for ax in [f_axes[top_p][5]]: 97 | ax.annotate(top_cc_str + "\n" + class_str, **anno_opts_cen) 98 | ax.set_axis_off() 99 | for ax in [f_axes[top_p][6]]: 100 | tc = float(top_cc_str) * float(sim_score) 101 | ax.annotate('{0:.3f}'.format(tc) + "\n" + class_str, **anno_opts_cen) 102 | ax.set_axis_off() 103 | # put images in place 104 | p_img = Image.open(os.path.join(top_p_dir, f'top-{top_p+1}_activated_prototype_full_size.png')) 105 | for ax in [f_axes[top_p][2]]: 106 | ax.imshow(p_img) 107 | ax.get_xaxis().set_ticks([]) 108 | ax.get_yaxis().set_ticks([]) 109 | p_act_img = Image.open(os.path.join(top_p_dir, f'top-{top_p+1}_activated_prototype_self_act.png')) 110 | for ax in [f_axes[top_p][3]]: 111 | ax.imshow(p_act_img) 112 | ax.get_xaxis().set_ticks([]) 113 | ax.get_yaxis().set_ticks([]) 114 | act_img = Image.open(os.path.join(top_p_dir, f'prototype_activation_map_by_top-{top_p+1}_prototype_normed.png')) 115 | for ax in [f_axes[top_p][1]]: 116 | ax.imshow(act_img) 117 | ax.get_xaxis().set_ticks([]) 118 | ax.get_yaxis().set_ticks([]) 119 | #summary 120 | f_axes[2][4].annotate(f"This {classname_dict[int(truth)]} lesion is classified as {classname_dict[int(pred)]}.", **anno_opts_sum) 121 | 122 | save_loc1 = os.path.join(source_dir, 'visualizations_of_expl') + f'/all_class.png' 123 | plt.savefig(save_loc1, bbox_inches='tight', pad_inches=0) 124 | os.makedirs('./visualizations_of_expl/', exist_ok=True) 125 | save_loc2 = './visualizations_of_expl/' + str(source_dir.replace('/', '__'))[len('__usr__xtmp__IAIABL__saved_models__0129_pushonall_topkk=9_fa=0.001_random=4__pruned_prototypes_epoch50_k6_pt3__'):] + f'all_class.png' 126 | shutil.copy2(save_loc1, save_loc2) 127 | print(f"Saved in {save_loc2}") 128 | return 129 | 130 | def read_local_analysis_log(file_loc): 131 | log_file = open(file_loc, 'r') 132 | for _ in range(30): 133 | line = log_file.readline() 134 | if line[0:len("Predicted: ")] == "Predicted: ": 135 | pred = line[len("Predicted: "):] 136 | elif line[0:len("Actual: ")] == "Actual: ": 137 | actual = line[len("Actual: "):] 138 | # pred = log_file.readline()[len("Predicted: "):] 139 | # actual = log_file.readline()[len("Actual: "):] 140 | log_file.close() 141 | return pred, actual 142 | 143 | 144 | def read_info(info_file, per_class=False): 145 | sim_score_line = info_file.readline() 146 | connection_line = info_file.readline() 147 | proto_index_line = info_file.readline() 148 | cc_0_line = info_file.readline() 149 | cc_1_line = info_file.readline() 150 | cc_2_line = info_file.readline() 151 | 152 | sim_score = sim_score_line[len("similarity: "):-1] 153 | if per_class: 154 | cc = connection_line[len('last layer connection: '):-1] 155 | else: 156 | cc = connection_line[len('last layer connection with predicted class: '):-1] 157 | circ_cc_str = cc_0_line[len('proto connection to class 0:tensor('):-(len(", device='cuda:0', grad_fn=)")+1)] 158 | circ_cc = float(circ_cc_str) 159 | indst_cc_str = cc_1_line[len('proto connection to class 1:tensor('):-(len(", device='cuda:0', grad_fn=)")+1)] 160 | indst_cc = float(indst_cc_str) 161 | spic_cc_str = cc_2_line[len('proto connection to class 2:tensor('):-(len(", device='cuda:0', grad_fn=)")+1)] 162 | spic_cc = float(spic_cc_str) 163 | 164 | cc_dict = dict() 165 | cc_dict[0] = circ_cc 166 | cc_dict[1] = indst_cc 167 | cc_dict[2] = spic_cc 168 | class_of_p = max(cc_dict, key=lambda k: cc_dict[k]) 169 | top_cc = cc_dict[class_of_p] 170 | 171 | class_str = classname_dict[class_of_p] 172 | if class_of_p == 0: 173 | top_cc_str = circ_cc_str 174 | elif class_of_p == 1: 175 | top_cc_str = indst_cc_str 176 | elif class_of_p == 2: 177 | top_cc_str = spic_cc_str 178 | else: 179 | print("Error. The maximum value class is not found.") 180 | 181 | return sim_score, cc_dict, class_str, top_cc_str 182 | 183 | def test(): 184 | 185 | im = Image.open('./visualizations_of_expl/' + 'original_img.png') 186 | 187 | fig = plt.figure(constrained_layout=False) 188 | fig.set_size_inches(28, 12) 189 | 190 | ncols, nrows = 7, 3 191 | spec = gridspec.GridSpec(ncols=ncols, nrows=nrows, figure=fig) 192 | 193 | f_axes = [] 194 | for row in range(nrows): 195 | f_axes.append([]) 196 | for col in range(ncols): 197 | f_axes[-1].append(fig.add_subplot(spec[row, col])) 198 | 199 | plt.rcParams.update({'font.size': 15}) 200 | 201 | for ax_num, ax in enumerate(f_axes[0]): 202 | if ax_num == 0: 203 | ax.set_title("Test image", fontdict=None, loc='left', color = "k") 204 | elif ax_num == 1: 205 | ax.set_title("Test image activation by prototype", fontdict=None, loc='left', color = "k") 206 | elif ax_num == 2: 207 | ax.set_title("Prototype", fontdict=None, loc='left', color = "k") 208 | elif ax_num == 3: 209 | ax.set_title("Self-activation of prototype", fontdict=None, loc='left', color = "k") 210 | elif ax_num == 4: 211 | ax.set_title("Similarity score", fontdict=None, loc='left', color = "k") 212 | elif ax_num == 5: 213 | ax.set_title("Class connection", fontdict=None, loc='left', color = "k") 214 | elif ax_num == 6: 215 | ax.set_title("Contribution", fontdict=None, loc='left', color = "k") 216 | else: 217 | pass 218 | 219 | plt.rcParams.update({'font.size': 22}) 220 | 221 | for ax in [f_axes[r][0] for r in range(nrows)]: 222 | ax.imshow(im) 223 | ax.get_xaxis().set_ticks([]) 224 | ax.get_yaxis().set_ticks([]) 225 | 226 | 227 | anno_opts = dict(xy=(0.4, 0.5), xycoords='axes fraction', 228 | va='center', ha='center') 229 | 230 | anno_opts_symb = dict(xy=(1, 0.5), xycoords='axes fraction', 231 | va='center', ha='center') 232 | 233 | for ax in [f_axes[r][s] for r in range(nrows) for s in range(4,7)]: 234 | ax.annotate('Number', **anno_opts) 235 | ax.set_axis_off() 236 | 237 | for ax in [f_axes[r][4] for r in range(nrows)]: 238 | ax.annotate('x', **anno_opts_symb) 239 | 240 | for ax in [f_axes[r][5] for r in range(nrows)]: 241 | ax.annotate('=', **anno_opts_symb) 242 | 243 | os.makedirs('./visualizations_of_expl/', exist_ok=True) 244 | plt.savefig('./visualizations_of_expl/' + 'test.png') 245 | 246 | # Refs: https://stackoverflow.com/questions/40846492/how-to-add-text-to-each-image-using-imagegrid 247 | # https://stackoverflow.com/questions/41793931/plotting-images-side-by-side-using-matplotlib 248 | 249 | if __name__ == "__main__": 250 | main() -------------------------------------------------------------------------------- /vgg_features.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.model_zoo as model_zoo 3 | 4 | model_urls = { 5 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 6 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth', 7 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth', 8 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth', 9 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 10 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth', 11 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth', 12 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth', 13 | } 14 | 15 | model_dir = './pretrained_models' 16 | 17 | cfg = { 18 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 19 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 20 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'end'], 21 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'], 22 | } 23 | 24 | class VGG_features(nn.Module): 25 | 26 | def __init__(self, cfg, batch_norm=False, init_weights=True): 27 | super(VGG_features, self).__init__() 28 | 29 | self.batch_norm = batch_norm 30 | 31 | self.kernel_sizes = [] 32 | self.strides = [] 33 | self.paddings = [] 34 | 35 | self.features = self._make_layers(cfg, batch_norm) 36 | 37 | if init_weights: 38 | self._initialize_weights() 39 | 40 | def forward(self, x): 41 | x = self.features(x) 42 | return x 43 | 44 | def _initialize_weights(self): 45 | for m in self.modules(): 46 | if isinstance(m, nn.Conv2d): 47 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 48 | if m.bias is not None: 49 | nn.init.constant_(m.bias, 0) 50 | elif isinstance(m, nn.BatchNorm2d): 51 | nn.init.constant_(m.weight, 1) 52 | nn.init.constant_(m.bias, 0) 53 | elif isinstance(m, nn.Linear): 54 | nn.init.normal_(m.weight, 0, 0.01) 55 | nn.init.constant_(m.bias, 0) 56 | 57 | def _make_layers(self, cfg, batch_norm): 58 | 59 | self.n_layers = 0 60 | 61 | layers = [] 62 | in_channels = 3 63 | for v in cfg: 64 | if v == 'M': 65 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 66 | 67 | self.kernel_sizes.append(2) 68 | self.strides.append(2) 69 | self.paddings.append(0) 70 | 71 | elif v != "end": 72 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 73 | if batch_norm: 74 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 75 | else: 76 | layers += [conv2d, nn.ReLU(inplace=True)] 77 | 78 | self.n_layers += 1 79 | 80 | self.kernel_sizes.append(3) 81 | self.strides.append(1) 82 | self.paddings.append(1) 83 | 84 | in_channels = v 85 | 86 | return nn.Sequential(*layers) 87 | 88 | def conv_info(self): 89 | return self.kernel_sizes, self.strides, self.paddings 90 | 91 | def num_layers(self): 92 | ''' 93 | the number of conv layers in the network 94 | ''' 95 | return self.n_layers 96 | 97 | def __repr__(self): 98 | template = 'VGG{}, batch_norm={}' 99 | return template.format(self.num_layers() + 3, 100 | self.batch_norm) 101 | 102 | 103 | 104 | def vgg11_features(pretrained=False, **kwargs): 105 | """VGG 11-layer model (configuration "A") 106 | 107 | Args: 108 | pretrained (bool): If True, returns a model pre-trained on ImageNet 109 | """ 110 | if pretrained: 111 | kwargs['init_weights'] = False 112 | model = VGG_features(cfg['A'], batch_norm=False, **kwargs) 113 | if pretrained: 114 | my_dict = model_zoo.load_url(model_urls['vgg11'], model_dir=model_dir) 115 | keys_to_remove = set() 116 | for key in my_dict: 117 | if key.startswith('classifier'): 118 | keys_to_remove.add(key) 119 | for key in keys_to_remove: 120 | del my_dict[key] 121 | model.load_state_dict(my_dict, strict=False) 122 | return model 123 | 124 | 125 | def vgg11_bn_features(pretrained=False, **kwargs): 126 | """VGG 11-layer model (configuration "A") with batch normalization 127 | 128 | Args: 129 | pretrained (bool): If True, returns a model pre-trained on ImageNet 130 | """ 131 | if pretrained: 132 | kwargs['init_weights'] = False 133 | model = VGG_features(cfg['A'], batch_norm=True, **kwargs) 134 | if pretrained: 135 | my_dict = model_zoo.load_url(model_urls['vgg11_bn'], model_dir=model_dir) 136 | keys_to_remove = set() 137 | for key in my_dict: 138 | if key.startswith('classifier'): 139 | keys_to_remove.add(key) 140 | for key in keys_to_remove: 141 | del my_dict[key] 142 | model.load_state_dict(my_dict, strict=False) 143 | return model 144 | 145 | 146 | def vgg13_features(pretrained=False, **kwargs): 147 | """VGG 13-layer model (configuration "B") 148 | 149 | Args: 150 | pretrained (bool): If True, returns a model pre-trained on ImageNet 151 | """ 152 | if pretrained: 153 | kwargs['init_weights'] = False 154 | model = VGG_features(cfg['B'], batch_norm=False, **kwargs) 155 | if pretrained: 156 | my_dict = model_zoo.load_url(model_urls['vgg13'], model_dir=model_dir) 157 | keys_to_remove = set() 158 | for key in my_dict: 159 | if key.startswith('classifier'): 160 | keys_to_remove.add(key) 161 | for key in keys_to_remove: 162 | del my_dict[key] 163 | model.load_state_dict(my_dict, strict=False) 164 | return model 165 | 166 | 167 | def vgg13_bn_features(pretrained=False, **kwargs): 168 | """VGG 13-layer model (configuration "B") with batch normalization 169 | 170 | Args: 171 | pretrained (bool): If True, returns a model pre-trained on ImageNet 172 | """ 173 | if pretrained: 174 | kwargs['init_weights'] = False 175 | model = VGG_features(cfg['B'], batch_norm=True, **kwargs) 176 | if pretrained: 177 | my_dict = model_zoo.load_url(model_urls['vgg13_bn'], model_dir=model_dir) 178 | keys_to_remove = set() 179 | for key in my_dict: 180 | if key.startswith('classifier'): 181 | keys_to_remove.add(key) 182 | for key in keys_to_remove: 183 | del my_dict[key] 184 | model.load_state_dict(my_dict, strict=False) 185 | return model 186 | 187 | 188 | def vgg16_features(pretrained=False, **kwargs): 189 | """VGG 16-layer model (configuration "D") 190 | 191 | Args: 192 | pretrained (bool): If True, returns a model pre-trained on ImageNet 193 | """ 194 | if pretrained: 195 | kwargs['init_weights'] = False 196 | model = VGG_features(cfg['D'], batch_norm=False, **kwargs) 197 | if pretrained: 198 | my_dict = model_zoo.load_url(model_urls['vgg16'], model_dir=model_dir) 199 | keys_to_remove = set() 200 | for key in my_dict: 201 | if key.startswith('classifier'): 202 | keys_to_remove.add(key) 203 | for key in keys_to_remove: 204 | del my_dict[key] 205 | model.load_state_dict(my_dict, strict=False) 206 | return model 207 | 208 | 209 | def vgg16_bn_features(pretrained=False, **kwargs): 210 | """VGG 16-layer model (configuration "D") with batch normalization 211 | 212 | Args: 213 | pretrained (bool): If True, returns a model pre-trained on ImageNet 214 | """ 215 | if pretrained: 216 | kwargs['init_weights'] = False 217 | model = VGG_features(cfg['D'], batch_norm=True, **kwargs) 218 | if pretrained: 219 | my_dict = model_zoo.load_url(model_urls['vgg16_bn'], model_dir=model_dir) 220 | keys_to_remove = set() 221 | for key in my_dict: 222 | if key.startswith('classifier'): 223 | keys_to_remove.add(key) 224 | for key in keys_to_remove: 225 | del my_dict[key] 226 | model.load_state_dict(my_dict, strict=False) 227 | return model 228 | 229 | 230 | def vgg19_features(pretrained=False, **kwargs): 231 | """VGG 19-layer model (configuration "E") 232 | 233 | Args: 234 | pretrained (bool): If True, returns a model pre-trained on ImageNet 235 | """ 236 | if pretrained: 237 | kwargs['init_weights'] = False 238 | model = VGG_features(cfg['E'], batch_norm=False, **kwargs) 239 | if pretrained: 240 | my_dict = model_zoo.load_url(model_urls['vgg19'], model_dir=model_dir) 241 | keys_to_remove = set() 242 | for key in my_dict: 243 | if key.startswith('classifier'): 244 | keys_to_remove.add(key) 245 | for key in keys_to_remove: 246 | del my_dict[key] 247 | model.load_state_dict(my_dict, strict=False) 248 | return model 249 | 250 | 251 | def vgg19_bn_features(pretrained=False, **kwargs): 252 | """VGG 19-layer model (configuration 'E') with batch normalization 253 | 254 | Args: 255 | pretrained (bool): If True, returns a model pre-trained on ImageNet 256 | """ 257 | if pretrained: 258 | kwargs['init_weights'] = False 259 | model = VGG_features(cfg['E'], batch_norm=True, **kwargs) 260 | if pretrained: 261 | my_dict = model_zoo.load_url(model_urls['vgg19_bn'], model_dir=model_dir) 262 | keys_to_remove = set() 263 | for key in my_dict: 264 | if key.startswith('classifier'): 265 | keys_to_remove.add(key) 266 | for key in keys_to_remove: 267 | del my_dict[key] 268 | model.load_state_dict(my_dict, strict=False) 269 | return model 270 | 271 | 272 | if __name__ == '__main__': 273 | 274 | # vgg11_f = vgg11_features(pretrained=True) 275 | # print(vgg11_f) 276 | # 277 | # vgg11_bn_f = vgg11_bn_features(pretrained=True) 278 | # print(vgg11_bn_f) 279 | # 280 | # vgg13_f = vgg13_features(pretrained=True) 281 | # print(vgg13_f) 282 | # 283 | # vgg13_bn_f = vgg13_bn_features(pretrained=True) 284 | # print(vgg13_bn_f) 285 | 286 | vgg16_f = vgg16_features(pretrained=True) 287 | print(vgg16_f) 288 | 289 | # vgg16_bn_f = vgg16_bn_features(pretrained=True) 290 | # print(vgg16_bn_f) 291 | # 292 | # vgg19_f = vgg19_features(pretrained=True) 293 | # print(vgg19_f) 294 | # 295 | # vgg19_bn_f = vgg19_bn_features(pretrained=True) 296 | # print(vgg19_bn_f) 297 | -------------------------------------------------------------------------------- /gradcam_utils.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/stefannc/GradCAM-Pytorch/blob/07fd6ece5010f7c1c9fbcc8155a60023819111d7/example.ipynb retrieved Mar 3 2021 ##### 2 | 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | def visualize_cam(mask, img): 8 | """Make heatmap from mask and synthesize GradCAM result image using heatmap and img. 9 | Args: 10 | mask (torch.tensor): mask shape of (1, 1, H, W) and each element has value in range [0, 1] 11 | img (torch.tensor): img shape of (1, 3, H, W) and each pixel value is in range [0, 1] 12 | 13 | Return: 14 | heatmap (torch.tensor): heatmap img shape of (3, H, W) 15 | result (torch.tensor): synthesized GradCAM result of same shape with heatmap. 16 | """ 17 | mask = mask.cpu() 18 | heatmap = cv2.applyColorMap(np.uint8(255 * mask.squeeze()), cv2.COLORMAP_JET) 19 | heatmap = torch.from_numpy(heatmap).permute(2, 0, 1).float().div(255) 20 | b, g, r = heatmap.split(1) 21 | heatmap = torch.cat([r, g, b]) 22 | 23 | result = heatmap+img.cpu() 24 | result = result.div(result.max()).squeeze() 25 | 26 | return heatmap, result 27 | 28 | 29 | def find_resnet_layer(arch, target_layer_name): 30 | """Find resnet layer to calculate GradCAM and GradCAM++ 31 | 32 | Args: 33 | arch: default torchvision densenet models 34 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 35 | target_layer_name = 'conv1' 36 | target_layer_name = 'layer1' 37 | target_layer_name = 'layer1_basicblock0' 38 | target_layer_name = 'layer1_basicblock0_relu' 39 | target_layer_name = 'layer1_bottleneck0' 40 | target_layer_name = 'layer1_bottleneck0_conv1' 41 | target_layer_name = 'layer1_bottleneck0_downsample' 42 | target_layer_name = 'layer1_bottleneck0_downsample_0' 43 | target_layer_name = 'avgpool' 44 | target_layer_name = 'fc' 45 | 46 | Return: 47 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 48 | """ 49 | if 'layer' in target_layer_name: 50 | hierarchy = target_layer_name.split('_') 51 | layer_num = int(hierarchy[0].lstrip('layer')) 52 | if layer_num == 1: 53 | target_layer = arch.layer1 54 | elif layer_num == 2: 55 | target_layer = arch.layer2 56 | elif layer_num == 3: 57 | target_layer = arch.layer3 58 | elif layer_num == 4: 59 | target_layer = arch.layer4 60 | else: 61 | raise ValueError('unknown layer : {}'.format(target_layer_name)) 62 | 63 | if len(hierarchy) >= 2: 64 | bottleneck_num = int(hierarchy[1].lower().lstrip('bottleneck').lstrip('basicblock')) 65 | target_layer = target_layer[bottleneck_num] 66 | 67 | if len(hierarchy) >= 3: 68 | target_layer = target_layer._modules[hierarchy[2]] 69 | 70 | if len(hierarchy) == 4: 71 | target_layer = target_layer._modules[hierarchy[3]] 72 | 73 | else: 74 | target_layer = arch._modules[target_layer_name] 75 | 76 | return target_layer 77 | 78 | 79 | def find_densenet_layer(arch, target_layer_name): 80 | """Find densenet layer to calculate GradCAM and GradCAM++ 81 | 82 | Args: 83 | arch: default torchvision densenet models 84 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 85 | target_layer_name = 'features' 86 | target_layer_name = 'features_transition1' 87 | target_layer_name = 'features_transition1_norm' 88 | target_layer_name = 'features_denseblock2_denselayer12' 89 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 90 | target_layer_name = 'features_denseblock2_denselayer12_norm1' 91 | target_layer_name = 'classifier' 92 | 93 | Return: 94 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 95 | """ 96 | 97 | hierarchy = target_layer_name.split('_') 98 | target_layer = arch._modules[hierarchy[0]] 99 | 100 | if len(hierarchy) >= 2: 101 | target_layer = target_layer._modules[hierarchy[1]] 102 | 103 | if len(hierarchy) >= 3: 104 | target_layer = target_layer._modules[hierarchy[2]] 105 | 106 | if len(hierarchy) == 4: 107 | target_layer = target_layer._modules[hierarchy[3]] 108 | 109 | return target_layer 110 | 111 | 112 | def find_vgg_layer(arch, target_layer_name): 113 | """Find vgg layer to calculate GradCAM and GradCAM++ 114 | 115 | Args: 116 | arch: default torchvision densenet models 117 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 118 | target_layer_name = 'features' 119 | target_layer_name = 'features_42' 120 | target_layer_name = 'classifier' 121 | target_layer_name = 'classifier_0' 122 | 123 | Return: 124 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 125 | """ 126 | hierarchy = target_layer_name.split('_') 127 | 128 | if len(hierarchy) >= 1: 129 | target_layer = arch.features 130 | 131 | if len(hierarchy) == 2: 132 | # print(f'int(hierarchy[1]) = {int(hierarchy[1])}') 133 | # print(f'target_layer[int(hierarchy[1])] = {target_layer[int(hierarchy[1])]}') 134 | # print(f'vgg type(target_layer) = {type(target_layer)}') 135 | # print(f'vgg target_layer.features = {target_layer.features}') 136 | # print(f'vgg target_layer[int(hierarchy[1])] = {target_layer[int(hierarchy[1])]}') 137 | target_layer = target_layer[int(hierarchy[1])] 138 | 139 | return target_layer 140 | 141 | def find_vgg_us_layer(arch, target_layer_name): 142 | """Find vgg layer to calculate GradCAM and GradCAM++. _us refers to integrating with IAIA-BL code structure. 143 | 144 | Args: 145 | arch: default torchvision densenet models 146 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 147 | target_layer_name = 'features' 148 | target_layer_name = 'features_42' 149 | target_layer_name = 'classifier' 150 | target_layer_name = 'classifier_0' 151 | 152 | Return: 153 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 154 | """ 155 | hierarchy = target_layer_name.split('_') 156 | 157 | # print(f'arch.classifier = {arch.classifier}') 158 | 159 | if len(hierarchy) >= 1: 160 | target_layer = arch.features.features 161 | 162 | if len(hierarchy) == 2: 163 | # print(f'6 = {6}') 164 | # print(f'vgg_us type(target_layer) = {type(target_layer)}') 165 | # print(f'vgg_us type(target_layer.features) = {type(target_layer.features)}') 166 | # print(f'vgg_us type(arch.features.features) = {type(arch.features.features)}') 167 | # print(f'vgg_us target_layer = {target_layer}') 168 | # print(f'vgg_us target_layer.features[6] = {target_layer.features[6]}') 169 | target_layer = target_layer[6] 170 | # target_layer = arch.classifier[1] 171 | 172 | return target_layer 173 | 174 | 175 | def find_alexnet_layer(arch, target_layer_name): 176 | """Find alexnet layer to calculate GradCAM and GradCAM++ 177 | 178 | Args: 179 | arch: default torchvision densenet models 180 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 181 | target_layer_name = 'features' 182 | target_layer_name = 'features_0' 183 | target_layer_name = 'classifier' 184 | target_layer_name = 'classifier_0' 185 | 186 | Return: 187 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 188 | """ 189 | hierarchy = target_layer_name.split('_') 190 | 191 | if len(hierarchy) >= 1: 192 | target_layer = arch.features 193 | 194 | if len(hierarchy) == 2: 195 | target_layer = target_layer[int(hierarchy[1])] 196 | 197 | return target_layer 198 | 199 | 200 | def find_squeezenet_layer(arch, target_layer_name): 201 | """Find squeezenet layer to calculate GradCAM and GradCAM++ 202 | 203 | Args: 204 | arch: default torchvision densenet models 205 | target_layer_name (str): the name of layer with its hierarchical information. please refer to usages below. 206 | target_layer_name = 'features_12' 207 | target_layer_name = 'features_12_expand3x3' 208 | target_layer_name = 'features_12_expand3x3_activation' 209 | 210 | Return: 211 | target_layer: found layer. this layer will be hooked to get forward/backward pass information. 212 | """ 213 | hierarchy = target_layer_name.split('_') 214 | target_layer = arch._modules[hierarchy[0]] 215 | 216 | if len(hierarchy) >= 2: 217 | target_layer = target_layer._modules[hierarchy[1]] 218 | 219 | if len(hierarchy) == 3: 220 | target_layer = target_layer._modules[hierarchy[2]] 221 | 222 | elif len(hierarchy) == 4: 223 | target_layer = target_layer._modules[hierarchy[2]+'_'+hierarchy[3]] 224 | 225 | return target_layer 226 | 227 | 228 | def denormalize(tensor, mean, std): 229 | if not tensor.ndimension() == 4: 230 | raise TypeError('tensor should be 4D') 231 | 232 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 233 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 234 | 235 | return tensor.mul(std).add(mean) 236 | 237 | 238 | def normalize(tensor, mean, std): 239 | if not tensor.ndimension() == 4: 240 | raise TypeError('tensor should be 4D') 241 | 242 | mean = torch.FloatTensor(mean).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 243 | std = torch.FloatTensor(std).view(1, 3, 1, 1).expand_as(tensor).to(tensor.device) 244 | 245 | return tensor.sub(mean).div(std) 246 | 247 | 248 | class Normalize(object): 249 | def __init__(self, mean, std): 250 | self.mean = mean 251 | self.std = std 252 | 253 | def __call__(self, tensor): 254 | return self.do(tensor) 255 | 256 | def do(self, tensor): 257 | return normalize(tensor, self.mean, self.std) 258 | 259 | def undo(self, tensor): 260 | return denormalize(tensor, self.mean, self.std) 261 | 262 | def __repr__(self): 263 | return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std) -------------------------------------------------------------------------------- /resnet_features.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | model_urls = { 6 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 7 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 8 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 9 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 10 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 11 | } 12 | 13 | model_dir = './pretrained_models' 14 | 15 | def conv3x3(in_planes, out_planes, stride=1): 16 | """3x3 convolution with padding""" 17 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 18 | padding=1, bias=False) 19 | 20 | 21 | def conv1x1(in_planes, out_planes, stride=1): 22 | """1x1 convolution""" 23 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 24 | 25 | 26 | class BasicBlock(nn.Module): 27 | # class attribute 28 | expansion = 1 29 | num_layers = 2 30 | 31 | def __init__(self, inplanes, planes, stride=1, downsample=None): 32 | super(BasicBlock, self).__init__() 33 | # only conv with possibly not 1 stride 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = nn.BatchNorm2d(planes) 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = nn.BatchNorm2d(planes) 39 | 40 | # if stride is not 1 then self.downsample cannot be None 41 | self.downsample = downsample 42 | self.stride = stride 43 | 44 | def forward(self, x): 45 | identity = x 46 | 47 | out = self.conv1(x) 48 | out = self.bn1(out) 49 | out = self.relu(out) 50 | 51 | out = self.conv2(out) 52 | out = self.bn2(out) 53 | 54 | if self.downsample is not None: 55 | identity = self.downsample(x) 56 | 57 | # the residual connection 58 | out += identity 59 | out = self.relu(out) 60 | 61 | return out 62 | 63 | def block_conv_info(self): 64 | block_kernel_sizes = [3, 3] 65 | block_strides = [self.stride, 1] 66 | block_paddings = [1, 1] 67 | 68 | return block_kernel_sizes, block_strides, block_paddings 69 | 70 | 71 | class Bottleneck(nn.Module): 72 | # class attribute 73 | expansion = 4 74 | num_layers = 3 75 | 76 | def __init__(self, inplanes, planes, stride=1, downsample=None): 77 | super(Bottleneck, self).__init__() 78 | self.conv1 = conv1x1(inplanes, planes) 79 | self.bn1 = nn.BatchNorm2d(planes) 80 | # only conv with possibly not 1 stride 81 | self.conv2 = conv3x3(planes, planes, stride) 82 | self.bn2 = nn.BatchNorm2d(planes) 83 | self.conv3 = conv1x1(planes, planes * self.expansion) 84 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 85 | self.relu = nn.ReLU(inplace=True) 86 | 87 | # if stride is not 1 then self.downsample cannot be None 88 | self.downsample = downsample 89 | self.stride = stride 90 | 91 | def forward(self, x): 92 | identity = x 93 | 94 | out = self.conv1(x) 95 | out = self.bn1(out) 96 | out = self.relu(out) 97 | 98 | out = self.conv2(out) 99 | out = self.bn2(out) 100 | out = self.relu(out) 101 | 102 | out = self.conv3(out) 103 | out = self.bn3(out) 104 | 105 | if self.downsample is not None: 106 | identity = self.downsample(x) 107 | 108 | out += identity 109 | out = self.relu(out) 110 | 111 | return out 112 | 113 | def block_conv_info(self): 114 | block_kernel_sizes = [1, 3, 1] 115 | block_strides = [1, self.stride, 1] 116 | block_paddings = [0, 1, 0] 117 | 118 | return block_kernel_sizes, block_strides, block_paddings 119 | 120 | 121 | class ResNet_features(nn.Module): 122 | ''' 123 | the convolutional layers of ResNet 124 | the average pooling and final fully convolutional layer is removed 125 | ''' 126 | 127 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False): 128 | super(ResNet_features, self).__init__() 129 | 130 | self.inplanes = 64 131 | 132 | # the first convolutional layer before the structured sequence of blocks 133 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 134 | bias=False) 135 | self.bn1 = nn.BatchNorm2d(64) 136 | self.relu = nn.ReLU(inplace=True) 137 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 138 | # comes from the first conv and the following max pool 139 | self.kernel_sizes = [7, 3] 140 | self.strides = [2, 2] 141 | self.paddings = [3, 1] 142 | 143 | # the following layers, each layer is a sequence of blocks 144 | self.block = block 145 | self.layers = layers 146 | self.layer1 = self._make_layer(block=block, planes=64, num_blocks=self.layers[0]) 147 | self.layer2 = self._make_layer(block=block, planes=128, num_blocks=self.layers[1], stride=2) 148 | self.layer3 = self._make_layer(block=block, planes=256, num_blocks=self.layers[2], stride=2) 149 | self.layer4 = self._make_layer(block=block, planes=512, num_blocks=self.layers[3], stride=2) 150 | 151 | # initialize the parameters 152 | for m in self.modules(): 153 | if isinstance(m, nn.Conv2d): 154 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 155 | elif isinstance(m, nn.BatchNorm2d): 156 | nn.init.constant_(m.weight, 1) 157 | nn.init.constant_(m.bias, 0) 158 | 159 | # Zero-initialize the last BN in each residual branch, 160 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 161 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 162 | if zero_init_residual: 163 | for m in self.modules(): 164 | if isinstance(m, Bottleneck): 165 | nn.init.constant_(m.bn3.weight, 0) 166 | elif isinstance(m, BasicBlock): 167 | nn.init.constant_(m.bn2.weight, 0) 168 | 169 | def _make_layer(self, block, planes, num_blocks, stride=1): 170 | downsample = None 171 | if stride != 1 or self.inplanes != planes * block.expansion: 172 | downsample = nn.Sequential( 173 | conv1x1(self.inplanes, planes * block.expansion, stride), 174 | nn.BatchNorm2d(planes * block.expansion), 175 | ) 176 | 177 | layers = [] 178 | # only the first block has downsample that is possibly not None 179 | layers.append(block(self.inplanes, planes, stride, downsample)) 180 | 181 | self.inplanes = planes * block.expansion 182 | for _ in range(1, num_blocks): 183 | layers.append(block(self.inplanes, planes)) 184 | 185 | # keep track of every block's conv size, stride size, and padding size 186 | for each_block in layers: 187 | block_kernel_sizes, block_strides, block_paddings = each_block.block_conv_info() 188 | self.kernel_sizes.extend(block_kernel_sizes) 189 | self.strides.extend(block_strides) 190 | self.paddings.extend(block_paddings) 191 | 192 | return nn.Sequential(*layers) 193 | 194 | def forward(self, x): 195 | x = self.conv1(x) 196 | x = self.bn1(x) 197 | x = self.relu(x) 198 | x = self.maxpool(x) 199 | 200 | x = self.layer1(x) 201 | x = self.layer2(x) 202 | x = self.layer3(x) 203 | x = self.layer4(x) 204 | 205 | return x 206 | 207 | def conv_info(self): 208 | return self.kernel_sizes, self.strides, self.paddings 209 | 210 | def num_layers(self): 211 | ''' 212 | the number of conv layers in the network, not counting the number 213 | of bypass layers 214 | ''' 215 | 216 | return (self.block.num_layers * self.layers[0] 217 | + self.block.num_layers * self.layers[1] 218 | + self.block.num_layers * self.layers[2] 219 | + self.block.num_layers * self.layers[3] 220 | + 1) 221 | 222 | 223 | def __repr__(self): 224 | template = 'resnet{}_features' 225 | return template.format(self.num_layers() + 1) 226 | 227 | def resnet18_features(pretrained=False, **kwargs): 228 | """Constructs a ResNet-18 model. 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | """ 232 | model = ResNet_features(BasicBlock, [2, 2, 2, 2], **kwargs) 233 | if pretrained: 234 | my_dict = model_zoo.load_url(model_urls['resnet18'], model_dir=model_dir) 235 | my_dict.pop('fc.weight') 236 | my_dict.pop('fc.bias') 237 | model.load_state_dict(my_dict, strict=False) 238 | return model 239 | 240 | 241 | def resnet34_features(pretrained=False, **kwargs): 242 | """Constructs a ResNet-34 model. 243 | Args: 244 | pretrained (bool): If True, returns a model pre-trained on ImageNet 245 | """ 246 | model = ResNet_features(BasicBlock, [3, 4, 6, 3], **kwargs) 247 | if pretrained: 248 | my_dict = model_zoo.load_url(model_urls['resnet34'], model_dir=model_dir) 249 | my_dict.pop('fc.weight') 250 | my_dict.pop('fc.bias') 251 | model.load_state_dict(my_dict, strict=False) 252 | return model 253 | 254 | 255 | def resnet50_features(pretrained=False, **kwargs): 256 | """Constructs a ResNet-50 model. 257 | Args: 258 | pretrained (bool): If True, returns a model pre-trained on ImageNet 259 | """ 260 | model = ResNet_features(Bottleneck, [3, 4, 6, 3], **kwargs) 261 | if pretrained: 262 | my_dict = model_zoo.load_url(model_urls['resnet50'], model_dir=model_dir) 263 | my_dict.pop('fc.weight') 264 | my_dict.pop('fc.bias') 265 | model.load_state_dict(my_dict, strict=False) 266 | return model 267 | 268 | 269 | def resnet101_features(pretrained=False, **kwargs): 270 | """Constructs a ResNet-101 model. 271 | Args: 272 | pretrained (bool): If True, returns a model pre-trained on ImageNet 273 | """ 274 | model = ResNet_features(Bottleneck, [3, 4, 23, 3], **kwargs) 275 | if pretrained: 276 | my_dict = model_zoo.load_url(model_urls['resnet101'], model_dir=model_dir) 277 | my_dict.pop('fc.weight') 278 | my_dict.pop('fc.bias') 279 | model.load_state_dict(my_dict, strict=False) 280 | return model 281 | 282 | 283 | def resnet152_features(pretrained=False, **kwargs): 284 | """Constructs a ResNet-152 model. 285 | Args: 286 | pretrained (bool): If True, returns a model pre-trained on ImageNet 287 | """ 288 | model = ResNet_features(Bottleneck, [3, 8, 36, 3], **kwargs) 289 | if pretrained: 290 | my_dict = model_zoo.load_url(model_urls['resnet152'], model_dir=model_dir) 291 | my_dict.pop('fc.weight') 292 | my_dict.pop('fc.bias') 293 | model.load_state_dict(my_dict, strict=False) 294 | return model 295 | 296 | 297 | if __name__ == '__main__': 298 | 299 | r18_features = resnet18_features(pretrained=True) 300 | print(r18_features) 301 | 302 | r34_features = resnet34_features(pretrained=True) 303 | print(r34_features) 304 | 305 | r50_features = resnet50_features(pretrained=True) 306 | print(r50_features) 307 | 308 | r101_features = resnet101_features(pretrained=True) 309 | print(r101_features) 310 | 311 | r152_features = resnet152_features(pretrained=True) 312 | print(r152_features) 313 | -------------------------------------------------------------------------------- /train_and_test.py: -------------------------------------------------------------------------------- 1 | import time 2 | import torch 3 | from sklearn.metrics import roc_auc_score 4 | import numpy as np 5 | import pandas as pd 6 | import csv 7 | from helpers import list_of_distances, make_one_hot 8 | 9 | def _train_or_test(model, dataloader, optimizer=None, class_specific=True, use_l1_mask=True, 10 | coefs=None, log=print, save_logits=False, finer_loader=None): 11 | ''' 12 | model: the multi-gpu model 13 | dataloader: 14 | optimizer: if None, will be test evaluation 15 | ''' 16 | is_train = optimizer is not None 17 | start = time.time() 18 | n_examples = 0 19 | n_correct = 0 20 | n_batches = 0 21 | total_output = [] 22 | total_one_hot_label = [] 23 | confusion_matrix = [0,0,0,0] 24 | total_cross_entropy = 0 25 | total_cluster_cost = 0 26 | # separation cost is meaningful only for class_specific 27 | total_separation_cost = 0 28 | total_avg_separation_cost = 0 29 | total_fa_cost = 0 30 | with_fa = False # intialization, see line 41 31 | 32 | for i, (image, label, patient_id) in enumerate(dataloader): 33 | # get one batch from finer datatloader 34 | if finer_loader: 35 | finer_image, finer_label, _ = next(iter(finer_loader)) 36 | # print(image.shape) 37 | image = torch.cat((image, finer_image)) 38 | label = torch.cat((label, finer_label)) 39 | # print(image.shape) 40 | if image.shape[1] == 4: 41 | with_fa = True 42 | fine_annotation = image[:, 3:4, :, :] 43 | image = image[:, 0:3, :, :] #(no view, create slice) 44 | elif image.shape[1] == 3: 45 | fine_annotation = torch.zeros(size=(image.shape[0], 1, image.shape[2], image.shape[3])) #means everything can be relevant 46 | image = image 47 | fine_annotation = fine_annotation.cuda() 48 | input = image.cuda() 49 | target = label.cuda() 50 | 51 | # torch.enable_grad() has no effect outside of no_grad() 52 | grad_req = torch.enable_grad() if is_train else torch.no_grad() 53 | with grad_req: 54 | # nn.Module has implemented __call__() function 55 | # so no need to call .forward 56 | output, min_distances, upsampled_activation = model(input) 57 | # compute loss 58 | cross_entropy = torch.nn.functional.cross_entropy(output, target) 59 | 60 | # only save to csv on test 61 | if not is_train and save_logits: 62 | _output_scores = [",".join([str(score) for score in scores.cpu().numpy()]) for scores in output] 63 | write_file = './logit_csvs/0218_training_3_class_margin_logits.csv' 64 | with open(write_file, mode='a') as logit_file: 65 | logit_writer = csv.writer(logit_file, delimiter=',', quotechar='"', quoting=csv.QUOTE_MINIMAL) 66 | for _index in range(len(patient_id)): 67 | logit_writer.writerow([patient_id[_index], _output_scores[_index]]) 68 | log(f'Wrote to {write_file}.') 69 | 70 | if class_specific: 71 | max_dist = (model.module.prototype_shape[1] 72 | * model.module.prototype_shape[2] 73 | * model.module.prototype_shape[3]) 74 | 75 | # prototypes_of_correct_class is a tensor of shape batch_size * num_prototypes 76 | # calculate cluster cost 77 | prototypes_of_correct_class = torch.t(model.module.prototype_class_identity[:,label]).cuda() 78 | inverted_distances, _ = torch.max((max_dist - min_distances) * prototypes_of_correct_class, dim=1) 79 | cluster_cost = torch.mean(max_dist - inverted_distances) 80 | # print("before change") 81 | 82 | # calculate separation cost 83 | prototypes_of_wrong_class = 1 - prototypes_of_correct_class 84 | inverted_distances_to_nontarget_prototypes, _ = \ 85 | torch.max((max_dist - min_distances) * prototypes_of_wrong_class, dim=1) 86 | separation_cost = torch.mean(max_dist - inverted_distances_to_nontarget_prototypes) 87 | # print("after change") 88 | 89 | # calculate avg cluster cost 90 | avg_separation_cost = \ 91 | torch.sum(min_distances * prototypes_of_wrong_class, dim=1) / torch.sum(prototypes_of_wrong_class, dim=1) 92 | avg_separation_cost = torch.mean(avg_separation_cost) 93 | 94 | if use_l1_mask: 95 | l1_mask = 1 - torch.t(model.module.prototype_class_identity).cuda() 96 | l1 = (model.module.last_layer.weight * l1_mask).norm(p=1) 97 | else: 98 | l1 = model.module.last_layer.weight.norm(p=1) 99 | 100 | #fine annotation loss 101 | fine_annotation_cost = 0 102 | if with_fa: 103 | proto_num_per_class = model.module.num_prototypes // model.module.num_classes 104 | all_white_mask = torch.ones(image.shape[2], image.shape[3]).cuda() 105 | for index in range(image.shape[0]): 106 | fine_annotation_cost += torch.norm(upsampled_activation[index, :label[index] * proto_num_per_class] * (1 * all_white_mask)) + \ 107 | torch.norm(upsampled_activation[index, label[index] * proto_num_per_class : (label[index] + 1) * proto_num_per_class] * (1 * fine_annotation[index])) + \ 108 | torch.norm(upsampled_activation[index, (label[index]+1) * proto_num_per_class:] * (1 * all_white_mask)) 109 | 110 | 111 | else: 112 | min_distance, _ = torch.min(min_distances, dim=1) 113 | # label=0 negative, label=1 positive, minimize cluster loss maximize separation loss 114 | # all prototypes are positive 115 | positive_sample_index = torch.flatten(torch.nonzero(label)).tolist() 116 | negative_sample_index = torch.flatten(torch.nonzero(label == 0)).tolist() 117 | if len(positive_sample_index) > 0: 118 | positive_proto_distance = min_distance[positive_sample_index] 119 | else: 120 | positive_proto_distance = torch.zeros(1) 121 | 122 | if len(negative_sample_index) > 0: 123 | negative_proto_distance = min_distance[negative_sample_index] 124 | else: 125 | negative_proto_distance = torch.zeros(1) 126 | 127 | cluster_cost = torch.mean(positive_proto_distance) 128 | separation_cost = torch.mean(negative_proto_distance) 129 | l1 = model.module.last_layer.weight.norm(p=1) 130 | 131 | # evaluation statistics 132 | _, predicted = torch.max(output.data, 1) 133 | n_examples += target.size(0) 134 | n_correct += (predicted == target).sum().item() 135 | 136 | # confusion matrix 137 | for t_idx, t in enumerate(target): 138 | if predicted[t_idx] == t and predicted[t_idx] == 1: # true positive 139 | confusion_matrix[0] += 1 140 | elif t == 0 and predicted[t_idx] == 1: 141 | confusion_matrix[1] += 1 # false positives 142 | elif t == 1 and predicted[t_idx] == 0: 143 | confusion_matrix[2] += 1 # false negative 144 | else: 145 | confusion_matrix[3] += 1 146 | 147 | # one hot label for AUC 148 | one_hot_label = np.zeros(shape=(len(target), model.module.num_classes)) 149 | for k in range(len(target)): 150 | one_hot_label[k][target[k].item()] = 1 151 | 152 | prob = torch.nn.functional.softmax(output, dim=1) 153 | total_output.extend(prob.data.cpu().numpy()) 154 | total_one_hot_label.extend(one_hot_label) 155 | # one hot label for AUC 156 | 157 | n_batches += 1 158 | total_cross_entropy += cross_entropy.item() 159 | total_cluster_cost += cluster_cost.item() 160 | total_separation_cost += separation_cost.item() 161 | total_fa_cost += fine_annotation_cost 162 | if class_specific: 163 | total_avg_separation_cost += avg_separation_cost.item() 164 | 165 | # compute gradient and do SGD step 166 | if is_train: 167 | if coefs is not None: 168 | loss = (coefs['crs_ent'] * cross_entropy 169 | + coefs['clst'] * cluster_cost 170 | + coefs['sep'] * separation_cost 171 | + coefs['l1'] * l1 172 | + coefs['fine'] * fine_annotation_cost) 173 | else: 174 | loss = cross_entropy + 0.8 * cluster_cost - 0.08 * separation_cost + 1e-4 * l1 175 | 176 | optimizer.zero_grad() 177 | loss.backward() 178 | optimizer.step() 179 | 180 | del input 181 | del target 182 | del output 183 | del predicted 184 | del min_distances 185 | 186 | end = time.time() 187 | 188 | log('\ttime: \t{0}'.format(end - start)) 189 | log('\tcross ent: \t{0}'.format(total_cross_entropy / n_batches)) 190 | log('\tcluster: \t{0}'.format(total_cluster_cost / n_batches)) 191 | log('\tseparation:\t{0}'.format(total_separation_cost / n_batches)) 192 | log('\tfine annotation:\t{0}'.format(total_fa_cost / n_batches)) 193 | if class_specific: 194 | log('\tavg separation:\t{0}'.format(total_avg_separation_cost / n_batches)) 195 | 196 | avg_auc = 0 197 | for auc_idx in range(len(total_one_hot_label[0])): 198 | avg_auc += roc_auc_score(np.array(total_one_hot_label)[:, auc_idx], np.array(total_output)[:, auc_idx]) / len(total_one_hot_label[0]) 199 | log("\tauc score for class {} is: \t\t{}".format(auc_idx, 200 | roc_auc_score(np.array(total_one_hot_label)[:, auc_idx], np.array(total_output)[:, auc_idx]))) 201 | 202 | log('\taccu: \t\t{0}%'.format(n_correct / n_examples * 100)) 203 | log('\tl1: \t\t{0}'.format(model.module.last_layer.weight.norm(p=1).item())) 204 | p = model.module.prototype_vectors.view(model.module.num_prototypes, -1).cpu() 205 | with torch.no_grad(): 206 | p_avg_pair_dist = torch.mean(list_of_distances(p, p)) 207 | log('\tp dist pair: \t{0}'.format(p_avg_pair_dist.item())) 208 | log('\tthe confusion matrix is: \t\t{0}'.format(confusion_matrix)) 209 | 210 | return avg_auc 211 | 212 | 213 | def train(model, dataloader, optimizer, class_specific=False, coefs=None, log=print, finer_loader=None): 214 | assert(optimizer is not None) 215 | 216 | log('\ttrain') 217 | model.train() 218 | return _train_or_test(model=model, dataloader=dataloader, optimizer=optimizer, 219 | class_specific=class_specific, coefs=coefs, log=log, finer_loader=finer_loader) 220 | 221 | 222 | def test(model, dataloader, class_specific=False, log=print, save_logits=False): 223 | log('\ttest') 224 | model.eval() 225 | return _train_or_test(model=model, dataloader=dataloader, optimizer=None, 226 | class_specific=class_specific, log=log, save_logits=save_logits) 227 | 228 | 229 | def last_only(model, log=print): 230 | for p in model.module.features.parameters(): 231 | p.requires_grad = False 232 | for p in model.module.add_on_layers.parameters(): 233 | p.requires_grad = False 234 | model.module.prototype_vectors.requires_grad = False 235 | for p in model.module.last_layer.parameters(): 236 | p.requires_grad = True 237 | 238 | log('\tlast layer') 239 | 240 | 241 | def warm_only(model, log=print): 242 | for p in model.module.features.parameters(): 243 | p.requires_grad = False 244 | for p in model.module.add_on_layers.parameters(): 245 | p.requires_grad = True 246 | model.module.prototype_vectors.requires_grad = True 247 | for p in model.module.last_layer.parameters(): 248 | p.requires_grad = True 249 | 250 | log('\twarm') 251 | 252 | 253 | def joint(model, log=print): 254 | for p in model.module.features.parameters(): 255 | p.requires_grad = True 256 | for p in model.module.add_on_layers.parameters(): 257 | p.requires_grad = True 258 | model.module.prototype_vectors.requires_grad = True 259 | for p in model.module.last_layer.parameters(): 260 | p.requires_grad = True 261 | 262 | log('\tjoint') 263 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import matplotlib.pyplot as plt 4 | import matplotlib 5 | import numpy as np 6 | matplotlib.use("Agg") 7 | import torch 8 | import torch.utils.data 9 | # import torch.utils.data.distributed 10 | import torchvision.transforms as transforms 11 | import torchvision.datasets as datasets 12 | 13 | import argparse 14 | import re 15 | from dataHelper import DatasetFolder 16 | from helpers import makedir 17 | import model 18 | import push 19 | import prune 20 | import train_and_test as tnt 21 | import save 22 | from log import create_logger 23 | from preprocess import mean, std, preprocess_input_function 24 | import random 25 | 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('-gpuid', nargs=1, type=str, default='0') # python3 main.py -gpuid=0,1,2,3 28 | parser.add_argument('-experiment_run', nargs=1, type=str, default='0') 29 | parser.add_argument("-latent", nargs=1, type=int, default=32) 30 | parser.add_argument("-last_layer_weight", nargs=1, type=int, default=None) 31 | parser.add_argument("-fa_coeff", nargs=1, type=float, default=None) 32 | parser.add_argument("-model", type=str) 33 | parser.add_argument("-base", type=str) 34 | parser.add_argument("-train_dir", type=str) 35 | parser.add_argument("-test_dir", type=str) 36 | parser.add_argument("-push_dir", type=str) 37 | parser.add_argument('-finer_dir', type=str) 38 | parser.add_argument("-random_seed", nargs=1, type=int) 39 | parser.add_argument("-topk_k", nargs=1, type=int) 40 | args = parser.parse_args() 41 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpuid[0] 42 | latent_shape = args.latent[0] 43 | experiment_run = args.experiment_run[0] 44 | load_model_dir = args.model 45 | base_architecture = args.base 46 | last_layer_weight = args.last_layer_weight[0] 47 | fa_coeff_manual = args.fa_coeff 48 | topk_k = args.topk_k[0] 49 | 50 | random_seed_number = args.random_seed[0] 51 | torch.manual_seed(random_seed_number) 52 | torch.cuda.manual_seed(random_seed_number) 53 | np.random.seed(random_seed_number) 54 | random.seed(random_seed_number) 55 | torch.backends.cudnn.enabled=False 56 | torch.backends.cudnn.deterministic=True 57 | 58 | # book keeping namings and code 59 | from settings import img_size, prototype_shape, num_classes, \ 60 | prototype_activation_function, add_on_layers_type, prototype_activation_function_in_numpy 61 | 62 | if not base_architecture: 63 | from settings import base_architecture 64 | 65 | base_architecture_type = re.match('^[a-z]*', base_architecture).group(0) 66 | 67 | prototype_shape = (prototype_shape[0], latent_shape, prototype_shape[2], prototype_shape[3]) 68 | print("Protoype shape: ", prototype_shape) 69 | 70 | model_dir = '/usr/xtmp/mammo/saved_models/' + base_architecture + '/' + experiment_run + '/' 71 | print("saving models to: ", model_dir) 72 | makedir(model_dir) 73 | shutil.copy(src=os.path.join(os.getcwd(), __file__), dst=model_dir) 74 | shutil.copy(src=os.path.join(os.getcwd(), 'settings.py'), dst=model_dir) 75 | shutil.copy(src=os.path.join(os.getcwd(), base_architecture_type + '_features.py'), dst=model_dir) 76 | shutil.copy(src=os.path.join(os.getcwd(), 'model.py'), dst=model_dir) 77 | shutil.copy(src=os.path.join(os.getcwd(), 'train_and_test.py'), dst=model_dir) 78 | log, logclose = create_logger(log_filename=os.path.join(model_dir, 'train.log')) 79 | img_dir = os.path.join(model_dir, 'img') 80 | makedir(img_dir) 81 | weight_matrix_filename = 'outputL_weights' 82 | prototype_img_filename_prefix = 'prototype-img' 83 | prototype_self_act_filename_prefix = 'prototype-self-act' 84 | proto_bound_boxes_filename_prefix = 'bb' 85 | 86 | # load the data 87 | from settings import train_dir, test_dir, train_push_dir, \ 88 | train_batch_size, test_batch_size, train_push_batch_size 89 | 90 | normalize = transforms.Normalize(mean=mean, 91 | std=std) 92 | 93 | if args.train_dir: 94 | print("inputting training dir") 95 | train_dir = args.train_dir 96 | if args.test_dir: 97 | test_dir = args.test_dir 98 | if args.push_dir: 99 | train_push_dir = args.push_dir 100 | if args.finer_dir: 101 | finer_annotation_dir = args.finer_dir 102 | print("fine annotation set location: ", finer_annotation_dir) 103 | else: 104 | finer_annotation_dir = None 105 | finer_train_loader = None 106 | 107 | # all datasets 108 | # train set 109 | train_dataset = DatasetFolder( 110 | train_dir, 111 | augmentation=False, 112 | loader=np.load, 113 | extensions=("npy",), 114 | transform = transforms.Compose([ 115 | torch.from_numpy, 116 | ])) 117 | train_loader = torch.utils.data.DataLoader( 118 | train_dataset, batch_size=train_batch_size, shuffle=True, 119 | num_workers=4, pin_memory=False) 120 | 121 | # finer train set 122 | if finer_annotation_dir: 123 | finer_train_dataset = DatasetFolder( 124 | finer_annotation_dir, 125 | augmentation=False, 126 | loader=np.load, 127 | extensions=('npy',), 128 | transform = transforms.Compose([ 129 | torch.from_numpy, 130 | ])) 131 | finer_train_loader = torch.utils.data.DataLoader( 132 | finer_train_dataset, batch_size=10, shuffle=True, num_workers=4, pin_memory=False) 133 | 134 | # push set 135 | train_push_dataset = DatasetFolder( 136 | root = train_push_dir, 137 | loader = np.load, 138 | extensions=("npy",), 139 | transform = transforms.Compose([ 140 | torch.from_numpy, 141 | ])) 142 | train_push_loader = torch.utils.data.DataLoader( 143 | train_push_dataset, batch_size=train_push_batch_size, shuffle=False, 144 | num_workers=4, pin_memory=False) 145 | 146 | # test set 147 | test_dataset =DatasetFolder( 148 | test_dir, 149 | loader=np.load, 150 | extensions=("npy",), 151 | transform = transforms.Compose([ 152 | torch.from_numpy, 153 | ])) 154 | test_loader = torch.utils.data.DataLoader( 155 | test_dataset, batch_size=test_batch_size, shuffle=False, 156 | num_workers=4, pin_memory=False) 157 | 158 | 159 | # we should look into distributed sampler more carefully at torch.utils.data.distributed.DistributedSampler(train_dataset) 160 | log('training set location: {0}'.format(train_dir)) 161 | log('training set size: {0}'.format(len(train_loader.dataset))) 162 | log('push set location: {0}'.format(train_push_dir)) 163 | log('push set size: {0}'.format(len(train_push_loader.dataset))) 164 | log('test set location: {0}'.format(test_dir)) 165 | log('test set size: {0}'.format(len(test_loader.dataset))) 166 | log('batch size: {0}'.format(train_batch_size)) 167 | log("Using topk_k coeff from bash args: {0}, which is {1:.4}%".format(topk_k, float(topk_k)*100./(14*14))) # for prototype size 1x1 on 14x14 grid experminents 168 | 169 | from settings import class_specific 170 | # construct the model 171 | if load_model_dir: 172 | ppnet = torch.load(load_model_dir) 173 | log('starting from model: {0}'.format(load_model_dir)) 174 | else: 175 | ppnet = model.construct_PPNet(base_architecture=base_architecture, 176 | pretrained=True, img_size=img_size, 177 | prototype_shape=prototype_shape, 178 | topk_k=topk_k, 179 | num_classes=num_classes, 180 | prototype_activation_function=prototype_activation_function, 181 | add_on_layers_type=add_on_layers_type, 182 | last_layer_weight=last_layer_weight, 183 | class_specific=class_specific) 184 | 185 | #if prototype_activation_function == 'linear': 186 | # ppnet.set_last_layer_incorrect_connection(incorrect_strength=0) 187 | ppnet = ppnet.cuda() 188 | ppnet_multi = torch.nn.DataParallel(ppnet) 189 | 190 | # define optimizer 191 | from settings import joint_optimizer_lrs, joint_lr_step_size 192 | joint_optimizer_specs = \ 193 | [{'params': ppnet.features.parameters(), 'lr': joint_optimizer_lrs['features'], 'weight_decay': 1e-3}, # bias are now also being regularized 194 | {'params': ppnet.add_on_layers.parameters(), 'lr': joint_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3}, 195 | {'params': ppnet.prototype_vectors, 'lr': joint_optimizer_lrs['prototype_vectors']}, 196 | ] 197 | joint_optimizer = torch.optim.Adam(joint_optimizer_specs) 198 | joint_lr_scheduler = torch.optim.lr_scheduler.StepLR(joint_optimizer, step_size=joint_lr_step_size, gamma=0.1) 199 | 200 | from settings import warm_optimizer_lrs 201 | warm_optimizer_specs = \ 202 | [{'params': ppnet.add_on_layers.parameters(), 'lr': warm_optimizer_lrs['add_on_layers'], 'weight_decay': 1e-3}, 203 | {'params': ppnet.prototype_vectors, 'lr': warm_optimizer_lrs['prototype_vectors']}, 204 | ] 205 | warm_optimizer = torch.optim.Adam(warm_optimizer_specs) 206 | 207 | from settings import last_layer_optimizer_lr 208 | last_layer_optimizer_specs = [{'params': ppnet.last_layer.parameters(), 'lr': last_layer_optimizer_lr}] 209 | last_layer_optimizer = torch.optim.Adam(last_layer_optimizer_specs) 210 | 211 | # weighting of different training losses 212 | from settings import coefs 213 | 214 | # for fa adjustment training only 215 | if not (fa_coeff_manual==None): 216 | coefs['fine'] = fa_coeff_manual[0] 217 | print("Using fa coeff from bash args: {}".format(coefs['fine'])) 218 | else: 219 | print("Using fa coeff from settings: {}".format(coefs['fine'])) 220 | 221 | # number of training epochs, number of warm epochs, push start epoch, push epochs 222 | from settings import num_train_epochs, num_warm_epochs, push_start, push_epochs 223 | 224 | # train the model 225 | log('start training') 226 | import copy 227 | 228 | train_auc = [] 229 | test_auc = [] 230 | currbest, best_epoch = 0, -1 231 | 232 | for epoch in range(num_train_epochs): 233 | log('epoch: \t{0}'.format(epoch)) 234 | 235 | if epoch < num_warm_epochs: 236 | tnt.warm_only(model=ppnet_multi, log=log) 237 | _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=warm_optimizer, 238 | class_specific=class_specific, coefs=coefs, log=log, finer_loader=finer_train_loader) 239 | else: 240 | tnt.joint(model=ppnet_multi, log=log) 241 | _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=joint_optimizer, 242 | class_specific=class_specific, coefs=coefs, log=log, finer_loader=finer_train_loader) 243 | joint_lr_scheduler.step() 244 | 245 | auc = tnt.test(model=ppnet_multi, dataloader=test_loader, 246 | class_specific=class_specific, log=log) 247 | save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'nopush', accu=auc, 248 | target_accu=0.00, log=log) 249 | 250 | train_auc.append(_) 251 | if currbest < auc: 252 | currbest = auc 253 | best_epoch = epoch 254 | log("\tcurrent best auc is: \t\t{} at epoch {}".format(currbest, best_epoch)) 255 | test_auc.append(auc) 256 | plt.plot(train_auc, "b", label="train") 257 | plt.plot(test_auc, "r", label="test") 258 | plt.ylim(0.4, 1) 259 | plt.legend() 260 | plt.savefig(model_dir + 'train_test_auc.png') 261 | plt.close() 262 | 263 | 264 | if epoch >= push_start and epoch in push_epochs: 265 | push.push_prototypes( 266 | train_push_loader, # pytorch dataloader (must be unnormalized in [0,1]) 267 | prototype_network_parallel=ppnet_multi, # pytorch network with prototype_vectors 268 | class_specific=class_specific, 269 | preprocess_input_function=preprocess_input_function, # normalize if needed 270 | prototype_layer_stride=1, 271 | root_dir_for_saving_prototypes=img_dir, # if not None, prototypes will be saved here 272 | epoch_number=epoch, # if not provided, prototypes saved previously will be overwritten 273 | prototype_img_filename_prefix=prototype_img_filename_prefix, 274 | prototype_self_act_filename_prefix=prototype_self_act_filename_prefix, 275 | proto_bound_boxes_filename_prefix=proto_bound_boxes_filename_prefix, 276 | save_prototype_class_identity=True, 277 | log=log, 278 | prototype_activation_function_in_numpy=prototype_activation_function_in_numpy) 279 | accu = tnt.test(model=ppnet_multi, dataloader=test_loader, 280 | class_specific=class_specific, log=log) 281 | save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + 'push', accu=accu, 282 | target_accu=0.00, log=log) 283 | 284 | if prototype_activation_function != 'linear': 285 | tnt.last_only(model=ppnet_multi, log=log) 286 | for i in range(10): 287 | log('iteration: \t{0}'.format(i)) 288 | _ = tnt.train(model=ppnet_multi, dataloader=train_loader, optimizer=last_layer_optimizer, 289 | class_specific=class_specific, coefs=coefs, log=log, finer_loader=finer_train_loader) 290 | auc = tnt.test(model=ppnet_multi, dataloader=test_loader, 291 | class_specific=class_specific, log=log) 292 | save.save_model_w_condition(model=ppnet, model_dir=model_dir, model_name=str(epoch) + '_' + str(i) + 'push', accu=auc, 293 | target_accu=0.00, log=log) 294 | train_auc.append(_) 295 | test_auc.append(auc) 296 | 297 | if currbest < auc: 298 | currbest = auc 299 | best_epoch = epoch 300 | 301 | plt.plot(train_auc, "b", label="train") 302 | plt.plot(test_auc, "r", label="test") 303 | plt.ylim(0.4, 1) 304 | plt.legend() 305 | plt.savefig(model_dir + 'train_test_auc' + ".png") 306 | plt.close() 307 | 308 | logclose() 309 | 310 | -------------------------------------------------------------------------------- /gradcam_APs.py: -------------------------------------------------------------------------------- 1 | ### Adapted from https://github.com/stefannc/GradCAM-Pytorch/blob/07fd6ece5010f7c1c9fbcc8155a60023819111d7/example.ipynb retrieved Mar 3 2021 ##### 2 | 3 | ## cell 1: imports 4 | import os 5 | import PIL 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | import torch.nn as nn 10 | import torchvision.models as models 11 | from torchvision.utils import make_grid, save_image 12 | 13 | from gradcam_utils import visualize_cam, Normalize 14 | from gradcam import GradCAM, GradCAMpp 15 | 16 | import torchvision.transforms as transforms 17 | from vanilla_vgg import Vanilla_VGG 18 | from dataHelper import DatasetFolder, DatasetFolder_WithReplacement 19 | from skimage.transform import resize 20 | import our_vgg 21 | from collections import defaultdict 22 | import argparse 23 | 24 | ## argparsing 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("-save_loc", type=str) 27 | args = parser.parse_args() 28 | 29 | ## get our mammo img 30 | test_dir = '/usr/xtmp/IAIABL/Lo1136i/test/' 31 | test_dataset = DatasetFolder( 32 | test_dir, 33 | augmentation=False, 34 | loader=np.load, 35 | extensions=("npy",), 36 | transform=transforms.Compose([ 37 | torch.from_numpy, 38 | ]) 39 | ) 40 | test_loader = torch.utils.data.DataLoader( 41 | test_dataset, batch_size=1, shuffle=True, 42 | num_workers=4, pin_memory=False) 43 | sample_img, target, patient_id = next(iter(test_loader)) 44 | print(patient_id) 45 | normed_torch_img = sample_img.cuda() 46 | torch_img = normed_torch_img 47 | 48 | 49 | ## cell 4: load model 50 | # vgg = models.vgg16(pretrained=True) 51 | model_path = '/usr/xtmp/IAIABL/saved_models/vanilla/0125_vanilla_3margin_vgg16_latent512_baseline3_random=4/0.9384582045743842_at_epoch_136' 52 | vgg_us = torch.load(model_path) 53 | vgg_us.eval(), vgg_us.cuda(); 54 | 55 | state_dict = vgg_us.state_dict() 56 | 57 | for key in list(state_dict.keys()): 58 | state_dict[key.replace('features.features.', 'features.')] = state_dict.pop(key) 59 | 60 | vgg_l = our_vgg.vgg16() 61 | vgg_l.load_state_dict(state_dict) 62 | vgg_l.eval(), vgg_l.cuda(); 63 | 64 | # Ref: https://stackoverflow.com/q/54846905/7521428 65 | # print("### OUR MODEL ###") 66 | # l = [module for module in vgg_us.modules() if type(module) != nn.Sequential] 67 | # print(l) 68 | 69 | print("### USUAL VGG MODEL ###") 70 | vgg = models.vgg16(pretrained=True) 71 | vgg.eval(), vgg.cuda(); 72 | # l = [module for module in vgg.modules() if type(module) != nn.Sequential] 73 | # print(l) 74 | 75 | cam_dict = dict() 76 | 77 | vgg_model_dict = dict(type='vgg', arch=vgg, layer_name='features_29', input_size=(224, 224)) 78 | vgg_gradcam = GradCAM(vgg_model_dict, True) 79 | vgg_gradcampp = GradCAMpp(vgg_model_dict, True) 80 | cam_dict['vgg'] = [vgg_gradcam, vgg_gradcampp] 81 | 82 | vgg_model_dict = dict(type='vgg', arch=vgg, layer_name='features_6', input_size=(224, 224)) 83 | vgg_gradcam = GradCAM(vgg_model_dict, True) 84 | vgg_gradcampp = GradCAMpp(vgg_model_dict, True) 85 | cam_dict['vgg_layer6'] = [vgg_gradcam, vgg_gradcampp] 86 | 87 | vgg_us_model_dict = dict(type='vgg_us', arch=vgg_us, layer_name='features_29', input_size=(224, 224)) 88 | vgg_us_gradcam = GradCAM(vgg_us_model_dict, True) 89 | vgg_us_gradcampp = GradCAMpp(vgg_us_model_dict, True) 90 | cam_dict['vgg_us'] = [vgg_us_gradcam, vgg_us_gradcampp] 91 | 92 | vgg_l_model_dict = dict(type='vgg', arch=vgg_l, layer_name='features_29', input_size=(224, 224)) 93 | vgg_l_gradcam = GradCAM(vgg_l_model_dict, True) 94 | vgg_l_gradcampp = GradCAMpp(vgg_l_model_dict, True) 95 | cam_dict['vgg_l'] = [vgg_l_gradcam, vgg_l_gradcampp] 96 | 97 | ## cell 5: make image grid 98 | images = [] 99 | for gradcam, gradcam_pp in cam_dict.values(): 100 | mask, _ = gradcam(normed_torch_img) 101 | # print("Min of mask is: ", torch.min(mask)) 102 | # print("Max of mask is: ", torch.max(mask)) 103 | heatmap, result = visualize_cam(mask, torch_img) 104 | 105 | mask_pp, _ = gradcam_pp(normed_torch_img) 106 | # print("Max of mask_pp is: ", torch.max(mask_pp)) 107 | heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img) 108 | 109 | images.append(torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0)) 110 | 111 | image_grid = make_grid(torch.cat(images, 0), nrow=5) 112 | 113 | ## cell 6: save image grid 114 | output_path = args.save_loc 115 | 116 | os.makedirs(output_path[:-8]) 117 | 118 | save_image(image_grid, output_path) 119 | 120 | ## from AP generator 121 | def activation_precision(dataloader, # can be train, test, train_finer, test_finer 122 | model, 123 | gradcam, 124 | num_classes=3, 125 | preprocess_input_function=None, 126 | log=print, 127 | debug_mode=True, 128 | per_class=False): 129 | 130 | #assert dataloader loads with fourth channel 131 | #assert dataloader batch size of 1 132 | 133 | precisions = [] 134 | 135 | per_class_hp = defaultdict(list) 136 | 137 | for idx, (search_batch_input, search_y, patient_id) in enumerate(dataloader): 138 | if debug_mode: 139 | print('batch {}'.format(idx)) 140 | if preprocess_input_function is not None: 141 | # print('preprocessing input for pushing ...') 142 | # search_batch = copy.deepcopy(search_batch_input) 143 | search_batch = preprocess_input_function(search_batch_input[:, :3, : , :]) 144 | else: 145 | search_batch = search_batch_input 146 | 147 | search_batch = search_batch_input[:, :3, : , :] 148 | fine_anno = 1 - search_batch_input[:, 3:, : , :] 149 | 150 | if debug_mode: 151 | print("search_batch:", search_batch.shape) 152 | print("fine_anno:", fine_anno.shape) 153 | print("search_y:", search_y.detach().cpu().numpy()[0]) 154 | print("search_y.shape, sy[0]:", search_y.shape, search_y[0]) 155 | print("fine_anno[0][0][0][0]: ", fine_anno[0][0][0][0]) 156 | print("fine_anno[0][0][122][122]: ", fine_anno[0][0][122][122]) 157 | 158 | 159 | with torch.no_grad(): 160 | search_batch = search_batch.cuda() 161 | fine_anno = fine_anno.cuda() 162 | 163 | proto_acts, _ = gradcam(search_batch, class_idx=search_y.detach().cpu().numpy()[0]) 164 | 165 | if debug_mode: 166 | print("proto_acts:", proto_acts.shape) 167 | 168 | hps = fine_anno * proto_acts 169 | 170 | if debug_mode: 171 | print("hps:", hps.shape) 172 | 173 | fine_anno_ = np.copy(fine_anno.detach().cpu().numpy()) 174 | 175 | percentile = 95 176 | 177 | activation_map_ = proto_acts.cpu() 178 | threshold = np.percentile(activation_map_, percentile) 179 | mask = np.ones(activation_map_.shape) 180 | mask[activation_map_ < threshold] = 0 181 | 182 | if idx==0 and debug_mode: 183 | print("act_map:", activation_map_.shape) 184 | print("mask:", mask.shape) 185 | print("fine_anno_:", fine_anno_.shape) 186 | denom = np.sum(mask) 187 | num = np.sum(mask * fine_anno_[0][0]) 188 | 189 | if idx==0 and debug_mode: 190 | print(f"act. prec. for first image is: {num/denom}") 191 | precisions.append(num/denom) 192 | per_class_hp[search_y.detach().cpu().numpy()[0]].append(num/denom) 193 | 194 | 195 | if per_class: 196 | per_class_hp_list = [] 197 | for k, v in per_class_hp.items(): 198 | per_class_hp_list.append((k, np.average(np.asarray(v)))) 199 | per_class_hp_list.sort(key=lambda x: x[0]) 200 | return per_class_hp_list 201 | else: 202 | return np.average(np.asarray(precisions)) 203 | 204 | 205 | ## call the AP func for a single check 206 | 207 | test_dir = '/usr/xtmp/IAIABL/Lo1136i_finer/by_margin/test/' 208 | test_dataset = DatasetFolder( 209 | test_dir, 210 | augmentation=False, 211 | loader=np.load, 212 | extensions=("npy",), 213 | transform=transforms.Compose([ 214 | torch.from_numpy, 215 | ]) 216 | ) 217 | test_loader = torch.utils.data.DataLoader( 218 | test_dataset, batch_size=1, shuffle=True, 219 | num_workers=4, pin_memory=False) 220 | 221 | num_classes = len(test_dataset.classes) 222 | 223 | print("fine-scale activation precision for gradCAM is: ", 224 | activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 225 | model=vgg_l, 226 | gradcam=vgg_l_gradcam, 227 | num_classes=num_classes, 228 | preprocess_input_function=None, 229 | log=print, 230 | debug_mode=False, 231 | per_class=False) 232 | ) 233 | 234 | print("fine-scale activation precision for gradCAM++ is: ", 235 | activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 236 | model=vgg_l, 237 | gradcam=vgg_l_gradcampp, 238 | num_classes=num_classes, 239 | preprocess_input_function=None, 240 | log=print, 241 | debug_mode=False, 242 | per_class=False) 243 | ) 244 | 245 | test_dir = '/usr/xtmp/IAIABL/Lo1136i_with_fa/test/' 246 | test_dataset = DatasetFolder( 247 | test_dir, 248 | augmentation=False, 249 | loader=np.load, 250 | extensions=("npy",), 251 | transform=transforms.Compose([ 252 | torch.from_numpy, 253 | ]) 254 | ) 255 | test_loader = torch.utils.data.DataLoader( 256 | test_dataset, batch_size=1, shuffle=True, 257 | num_workers=4, pin_memory=False) 258 | 259 | print("lesion-scale activation precision for gradCAM is: ", 260 | activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 261 | model=vgg_l, 262 | gradcam=vgg_l_gradcam, 263 | num_classes=num_classes, 264 | preprocess_input_function=None, 265 | log=print, 266 | debug_mode=False, 267 | per_class=False) 268 | ) 269 | 270 | print("lesion-scale activation precision for gradCAM++ is: ", 271 | activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 272 | model=vgg_l, 273 | gradcam=vgg_l_gradcampp, 274 | num_classes=num_classes, 275 | preprocess_input_function=None, 276 | log=print, 277 | debug_mode=False, 278 | per_class=False) 279 | ) 280 | 281 | 282 | ## bootstrapped AP function calls 283 | 284 | f_test_dir = '/usr/xtmp/IAIABL/Lo1136i_finer/by_margin/test/' 285 | l_test_dir = '/usr/xtmp/IAIABL/Lo1136i_with_fa/test/' 286 | 287 | for test_dir in [f_test_dir, l_test_dir]: 288 | print(f'for data in {test_dir}') 289 | test_dataset = DatasetFolder_WithReplacement( 290 | test_dir, 291 | augmentation=False, 292 | loader=np.load, 293 | extensions=("npy",), 294 | transform=transforms.Compose([ 295 | torch.from_numpy, 296 | ]) 297 | ) 298 | 299 | test_batch_size = len(test_dataset.samples) # we decided to use a sample size equal to the size of the test set. 300 | CI = 0.95 # confidence interval 301 | num_iterations = 10 #DEAR REVIEWERS: In the implementation presented in the paper this value was 5000, but to make the demo faster I reduced this value. 302 | aps = [0]*num_iterations # doing this instead of a list append marginally improves computational efficiency 303 | aps_pp = [0]*num_iterations 304 | 305 | for i in range(num_iterations): 306 | test_loader = torch.utils.data.DataLoader( 307 | test_dataset, batch_size=1, shuffle=True, 308 | num_workers=4, pin_memory=False) 309 | aps[i] = activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 310 | model=vgg_l, 311 | gradcam=vgg_l_gradcam, 312 | num_classes=num_classes, 313 | preprocess_input_function=None, 314 | log=print, 315 | debug_mode=False, 316 | per_class=False) 317 | aps_pp[i] = activation_precision(dataloader=test_loader, # can be train, test, train_finer, test_finer 318 | model=vgg_l, 319 | gradcam=vgg_l_gradcampp, 320 | num_classes=num_classes, 321 | preprocess_input_function=None, 322 | log=print, 323 | debug_mode=False, 324 | per_class=False) 325 | 326 | vois = zip([aps, aps_pp],\ 327 | ['GradCAM AP', 'GradCAM++ AP']) 328 | lower, upper = 100 * ( (1.0 - CI)/2. ), 100 * ( 1.0 - ((1.0 - CI)/2.) ) 329 | for valueofinterest, valueofinterest_str in vois: 330 | voi_mean = np.mean(np.asarray(valueofinterest)) 331 | voi_std = np.std(np.asarray(valueofinterest)) 332 | voi_lower, voi_upper = np.percentile(valueofinterest, [lower, upper]) 333 | print(f"Final mean {valueofinterest_str} is {voi_mean}, std {voi_std} {CI*100}% confidence iterval accuracy is {voi_lower} to {voi_upper} with {num_iterations} iterations.") -------------------------------------------------------------------------------- /densenet_features.py: -------------------------------------------------------------------------------- 1 | import re 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.utils.model_zoo as model_zoo 6 | from collections import OrderedDict 7 | 8 | 9 | model_urls = { 10 | 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth', 11 | 'densenet169': 'https://download.pytorch.org/models/densenet169-b2777c0a.pth', 12 | 'densenet201': 'https://download.pytorch.org/models/densenet201-c1103571.pth', 13 | 'densenet161': 'https://download.pytorch.org/models/densenet161-8d451a50.pth', 14 | } 15 | model_dir = './pretrained_models' 16 | 17 | 18 | class _DenseLayer(nn.Sequential): 19 | 20 | num_layers = 2 21 | 22 | def __init__(self, num_input_features, growth_rate, bn_size, drop_rate): 23 | super(_DenseLayer, self).__init__() 24 | self.add_module('norm1', nn.BatchNorm2d(num_input_features)), 25 | self.add_module('relu1', nn.ReLU(inplace=True)), 26 | self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * 27 | growth_rate, kernel_size=1, stride=1, bias=False)), 28 | self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)), 29 | self.add_module('relu2', nn.ReLU(inplace=True)), 30 | self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate, 31 | kernel_size=3, stride=1, padding=1, bias=False)), 32 | self.drop_rate = drop_rate 33 | 34 | def forward(self, x): 35 | new_features = super(_DenseLayer, self).forward(x) 36 | if self.drop_rate > 0: 37 | new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) 38 | 39 | # channelwise concatenation 40 | return torch.cat([x, new_features], 1) 41 | 42 | def layer_conv_info(self): 43 | layer_kernel_sizes = [1, 3] 44 | layer_strides = [1, 1] 45 | layer_paddings = [0, 1] 46 | 47 | return layer_kernel_sizes, layer_strides, layer_paddings 48 | 49 | 50 | class _DenseBlock(nn.Sequential): 51 | def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate): 52 | super(_DenseBlock, self).__init__() 53 | self.block_kernel_sizes = [] 54 | self.block_strides = [] 55 | self.block_paddings = [] 56 | 57 | for i in range(num_layers): 58 | layer = _DenseLayer(num_input_features + i * growth_rate, growth_rate, bn_size, drop_rate) 59 | layer_kernel_sizes, layer_strides, layer_paddings = layer.layer_conv_info() 60 | self.block_kernel_sizes.extend(layer_kernel_sizes) 61 | self.block_strides.extend(layer_strides) 62 | self.block_paddings.extend(layer_paddings) 63 | self.add_module('denselayer%d' % (i + 1), layer) 64 | 65 | self.num_layers = _DenseLayer.num_layers * num_layers 66 | 67 | def block_conv_info(self): 68 | return self.block_kernel_sizes, self.block_strides, self.block_paddings 69 | 70 | 71 | class _Transition(nn.Sequential): 72 | 73 | num_layers = 1 74 | 75 | def __init__(self, num_input_features, num_output_features): 76 | super(_Transition, self).__init__() 77 | self.add_module('norm', nn.BatchNorm2d(num_input_features)) 78 | self.add_module('relu', nn.ReLU(inplace=True)) 79 | self.add_module('conv', nn.Conv2d(num_input_features, num_output_features, 80 | kernel_size=1, stride=1, bias=False)) 81 | self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2)) # AvgPool2d has no padding 82 | 83 | def block_conv_info(self): 84 | return [1, 2], [1, 2], [0, 0] 85 | 86 | 87 | class DenseNet_features(nn.Module): 88 | r"""Densenet-BC model class, based on 89 | `"Densely Connected Convolutional Networks" `_ 90 | 91 | Args: 92 | growth_rate (int) - how many filters to add each layer (`k` in paper) 93 | block_config (list of 4 ints) - how many layers in each pooling block 94 | num_init_features (int) - the number of filters to learn in the first convolution layer 95 | bn_size (int) - multiplicative factor for number of bottle neck layers 96 | (i.e. bn_size * k features in the bottleneck layer) 97 | drop_rate (float) - dropout rate after each dense layer 98 | num_classes (int) - number of classification classes 99 | """ 100 | 101 | def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16), 102 | num_init_features=64, bn_size=4, drop_rate=0, num_classes=1000): 103 | 104 | super(DenseNet_features, self).__init__() 105 | self.kernel_sizes = [] 106 | self.strides = [] 107 | self.paddings = [] 108 | 109 | self.n_layers = 0 110 | 111 | # First convolution 112 | self.features = nn.Sequential(OrderedDict([ 113 | ('conv0', nn.Conv2d(in_channels=3, out_channels=num_init_features, kernel_size=7, stride=2, padding=3, bias=False)), 114 | ('norm0', nn.BatchNorm2d(num_init_features)), 115 | ('relu0', nn.ReLU(inplace=True)), 116 | ('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)), 117 | ])) 118 | 119 | self.kernel_sizes.extend([7, 3]) 120 | self.strides.extend([2, 2]) 121 | self.paddings.extend([3, 1]) 122 | 123 | # Each denseblock 124 | num_features = num_init_features 125 | for i, num_layers in enumerate(block_config): 126 | block = _DenseBlock(num_layers=num_layers, num_input_features=num_features, 127 | bn_size=bn_size, growth_rate=growth_rate, drop_rate=drop_rate) 128 | self.n_layers += block.num_layers 129 | 130 | block_kernel_sizes, block_strides, block_paddings = block.block_conv_info() 131 | self.kernel_sizes.extend(block_kernel_sizes) 132 | self.strides.extend(block_strides) 133 | self.paddings.extend(block_paddings) 134 | 135 | self.features.add_module('denseblock%d' % (i + 1), block) 136 | num_features = num_features + num_layers * growth_rate 137 | if i != len(block_config) - 1: 138 | trans = _Transition(num_input_features=num_features, num_output_features=num_features // 2) 139 | 140 | self.n_layers += trans.num_layers 141 | 142 | block_kernel_sizes, block_strides, block_paddings = trans.block_conv_info() 143 | self.kernel_sizes.extend(block_kernel_sizes) 144 | self.strides.extend(block_strides) 145 | self.paddings.extend(block_paddings) 146 | 147 | self.features.add_module('transition%d' % (i + 1), trans) 148 | num_features = num_features // 2 149 | 150 | # Final batch norm 151 | self.features.add_module('norm5', nn.BatchNorm2d(num_features)) 152 | self.features.add_module('final_relu', nn.ReLU(inplace=True)) 153 | 154 | # Official init from torch repo. 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv2d): 157 | nn.init.kaiming_normal_(m.weight) 158 | elif isinstance(m, nn.BatchNorm2d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | elif isinstance(m, nn.Linear): 162 | nn.init.constant_(m.bias, 0) 163 | 164 | def forward(self, x): 165 | return self.features(x) 166 | 167 | def conv_info(self): 168 | return self.kernel_sizes, self.strides, self.paddings 169 | 170 | def num_layers(self): 171 | return self.n_layers 172 | 173 | def __repr__(self): 174 | template = 'densenet{}_features' 175 | return template.format((self.num_layers() + 2)) 176 | 177 | 178 | def densenet121_features(pretrained=False, **kwargs): 179 | r"""Densenet-121 model from 180 | `"Densely Connected Convolutional Networks" `_ 181 | 182 | Args: 183 | pretrained (bool): If True, returns a model pre-trained on ImageNet 184 | """ 185 | model = DenseNet_features(num_init_features=64, growth_rate=32, block_config=(6, 12, 24, 16), 186 | **kwargs) 187 | if pretrained: 188 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 189 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 190 | # They are also in the checkpoints in model_urls. This pattern is used 191 | # to find such keys. 192 | pattern = re.compile( 193 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 194 | state_dict = model_zoo.load_url(model_urls['densenet121'], model_dir=model_dir) 195 | for key in list(state_dict.keys()): 196 | ''' 197 | example 198 | key 'features.denseblock4.denselayer24.norm.2.running_var' 199 | res.group(1) 'features.denseblock4.denselayer24.norm' 200 | res.group(2) '2.running_var' 201 | new_key 'features.denseblock4.denselayer24.norm2.running_var' 202 | ''' 203 | res = pattern.match(key) 204 | if res: 205 | new_key = res.group(1) + res.group(2) 206 | state_dict[new_key] = state_dict[key] 207 | del state_dict[key] 208 | 209 | del state_dict['classifier.weight'] 210 | del state_dict['classifier.bias'] 211 | model.load_state_dict(state_dict) 212 | return model 213 | 214 | 215 | def densenet169_features(pretrained=False, **kwargs): 216 | r"""Densenet-169 model from 217 | `"Densely Connected Convolutional Networks" `_ 218 | 219 | Args: 220 | pretrained (bool): If True, returns a model pre-trained on ImageNet 221 | """ 222 | model = DenseNet_features(num_init_features=64, growth_rate=32, block_config=(6, 12, 32, 32), 223 | **kwargs) 224 | if pretrained: 225 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 226 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 227 | # They are also in the checkpoints in model_urls. This pattern is used 228 | # to find such keys. 229 | pattern = re.compile( 230 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 231 | state_dict = model_zoo.load_url(model_urls['densenet169'], model_dir=model_dir) 232 | for key in list(state_dict.keys()): 233 | ''' 234 | example 235 | key 'features.denseblock4.denselayer24.norm.2.running_var' 236 | res.group(1) 'features.denseblock4.denselayer24.norm' 237 | res.group(2) '2.running_var' 238 | new_key 'features.denseblock4.denselayer24.norm2.running_var' 239 | ''' 240 | res = pattern.match(key) 241 | if res: 242 | new_key = res.group(1) + res.group(2) 243 | state_dict[new_key] = state_dict[key] 244 | del state_dict[key] 245 | 246 | del state_dict['classifier.weight'] 247 | del state_dict['classifier.bias'] 248 | model.load_state_dict(state_dict) 249 | return model 250 | 251 | 252 | def densenet201_features(pretrained=False, **kwargs): 253 | r"""Densenet-201 model from 254 | `"Densely Connected Convolutional Networks" `_ 255 | 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = DenseNet_features(num_init_features=64, growth_rate=32, block_config=(6, 12, 48, 32), 260 | **kwargs) 261 | if pretrained: 262 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 263 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 264 | # They are also in the checkpoints in model_urls. This pattern is used 265 | # to find such keys. 266 | pattern = re.compile( 267 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 268 | state_dict = model_zoo.load_url(model_urls['densenet201'], model_dir=model_dir) 269 | for key in list(state_dict.keys()): 270 | ''' 271 | example 272 | key 'features.denseblock4.denselayer24.norm.2.running_var' 273 | res.group(1) 'features.denseblock4.denselayer24.norm' 274 | res.group(2) '2.running_var' 275 | new_key 'features.denseblock4.denselayer24.norm2.running_var' 276 | ''' 277 | res = pattern.match(key) 278 | if res: 279 | new_key = res.group(1) + res.group(2) 280 | state_dict[new_key] = state_dict[key] 281 | del state_dict[key] 282 | 283 | del state_dict['classifier.weight'] 284 | del state_dict['classifier.bias'] 285 | model.load_state_dict(state_dict) 286 | 287 | return model 288 | 289 | 290 | def densenet161_features(pretrained=False, **kwargs): 291 | r"""Densenet-161 model from 292 | `"Densely Connected Convolutional Networks" `_ 293 | 294 | Args: 295 | pretrained (bool): If True, returns a model pre-trained on ImageNet 296 | """ 297 | model = DenseNet_features(num_init_features=96, growth_rate=48, block_config=(6, 12, 36, 24), 298 | **kwargs) 299 | if pretrained: 300 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 301 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 302 | # They are also in the checkpoints in model_urls. This pattern is used 303 | # to find such keys. 304 | pattern = re.compile( 305 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 306 | 307 | 308 | state_dict = model_zoo.load_url(model_urls['densenet161'], model_dir=model_dir) 309 | for key in list(state_dict.keys()): 310 | ''' 311 | example 312 | key 'features.denseblock4.denselayer24.norm.2.running_var' 313 | res.group(1) 'features.denseblock4.denselayer24.norm' 314 | res.group(2) '2.running_var' 315 | new_key 'features.denseblock4.denselayer24.norm2.running_var' 316 | ''' 317 | res = pattern.match(key) 318 | if res: 319 | new_key = res.group(1) + res.group(2) 320 | state_dict[new_key] = state_dict[key] 321 | del state_dict[key] 322 | 323 | 324 | del state_dict['classifier.weight'] 325 | del state_dict['classifier.bias'] 326 | model.load_state_dict(state_dict) 327 | 328 | return model 329 | 330 | if __name__ == '__main__': 331 | 332 | d161 = densenet161_features(True) 333 | print(d161) 334 | 335 | d201 = densenet201_features(True) 336 | print(d201) 337 | 338 | d169 = densenet169_features(True) 339 | print(d169) 340 | 341 | d121 = densenet121_features(True) 342 | print(d121) 343 | -------------------------------------------------------------------------------- /find_nearest.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | import heapq 5 | 6 | import matplotlib.pyplot as plt 7 | import os 8 | import copy 9 | import time 10 | 11 | import cv2 12 | 13 | from receptive_field import compute_rf_prototype 14 | from helpers import makedir, find_high_activation_crop 15 | 16 | def imsave_with_bbox(fname, img_rgb, bbox_height_start, bbox_height_end, 17 | bbox_width_start, bbox_width_end, color=(0, 255, 255)): 18 | img_bgr_uint8 = cv2.cvtColor(np.uint8(255*img_rgb), cv2.COLOR_RGB2BGR) 19 | cv2.rectangle(img_bgr_uint8, (bbox_width_start, bbox_height_start), (bbox_width_end-1, bbox_height_end-1), 20 | color, thickness=2) 21 | img_rgb_uint8 = img_bgr_uint8[...,::-1] 22 | img_rgb_float = np.float32(img_rgb_uint8) / 255 23 | #plt.imshow(img_rgb_float) 24 | #plt.axis('off') 25 | plt.imsave(fname, img_rgb_float) 26 | 27 | class ImagePatch: 28 | 29 | def __init__(self, patch, label, distance, 30 | original_img=None, act_pattern=None, patch_indices=None, conv_output=None): 31 | self.patch = patch 32 | self.label = label 33 | self.conv_output = conv_output 34 | self.negative_distance = -distance 35 | 36 | self.original_img = original_img 37 | self.act_pattern = act_pattern 38 | self.patch_indices = patch_indices 39 | 40 | def __lt__(self, other): 41 | return self.negative_distance < other.negative_distance 42 | 43 | 44 | class ImagePatchInfo: 45 | 46 | def __init__(self, label, distance): 47 | self.label = label 48 | self.negative_distance = -distance 49 | 50 | def __lt__(self, other): 51 | return self.negative_distance < other.negative_distance 52 | 53 | 54 | # find the nearest patches in the dataset to each prototype 55 | def find_k_nearest_patches_to_prototypes(dataloader, # pytorch dataloader (must be unnormalized in [0,1]) 56 | prototype_network_parallel, # pytorch network with prototype_vectors 57 | k=5, 58 | preprocess_input_function=None, # normalize if needed 59 | full_save=False, # save all the images 60 | root_dir_for_saving_images='./nearest', 61 | log=print, 62 | prototype_activation_function_in_numpy=None): 63 | prototype_network_parallel.eval() 64 | ''' 65 | full_save=False will only return the class identity of the closest 66 | patches, but it will not save anything. 67 | ''' 68 | log('find nearest patches') 69 | start = time.time() 70 | n_prototypes = prototype_network_parallel.module.num_prototypes 71 | 72 | prototype_shape = prototype_network_parallel.module.prototype_shape 73 | max_dist = prototype_shape[1] * prototype_shape[2] * prototype_shape[3] 74 | 75 | protoL_rf_info = prototype_network_parallel.module.proto_layer_rf_info 76 | 77 | heaps = [] 78 | # allocate an array of n_prototypes number of heaps 79 | for _ in range(n_prototypes): 80 | # a heap in python is just a maintained list 81 | heaps.append([]) 82 | 83 | for idx, (search_batch_input, search_y, patient_id) in enumerate(dataloader): 84 | print('batch {}'.format(idx)) 85 | if preprocess_input_function is not None: 86 | # print('preprocessing input for pushing ...') 87 | # search_batch = copy.deepcopy(search_batch_input) 88 | search_batch = preprocess_input_function(search_batch_input[:, :3, : , :]) 89 | 90 | else: 91 | search_batch = search_batch_input 92 | 93 | with torch.no_grad(): 94 | search_batch = search_batch.cuda() 95 | protoL_input_torch, proto_dist_torch = \ 96 | prototype_network_parallel.module.push_forward(search_batch) 97 | 98 | #protoL_input_ = np.copy(protoL_input_torch.detach().cpu().numpy()) 99 | proto_dist_ = np.copy(proto_dist_torch.detach().cpu().numpy()) 100 | 101 | for img_idx, distance_map in enumerate(proto_dist_): 102 | for j in range(n_prototypes): 103 | # find the closest patches in this batch to prototype j 104 | closest_patch_distance_to_prototype_j = np.amin(distance_map[j]) 105 | 106 | if full_save: 107 | closest_patch_indices_in_distance_map_j = \ 108 | list(np.unravel_index(np.argmin(distance_map[j],axis=None), 109 | distance_map[j].shape)) 110 | closest_patch_indices_in_distance_map_j = [0] + closest_patch_indices_in_distance_map_j 111 | closest_patch_indices_in_img = \ 112 | compute_rf_prototype(search_batch.size(2), 113 | closest_patch_indices_in_distance_map_j, 114 | protoL_rf_info) 115 | closest_patch = \ 116 | search_batch_input[img_idx, :, 117 | closest_patch_indices_in_img[1]:closest_patch_indices_in_img[2], 118 | closest_patch_indices_in_img[3]:closest_patch_indices_in_img[4]] 119 | closest_patch = closest_patch.numpy() 120 | closest_patch = np.transpose(closest_patch, (1, 2, 0)) 121 | 122 | original_img = search_batch_input[img_idx].numpy() 123 | original_img = np.transpose(original_img, (1, 2, 0)) 124 | 125 | if prototype_network_parallel.module.prototype_activation_function == 'log': 126 | act_pattern = np.log((distance_map[j] + 1)/(distance_map[j] + prototype_network_parallel.module.epsilon)) 127 | elif prototype_network_parallel.module.prototype_activation_function == 'linear': 128 | act_pattern = max_dist - distance_map[j] 129 | else: 130 | act_pattern = prototype_activation_function_in_numpy(distance_map[j]) 131 | 132 | # 4 numbers: height_start, height_end, width_start, width_end 133 | patch_indices = closest_patch_indices_in_img[1:5] 134 | 135 | # construct the closest patch object 136 | closest_patch = ImagePatch(patch=closest_patch, 137 | label=search_y[img_idx], 138 | conv_output=protoL_input_torch[img_idx], 139 | distance=closest_patch_distance_to_prototype_j, 140 | original_img=original_img, 141 | act_pattern=act_pattern, 142 | patch_indices=patch_indices) 143 | else: 144 | closest_patch = ImagePatchInfo(label=search_y[img_idx], 145 | distance=closest_patch_distance_to_prototype_j) 146 | 147 | 148 | # add to the j-th heap 149 | if len(heaps[j]) < k: 150 | heapq.heappush(heaps[j], closest_patch) 151 | else: 152 | # heappushpop runs more efficiently than heappush 153 | # followed by heappop 154 | heapq.heappushpop(heaps[j], closest_patch) 155 | 156 | # after looping through the dataset every heap will 157 | # have the k closest prototypes 158 | for j in range(n_prototypes): 159 | # finally sort the heap; the heap only contains the k closest 160 | # but they are not ranked yet 161 | heaps[j].sort() 162 | heaps[j] = heaps[j][::-1] 163 | 164 | if full_save: 165 | 166 | dir_for_saving_images = os.path.join(root_dir_for_saving_images, 167 | str(j)) 168 | makedir(dir_for_saving_images) 169 | 170 | labels = [] 171 | 172 | for i, patch in enumerate(heaps[j]): 173 | # save the activation pattern of the original image where the patch comes from 174 | np.save(os.path.join(dir_for_saving_images, 175 | 'nearest-' + str(i+1) + '_act.npy'), 176 | patch.act_pattern) 177 | 178 | # save the original image where the patch comes from 179 | plt.imsave(fname=os.path.join(dir_for_saving_images, 180 | 'nearest-' + str(i+1) + '_original.png'), 181 | arr=patch.original_img[:,:,0:3], 182 | vmin=0.0, 183 | vmax=1.0) 184 | 185 | # save the original image with the fa 186 | try: 187 | if patch.original_img.shape[2]==4: 188 | plt.imsave(fname=os.path.join(dir_for_saving_images, 189 | 'nearest-' + str(i+1) + '_original_with_fa.png'), 190 | arr=patch.original_img[:,:,1:4], 191 | vmin=0.0, 192 | vmax=1.0) 193 | except Exception as e: 194 | print("Exception in original img with fa: ", e) 195 | 196 | # overlay (upsampled) activation on original image and save the result 197 | img_size = patch.original_img.shape[0] 198 | upsampled_act_pattern = cv2.resize(patch.act_pattern, 199 | dsize=(img_size, img_size), 200 | interpolation=cv2.INTER_LINEAR) 201 | rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern) 202 | rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern) 203 | heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_pattern), cv2.COLORMAP_JET) 204 | heatmap = np.float32(heatmap) / 255 205 | heatmap = heatmap[...,::-1] 206 | overlayed_original_img = 0.5 * patch.original_img[:,:,0:3] + 0.3 * heatmap 207 | plt.imsave(fname=os.path.join(dir_for_saving_images, 208 | 'nearest-' + str(i+1) + '_original_with_heatmap.png'), 209 | arr=overlayed_original_img, 210 | vmin=0.0, 211 | vmax=1.0) 212 | 213 | # overlay activation on original image with fa and save the result 214 | try: 215 | if patch.original_img.shape[2]==4: 216 | img_size = patch.original_img.shape[0] 217 | upsampled_act_pattern = cv2.resize(patch.act_pattern, 218 | dsize=(img_size, img_size), 219 | interpolation=cv2.INTER_LINEAR) 220 | rescaled_act_pattern = upsampled_act_pattern - np.amin(upsampled_act_pattern) 221 | rescaled_act_pattern = rescaled_act_pattern / np.amax(rescaled_act_pattern) 222 | heatmap = cv2.applyColorMap(np.uint8(255*rescaled_act_pattern), cv2.COLORMAP_JET) 223 | heatmap = np.float32(heatmap) / 255 224 | heatmap = heatmap[...,::-1] 225 | overlayed_original_img = 0.5 * patch.original_img[:,:,1:4] + 0.3 * heatmap 226 | plt.imsave(fname=os.path.join(dir_for_saving_images, 227 | 'nearest-' + str(i+1) + '_original_with_fa_and_heatmap.png'), 228 | arr=overlayed_original_img, 229 | vmin=0.0, 230 | vmax=1.0) 231 | except Exception as e: 232 | print("Exception in overlay activation on original img with fa: ", e) 233 | 234 | # if different from original image, save the patch (i.e. receptive field) 235 | if patch.patch.shape[0] != img_size or patch.patch.shape[1] != img_size: 236 | np.save(os.path.join(dir_for_saving_images, 237 | 'nearest-' + str(i+1) + '_receptive_field_indices.npy'), 238 | patch.patch_indices) 239 | plt.imsave(fname=os.path.join(dir_for_saving_images, 240 | 'nearest-' + str(i+1) + '_receptive_field.png'), 241 | arr=patch.patch, 242 | vmin=0.0, 243 | vmax=1.0) 244 | # save the receptive field patch with heatmap 245 | overlayed_patch = overlayed_original_img[patch.patch_indices[0]:patch.patch_indices[1], 246 | patch.patch_indices[2]:patch.patch_indices[3], 0:3] 247 | plt.imsave(fname=os.path.join(dir_for_saving_images, 248 | 'nearest-' + str(i+1) + '_receptive_field_with_heatmap.png'), 249 | arr=overlayed_patch, 250 | vmin=0.0, 251 | vmax=1.0) 252 | 253 | # save the highly activated patch 254 | high_act_patch_indices = find_high_activation_crop(upsampled_act_pattern) 255 | high_act_patch = patch.original_img[high_act_patch_indices[0]:high_act_patch_indices[1], 256 | high_act_patch_indices[2]:high_act_patch_indices[3], 0:3] 257 | np.save(os.path.join(dir_for_saving_images, 258 | 'nearest-' + str(i+1) + '_high_act_patch_indices.npy'), 259 | high_act_patch_indices) 260 | plt.imsave(fname=os.path.join(dir_for_saving_images, 261 | 'nearest-' + str(i+1) + '_high_act_patch.png'), 262 | arr=high_act_patch, 263 | vmin=0.0, 264 | vmax=1.0) 265 | # save the original image with bounding box showing high activation patch 266 | imsave_with_bbox(fname=os.path.join(dir_for_saving_images, 267 | 'nearest-' + str(i+1) + '_high_act_patch_in_original_img.png'), 268 | img_rgb=patch.original_img[:,:,0:3], 269 | bbox_height_start=high_act_patch_indices[0], 270 | bbox_height_end=high_act_patch_indices[1], 271 | bbox_width_start=high_act_patch_indices[2], 272 | bbox_width_end=high_act_patch_indices[3], color=(0, 255, 255)) 273 | 274 | labels = np.array([patch.label for patch in heaps[j]]) 275 | np.save(os.path.join(dir_for_saving_images, 'class_id.npy'), 276 | labels) 277 | 278 | labels_all_prototype = np.array([[patch.label for patch in heaps[j]] for j in range(n_prototypes)]) 279 | neg_dists_all_prototype = np.array([[patch.negative_distance for patch in heaps[j]] for j in range(n_prototypes)]) 280 | 281 | if full_save: 282 | np.save(os.path.join(root_dir_for_saving_images, 'full_class_id.npy'), 283 | labels_all_prototype) 284 | #print(labels_all_prototype) 285 | np.save(os.path.join(root_dir_for_saving_images, 'full_class_neg_distance.npy'), 286 | neg_dists_all_prototype) 287 | end = time.time() 288 | log('\tfind nearest patches time: \t{0}'.format(end - start)) 289 | 290 | return labels_all_prototype 291 | --------------------------------------------------------------------------------