├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── assets ├── omega_distribution-progan.jpg ├── omega_distribution-progan.pdf └── r50_patches.jpeg ├── fmap_relevances ├── efb0 │ └── progan_val_0.5 │ │ ├── progan_val-fake.csv │ │ └── progan_val-real.csv └── resnet50 │ └── progan_val_0.5 │ ├── progan_val-fake.csv │ └── progan_val-real.csv ├── requirements.txt └── src ├── activation_histograms └── get_activation_values.py ├── cr_ud ├── datasets.py └── trainer.py ├── extract_max_activation_patches.py ├── find_color_conditional_percentage.py ├── find_topk.py ├── fmap_ranking ├── rank_fmaps_ud_efb0.py └── rank_fmaps_ud_r50.py ├── get_max_activation_rankings.py ├── gradcam.py ├── grayscale_detector_sensitivity_whisker_plots.py ├── grayscale_sensitivity_analysis.py ├── grayscale_sensitivity_whisker_plots.py ├── imagenet_lrp.py ├── imagenet_lrp_efb0.py ├── lrp ├── ef_lrp_general.py ├── ef_wrapper.py ├── resnet_lrp_general.py └── resnet_wrapper.py ├── median_test_activation_histograms.py ├── patch_collage.py ├── patch_extraction ├── extract_lrp_max_patches_using_filenames_efb0.py ├── extract_lrp_max_patches_using_filenames_resnet50.py └── get_top_activated_images.py ├── rank_fmaps.py ├── sensitivity_assessment ├── ap_sensitivity.py ├── color.py └── transferability.py ├── transfer_sensitivity_analysis.py ├── ud_lrp.py ├── ud_lrp_efb0.py └── utils ├── dataset_helpers.py ├── general.py ├── heatmap_helpers.py └── mask_fmaps.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | 132 | # json file to store wandb API key 133 | env.json 134 | 135 | weights/ 136 | results/ 137 | *.ipynb 138 | datasets/ 139 | output/ 140 | misc/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | ARG BASE_IMAGE=nvcr.io/nvidia/pytorch:20.12-py3 2 | FROM $BASE_IMAGE 3 | 4 | RUN pip install colour==0.1.5 5 | RUN pip install efficientnet-pytorch==0.7.1 6 | RUN pip install grad-cam==1.3.7 7 | RUN pip install matplotlib==3.3.4 8 | RUN pip install numpy 9 | RUN pip install opencv-python==4.5.5.62 10 | RUN pip install pandas==1.1.5 11 | RUN pip install scikit-learn==0.24.2 12 | RUN pip install scipy==1.5.4 13 | RUN pip install seaborn==0.11.2 14 | RUN pip install termcolor 15 | RUN pip install tqdm==4.62.3 16 | RUN pip3 install torch==1.8.2 torchvision==0.9.2 --extra-index-url https://download.pytorch.org/whl/lts/1.8/cu111 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Keshigeyan Chandrasegaran, Ngoc-Trung Tran, Alexander Binder, Ngai-Man Cheung 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /assets/omega_distribution-progan.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sutd-visual-computing-group/transferable-forensic-features/3e547351131b58672e85d09cecd314a4adddfb7d/assets/omega_distribution-progan.jpg -------------------------------------------------------------------------------- /assets/omega_distribution-progan.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sutd-visual-computing-group/transferable-forensic-features/3e547351131b58672e85d09cecd314a4adddfb7d/assets/omega_distribution-progan.pdf -------------------------------------------------------------------------------- /assets/r50_patches.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sutd-visual-computing-group/transferable-forensic-features/3e547351131b58672e85d09cecd314a4adddfb7d/assets/r50_patches.jpeg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | colour==0.1.5 2 | efficientnet-pytorch==0.7.1 3 | grad-cam==1.3.7 4 | matplotlib==3.3.4 5 | numpy 6 | opencv-python==4.5.5.62 7 | pandas==1.1.5 8 | scikit-learn==0.24.2 9 | scipy==1.5.4 10 | seaborn==0.11.2 11 | termcolor 12 | torch==1.8.2+cu111 13 | torchvision==0.9.2+cu111 14 | tqdm==4.62.3 15 | -------------------------------------------------------------------------------- /src/activation_histograms/get_activation_values.py: -------------------------------------------------------------------------------- 1 | 2 | # Import base libraries 3 | import os, sys, math, gc 4 | import PIL 5 | import copy 6 | import random, json 7 | from collections import OrderedDict 8 | 9 | # Import scientific computing libraries 10 | import numpy as np 11 | 12 | # Import torch and dependencies 13 | import torch 14 | from torchvision import models, transforms 15 | 16 | # Import utils 17 | from utils.heatmap_helpers import * 18 | from utils.dataset_helpers import * 19 | from utils.general import * 20 | 21 | # Import other libraries 22 | import matplotlib.pyplot as plt 23 | import pandas as pd 24 | from scipy import stats 25 | 26 | import seaborn as sns 27 | 28 | 29 | # Keep penultimate features as global varialble such that hook modifies these features 30 | penultimate_fts = None 31 | 32 | 33 | def get_penultimate_fts(self, input, output): 34 | global penultimate_fts 35 | penultimate_fts = output 36 | return None 37 | 38 | 39 | 40 | 41 | def get_fmap_activations(model, dls, device): 42 | global penultimate_fts 43 | penultimate_fts = None 44 | assert(penultimate_fts == None) 45 | 46 | model.eval() # Set model to eval mode 47 | 48 | m = torch.nn.AdaptiveMaxPool2d((1, 1)) # Look at maximum activation features 49 | 50 | # Store fts in a global array 51 | all_features = None 52 | fnames = [] 53 | 54 | with torch.no_grad(): 55 | for dl in dls: 56 | for batch_idx, data in enumerate(dl): 57 | 58 | x = data['image'].to(device) 59 | y = data['label'] 60 | fname = data['filename'] 61 | output = model(x) 62 | assert torch.is_tensor(penultimate_fts) 63 | 64 | fnames.extend(fname) 65 | 66 | if all_features is None: 67 | all_features = (m(penultimate_fts.data.clone().cpu()) ).numpy().squeeze() 68 | 69 | else: 70 | features = (m(penultimate_fts.data.clone().cpu()) ).numpy().squeeze() 71 | all_features = np.concatenate((all_features, features), axis=0) 72 | 73 | return all_features, fnames 74 | 75 | 76 | 77 | def linspace(start, stop, step=0.5): 78 | """ 79 | Like np.linspace but uses step instead of num 80 | This is inclusive to stop, so if start=1, stop=3, step=0.5 81 | Output is: array([1., 1.5, 2., 2.5, 3.]) 82 | """ 83 | return np.linspace(start, stop, int((stop - start) / step + 1)) 84 | 85 | 86 | def get_activation_values(feature_map_name, parent_dir, arch, gan_name, have_classes, aug, bsize, transform, num_instances=None): 87 | 88 | ## ------------Set Parameters-------- 89 | # Define device and other parameters 90 | device = torch.device('cuda:0') 91 | 92 | # model and weights 93 | if arch == 'resnet50': 94 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 95 | model = get_resnet50_universal_detector(weightfn).to(device) 96 | 97 | elif arch == 'efb0': 98 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 99 | model = get_efb0_universal_detector(weightfn).to(device) 100 | 101 | 102 | # Get feature map and layer 103 | feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 104 | layertobeattached = feature_map_name.split("#")[0][:-1] 105 | print(layertobeattached) 106 | 107 | # Attach hook to original model 108 | new_handles = [] 109 | for ind,(name,module) in enumerate(model.named_modules()): 110 | if name ==layertobeattached: 111 | print('name: {}'.format(name) ) 112 | h=module.register_forward_hook( get_penultimate_fts ) 113 | new_handles.append(h) 114 | 115 | 116 | #for idx, clss in enumerate(clsses): 117 | root_dir = os.path.join(parent_dir, gan_name) 118 | 119 | # Obtain D 120 | if have_classes: 121 | dl_fake = get_classwise_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=True) 122 | 123 | else: 124 | dl = get_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=True) 125 | dl_fake = [dl] 126 | 127 | #clss_specific_fts_real, y = get_class_specific_penultimate_fts(model, dl_real, device) 128 | fts_fake, fnames = get_fmap_activations(model, dl_fake, device) 129 | fts_fake = fts_fake[:, feature_map_idx] 130 | 131 | #print(clss_specific_fts_fake, fnames) 132 | 133 | # Sort and get the top activated 100 images 134 | top_images = len(fnames) 135 | max_act_vals = np.asarray(fts_fake) 136 | idx = (-max_act_vals.copy()).argsort()[:top_images] 137 | 138 | # Save df 139 | df = pd.DataFrame() 140 | df['name'] = [fnames[i] for i in idx] 141 | df['max_act'] = [ max_act_vals[i] for i in idx] 142 | #df.to_csv("image_rankings_progan/{}.csv".format(feature_map_name), index=None) 143 | 144 | return df, fts_fake -------------------------------------------------------------------------------- /src/cr_ud/datasets.py: -------------------------------------------------------------------------------- 1 | # The base script is borrowed from https://github.com/peterwang512/CNNDetection 2 | 3 | import cv2 4 | import numpy as np 5 | import torchvision.datasets as datasets 6 | import torchvision.transforms as transforms 7 | import torchvision.transforms.functional as TF 8 | from random import random, choice 9 | from io import BytesIO 10 | from PIL import Image 11 | from PIL import ImageFile 12 | from scipy.ndimage.filters import gaussian_filter 13 | 14 | 15 | ImageFile.LOAD_TRUNCATED_IMAGES = True 16 | 17 | def dataset_folder(opt, root): 18 | if opt.mode == 'binary': 19 | print("Using binary mode...") 20 | return binary_dataset(opt, root) 21 | if opt.mode == 'filename': 22 | return FileNameDataset(opt, root) 23 | raise ValueError('opt.mode needs to be binary or filename.') 24 | 25 | 26 | def binary_dataset(opt, root): 27 | if opt.isTrain: 28 | crop_func = transforms.RandomCrop(opt.cropSize) 29 | elif opt.no_crop: 30 | crop_func = transforms.Lambda(lambda img: img) 31 | else: 32 | crop_func = transforms.CenterCrop(opt.cropSize) 33 | 34 | if opt.isTrain and not opt.no_flip: 35 | flip_func = transforms.RandomHorizontalFlip() 36 | else: 37 | flip_func = transforms.Lambda(lambda img: img) 38 | if not opt.isTrain and opt.no_resize: 39 | rz_func = transforms.Lambda(lambda img: img) 40 | else: 41 | rz_func = transforms.Lambda(lambda img: custom_resize(img, opt)) 42 | 43 | print("Training with random grayscaling with p=50%...") 44 | dset = datasets.ImageFolder( 45 | root, 46 | transforms.Compose([ 47 | rz_func, 48 | transforms.Lambda(lambda img: data_augment(img, opt)), 49 | crop_func, 50 | flip_func, 51 | #transforms.Grayscale(3), # Test grayscale only detector 52 | transforms.RandomGrayscale(p=0.5), 53 | transforms.ToTensor(), 54 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 55 | ])) 56 | return dset 57 | 58 | 59 | class FileNameDataset(datasets.ImageFolder): 60 | def name(self): 61 | return 'FileNameDataset' 62 | 63 | def __init__(self, opt, root): 64 | self.opt = opt 65 | super().__init__(root) 66 | 67 | def __getitem__(self, index): 68 | # Loading sample 69 | path, target = self.samples[index] 70 | return path 71 | 72 | 73 | def data_augment(img, opt): 74 | img = np.array(img) 75 | 76 | if random() < opt.blur_prob: 77 | sig = sample_continuous(opt.blur_sig) 78 | gaussian_blur(img, sig) 79 | 80 | if random() < opt.jpg_prob: 81 | method = sample_discrete(opt.jpg_method) 82 | qual = sample_discrete(opt.jpg_qual) 83 | img = jpeg_from_key(img, qual, method) 84 | 85 | return Image.fromarray(img) 86 | 87 | 88 | def sample_continuous(s): 89 | if len(s) == 1: 90 | return s[0] 91 | if len(s) == 2: 92 | rg = s[1] - s[0] 93 | return random() * rg + s[0] 94 | raise ValueError("Length of iterable s should be 1 or 2.") 95 | 96 | 97 | def sample_discrete(s): 98 | if len(s) == 1: 99 | return s[0] 100 | return choice(s) 101 | 102 | 103 | def gaussian_blur(img, sigma): 104 | gaussian_filter(img[:,:,0], output=img[:,:,0], sigma=sigma) 105 | gaussian_filter(img[:,:,1], output=img[:,:,1], sigma=sigma) 106 | gaussian_filter(img[:,:,2], output=img[:,:,2], sigma=sigma) 107 | 108 | 109 | def cv2_jpg(img, compress_val): 110 | img_cv2 = img[:,:,::-1] 111 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), compress_val] 112 | result, encimg = cv2.imencode('.jpg', img_cv2, encode_param) 113 | decimg = cv2.imdecode(encimg, 1) 114 | return decimg[:,:,::-1] 115 | 116 | 117 | def pil_jpg(img, compress_val): 118 | out = BytesIO() 119 | img = Image.fromarray(img) 120 | img.save(out, format='jpeg', quality=compress_val) 121 | img = Image.open(out) 122 | # load from memory before ByteIO closes 123 | img = np.array(img) 124 | out.close() 125 | return img 126 | 127 | 128 | jpeg_dict = {'cv2': cv2_jpg, 'pil': pil_jpg} 129 | def jpeg_from_key(img, compress_val, key): 130 | method = jpeg_dict[key] 131 | return method(img, compress_val) 132 | 133 | 134 | rz_dict = {'bilinear': Image.BILINEAR, 135 | 'bicubic': Image.BICUBIC, 136 | 'lanczos': Image.LANCZOS, 137 | 'nearest': Image.NEAREST} 138 | def custom_resize(img, opt): 139 | interp = sample_discrete(opt.rz_interp) 140 | return TF.resize(img, opt.loadSize, interpolation=rz_dict[interp]) 141 | -------------------------------------------------------------------------------- /src/cr_ud/trainer.py: -------------------------------------------------------------------------------- 1 | # The base script is borrowed from https://github.com/peterwang512/CNNDetection 2 | 3 | import functools 4 | import torch 5 | import torch.nn as nn 6 | from networks.resnet import resnet50 7 | from networks.base_model import BaseModel, init_weights 8 | 9 | 10 | from efficientnet_pytorch import EfficientNet 11 | # https://github.com/lukemelas/EfficientNet-PyTorch 12 | 13 | 14 | class Trainer(BaseModel): 15 | def name(self): 16 | return 'Trainer' 17 | 18 | def __init__(self, opt): 19 | super(Trainer, self).__init__(opt) 20 | 21 | if self.isTrain and not opt.continue_train: 22 | if opt.arch == "resnet50": 23 | self.model = resnet50(pretrained=True) 24 | self.model.fc = nn.Linear(2048, 1) 25 | torch.nn.init.normal_(self.model.fc.weight.data, 0.0, opt.init_gain) 26 | 27 | elif opt.arch == "efficientnet-b0": 28 | self.model = EfficientNet.from_pretrained('efficientnet-b0', num_classes=1) 29 | 30 | else: 31 | raise NotImplementedError 32 | 33 | print("Using {} with transfer learning...".format(opt.arch)) 34 | 35 | if not self.isTrain or opt.continue_train: 36 | if opt.arch == "resnet50": 37 | self.model = resnet50(num_classes=1) 38 | elif opt.arch == "efficientnet-b0": 39 | self.model = EfficientNet.from_name('efficientnet-b0', num_classes=1) 40 | else: 41 | raise NotImplementedError 42 | 43 | print("Using {} without transfer learning...".format(opt.arch)) 44 | 45 | if self.isTrain: 46 | self.loss_fn = nn.BCEWithLogitsLoss() 47 | # initialize optimizers 48 | if opt.optim == 'adam': 49 | self.optimizer = torch.optim.Adam(self.model.parameters(), 50 | lr=opt.lr, betas=(opt.beta1, 0.999)) 51 | elif opt.optim == 'sgd': 52 | self.optimizer = torch.optim.SGD(self.model.parameters(), 53 | lr=opt.lr, momentum=0.0, weight_decay=0) 54 | else: 55 | raise ValueError("optim should be [adam, sgd]") 56 | 57 | if not self.isTrain or opt.continue_train: 58 | self.load_networks(opt.epoch) 59 | self.model.to(opt.gpu_ids[0]) 60 | 61 | 62 | def adjust_learning_rate(self, min_lr=1e-6): 63 | for param_group in self.optimizer.param_groups: 64 | param_group['lr'] /= 10. 65 | if param_group['lr'] < min_lr: 66 | return False 67 | return True 68 | 69 | def set_input(self, input): 70 | self.input = input[0].to(self.device) 71 | self.label = input[1].to(self.device).float() 72 | 73 | 74 | def forward(self): 75 | self.output = self.model(self.input) 76 | 77 | def get_loss(self): 78 | return self.loss_fn(self.output.squeeze(1), self.label) 79 | 80 | def optimize_parameters(self): 81 | self.forward() 82 | self.loss = self.loss_fn(self.output.squeeze(1), self.label) 83 | self.optimizer.zero_grad() 84 | self.loss.backward() 85 | self.optimizer.step() 86 | 87 | -------------------------------------------------------------------------------- /src/extract_max_activation_patches.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import imp 3 | import math 4 | from utils.general import get_all_channels 5 | 6 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 7 | 8 | # architecture 9 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 10 | 11 | # Data augmentation of model 12 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 13 | parser.add_argument('--blur_jpg', type=str, required=True) 14 | 15 | # other metrics 16 | parser.add_argument('--bsize', type=int, default=16) 17 | 18 | # dataset 19 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 20 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 21 | parser.add_argument('--have_classes', type=int, default=1) 22 | 23 | # Images per class for identifying channels 24 | parser.add_argument('--num_instances', type=int, default=5) 25 | 26 | # topk 27 | parser.add_argument('--topk', type=int, default=117) 28 | 29 | 30 | def main(): 31 | args = parser.parse_args() 32 | 33 | if(type(args.gan_name)==str): 34 | args.gan_name = [args.gan_name,] 35 | 36 | args.have_classes = bool(args.have_classes) 37 | 38 | print(args.gan_name, args.have_classes) 39 | 40 | # Get feature map names 41 | topk_channels, lowk_channels, all_channels = get_all_channels( 42 | fake_csv_path="fmap_relevances/{}/progan_val_{}/progan_val-fake.csv".format(args.arch, args.blur_jpg), 43 | topk=args.topk) 44 | 45 | # No need to get for all channels, just get for color channels 46 | if args.arch == 'resnet50': 47 | 48 | from patch_extraction.extract_lrp_max_patches_using_filenames_resnet50 import get_high_activation_patches 49 | for gan_name in args.gan_name: 50 | for feature_map_name in topk_channels: 51 | get_high_activation_patches(feature_map_name, args.arch, gan_name, args.blur_jpg, args.bsize, num_instances=args.num_instances) 52 | 53 | elif args.arch == 'efb0': 54 | 55 | from patch_extraction.extract_lrp_max_patches_using_filenames_efb0 import get_high_activation_patches 56 | for gan_name in args.gan_name: 57 | for feature_map_name in topk_channels: 58 | get_high_activation_patches(feature_map_name, args.arch, gan_name, args.blur_jpg, args.bsize, num_instances=args.num_instances) 59 | 60 | 61 | if __name__=='__main__': 62 | main() -------------------------------------------------------------------------------- /src/find_color_conditional_percentage.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import imp 3 | import math 4 | from utils.general import get_all_channels 5 | from activation_histograms.get_activation_values import get_activation_values 6 | 7 | import torch 8 | from torchvision import models, transforms 9 | 10 | from scipy import stats 11 | # from scipy imistats import ktest 12 | 13 | import seaborn as sns 14 | import pandas as pd 15 | import matplotlib.pyplot as plt 16 | 17 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 18 | 19 | # architecture 20 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 21 | 22 | # Data augmentation of model 23 | parser.add_argument('--blur_jpg', type=str, required=True) 24 | 25 | # dataset 26 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 27 | 28 | # topk 29 | parser.add_argument('--topk', type=int, default=114) 30 | 31 | 32 | def main(): 33 | args = parser.parse_args() 34 | 35 | if(type(args.gan_name)==str): 36 | args.gan_name = [args.gan_name,] 37 | 38 | for gan_name in args.gan_name: 39 | csv_path = 'output/median_test/{}_{}/{}.csv'.format(args.arch, args.blur_jpg, gan_name) 40 | df = pd.read_csv(csv_path) 41 | 42 | # Only some channels 43 | df = df.head(args.topk) 44 | 45 | total_top_channels = df.shape[0] 46 | 47 | df = df [ df['p_value'] <= 0.05] 48 | print(df) 49 | total_color_channels = df.shape[0] 50 | 51 | print(total_top_channels, total_color_channels) 52 | print("color-conditional channel %", gan_name, (total_color_channels / total_top_channels)*100.0) 53 | 54 | 55 | 56 | 57 | 58 | if __name__=='__main__': 59 | main() -------------------------------------------------------------------------------- /src/find_topk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sensitivity_assessment.ap_sensitivity import ap_sensitivity_analysis 3 | import math 4 | 5 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 6 | 7 | # architecture 8 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 9 | 10 | # Data augmentation of model 11 | # parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 12 | parser.add_argument('--blur_jpg', type=str, required=True) 13 | 14 | # other metrics 15 | parser.add_argument('--bsize', type=int, default=16) 16 | 17 | # dataset 18 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 19 | parser.add_argument('--gan_name', type=str, default='progan_val') 20 | parser.add_argument('--have_classes', type=bool, default=True) 21 | 22 | # Images per class for identifying channels 23 | parser.add_argument('--num_instances', type=int, default=5) 24 | 25 | # topk 26 | parser.add_argument('--topk', type=int, default=5) 27 | 28 | 29 | def main(): 30 | args = parser.parse_args() 31 | topk_list = generate_bissect_candidates(args.topk, steps=1) 32 | print("Generated topk search list using bissected intervals => ", topk_list) 33 | 34 | topk_list = [args.topk] 35 | ap_sensitivity_analysis(args.arch, 36 | args.dataset_dir, args.gan_name, args.blur_jpg, 37 | args.have_classes, args.bsize, num_instances=args.num_instances, topk_list=topk_list) 38 | 39 | 40 | 41 | def generate_bissect_candidates(topk, steps): 42 | list_of_vals = [0, topk] 43 | 44 | for i in range(steps): 45 | num_intervals = len(list_of_vals) - 1 46 | new_vals = [ int(math.ceil((list_of_vals[j] + list_of_vals[j+1])/2)) for j in range(num_intervals) ] 47 | list_of_vals.extend(new_vals) 48 | list_of_vals.sort() 49 | 50 | 51 | return list_of_vals 52 | 53 | 54 | if __name__=='__main__': 55 | main() -------------------------------------------------------------------------------- /src/fmap_ranking/rank_fmaps_ud_efb0.py: -------------------------------------------------------------------------------- 1 | # This piece of code requires high memory (RAM) for execution as relevance scores for every feature map 2 | # is stored in memory before aggregating. 3 | 4 | 5 | # Import base libraries 6 | import os, sys, math, gc 7 | from xml.dom import NotFoundErr 8 | import PIL 9 | import copy 10 | import random, json 11 | from collections import OrderedDict 12 | 13 | # Import scientific computing libraries 14 | import numpy as np 15 | 16 | # Import torch and dependencies 17 | import torch 18 | from torchvision import models, transforms 19 | import torchvision 20 | 21 | 22 | from efficientnet_pytorch import EfficientNet 23 | # from efficientnet_pytorch.utils import load_pretrained_weights 24 | 25 | # Import LRP modules 26 | from utils.heatmap_helpers import * 27 | from lrp.ef_lrp_general import * 28 | from lrp.ef_wrapper import * 29 | 30 | # Import utils modules 31 | from utils.dataset_helpers import * 32 | from utils.general import * 33 | 34 | # Import other libraries 35 | import matplotlib.pyplot as plt 36 | import pandas as pd 37 | from scipy import stats 38 | 39 | 40 | def get_wrapped_efficientnet_b0(weightpath, key, device): 41 | """ 42 | Get Wrapped ResNet50 model loaded into the device. 43 | 44 | Args: 45 | weightspath : path of berkeley classifier weights 46 | key : LRP key 47 | device : cuda or cpu to store the model 48 | 49 | Returns resnet50 pytorch object 50 | """ 51 | 52 | if key == 'beta0': 53 | #beta0 54 | lrp_params_def1={ 55 | 'conv2d_ignorebias': True, 56 | 'eltwise_eps': 1e-6, 57 | 'linear_eps': 1e-6, 58 | 'pooling_eps': 1e-6, 59 | 'use_zbeta': False , 60 | } 61 | 62 | lrp_layer2method={ 63 | 'Swish': relu_wrapper_fct, 64 | 'nn.BatchNorm2d': relu_wrapper_fct, 65 | 'nn.Conv2d': Conv2dDynamicSamePadding_beta0_wrapper_fct, 66 | 'nn.Linear': linearlayer_eps_wrapper_fct, 67 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 68 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 69 | } 70 | 71 | elif key == 'beta1': 72 | pass 73 | elif key == 'betaada': 74 | pass 75 | else: 76 | raise NotImplementedError("Unknown key", key) 77 | 78 | 79 | model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None,) 80 | somedict = torch.load(weightpath) 81 | model0.load_state_dict( somedict['model'] ) 82 | model0.eval() 83 | 84 | model_e = EfficientNet_canonized.from_pretrained('efficientnet-b0', num_classes=1, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 85 | 86 | model_e.copyfromefficientnet( model0, lrp_params_def1, lrp_layer2method) 87 | model_e.to(device) 88 | 89 | return model_e 90 | 91 | 92 | def get_lrp_explanations_for_batch(model, 93 | imagetensor, label, 94 | relfname , save, outpath, minus_fx=False): 95 | """ 96 | Get LRP explanations for a single sample. 97 | 98 | Args: 99 | model : pytorch model 100 | imagetensor : images 101 | label : label 102 | relfname : filenames 103 | save : If True, save LRP explanations 104 | outpath : output path to save pixel space explanations 105 | 106 | Get dict with positive relevances for the sample 107 | """ 108 | model.eval() 109 | 110 | all_lrp_explanations = [] 111 | 112 | os.makedirs(outpath, exist_ok=True) 113 | 114 | if imagetensor.grad is not None: 115 | imagetensor.grad.zero_() 116 | 117 | imagetensor.requires_grad=True # gxinp needs it here 118 | 119 | with torch.enable_grad(): 120 | outputs = model(imagetensor) 121 | 122 | with torch.no_grad(): 123 | probs = outputs.clone().sigmoid().flatten() 124 | print(probs) 125 | preds_labels = torch.where(probs>=0.5, 1.0, 0.0).long() 126 | correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 127 | 128 | if not correct_pred_indices.size(0) > 0: 129 | return all_lrp_explanations 130 | 131 | #Propagate the signals for the correctly predicted samples for LRP (We should get the same LRP results if we use all samples as well.) 132 | with torch.enable_grad(): 133 | if minus_fx: 134 | z = torch.sum( -outputs[correct_pred_indices, :] ) # Explain -f(x) if images are real 135 | 136 | else: 137 | z = torch.sum( outputs[correct_pred_indices, :] ) # Explain f(x) if images are fake 138 | 139 | with torch.no_grad(): 140 | z.backward(retain_graph=True) 141 | rel = imagetensor.grad.data.clone() 142 | 143 | for b in range(imagetensor.shape[0]): 144 | # Check for correct preds and skip incorrect preds. Look for high subject GAN confidence samples 145 | cond = (probs[b].item() >= 0.90 and label[b].item() == 1) or (probs[b].item() <= 0.10 and label[b].item() == 0) 146 | 147 | if not cond: 148 | continue 149 | 150 | fn = relfname[b] 151 | lrp_explanations = {} 152 | lrp_explanations['relfname'] = relfname[b] 153 | lrp_explanations['prob'] = probs[b].item() 154 | 155 | for i, (name, mod) in enumerate(model.named_modules()): 156 | if hasattr(mod, 'relfromoutput'): 157 | v = getattr(mod, 'relfromoutput') 158 | 159 | ftrelevances = v[b,:] 160 | 161 | # Save feature relevances to LRP explanations dict. Move to cpu since data is big. 162 | lrp_explanations[name] = ftrelevances.detach().cpu() 163 | 164 | # Save LRP explanations with images as png files and also lrp explanations only as .pt files 165 | if save: 166 | if label[b].item() == 0: 167 | vis_dir_name = os.path.join(outpath, "visualization", "0_real") 168 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-') ).replace('.png', '.pdf') 169 | os.makedirs(vis_dir_name, exist_ok=True) 170 | 171 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 172 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 173 | q=100, outname=vis_fname ) 174 | 175 | # Store LRP values 176 | lrp_dir_name = os.path.join(outpath, "lrp", "0_real") 177 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 178 | os.makedirs(lrp_dir_name, exist_ok=True) 179 | torch.save(torch.sum(rel, dim=0).cpu(), lrp_fname) 180 | 181 | else: 182 | vis_dir_name = os.path.join(outpath, "visualization", "1_fake") 183 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-') ).replace('.png', '.pdf') 184 | os.makedirs(vis_dir_name, exist_ok=True) 185 | 186 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 187 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 188 | q=100, outname=vis_fname) 189 | 190 | lrp_dir_name = os.path.join(outpath, "lrp", "1_fake") 191 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 192 | os.makedirs(lrp_dir_name, exist_ok=True) 193 | torch.save(torch.sum(rel, dim=0).cpu(), lrp_fname) 194 | 195 | 196 | all_lrp_explanations.append(lrp_explanations) 197 | 198 | torch.cuda.empty_cache() 199 | gc.collect() 200 | 201 | del ftrelevances 202 | 203 | return all_lrp_explanations 204 | 205 | 206 | def get_all_lrp_positive_explanations(model, dataloader, device, outpath, save, minus_fx): 207 | """ 208 | Get all LRP explanations for one folder. 209 | 210 | Args: 211 | model : resnet50 pytorch model 212 | dataloader : pytorch dataloader 213 | device : 214 | outpath : output path to save visualization and lrp numpy files 215 | save : If set to True, save 216 | minus_fx : If set to True, we will use -f(x) signal to calculate relevances 217 | 218 | Returns all LRP explanations 219 | """ 220 | # Global variable to store feature map information 221 | all_lrp_explanations = [] 222 | 223 | # > Explain prediction 224 | for index, data in enumerate(dataloader): 225 | # Get image tensor, filename, stub and labels 226 | imagetensors = data['image'].to(device) 227 | fnames = data['filename'] 228 | relfnames = data['relfilestub'] 229 | labels = data['label'].to(device) 230 | 231 | # Get LRP explanations 232 | lrp_explanations = get_lrp_explanations_for_batch(model, imagetensors, labels, relfnames, save, outpath, minus_fx) 233 | 234 | # Get LRP heatmap for all layers 235 | all_lrp_explanations.extend(lrp_explanations) 236 | 237 | torch.cuda.empty_cache() 238 | gc.collect() 239 | del lrp_explanations 240 | 241 | return all_lrp_explanations 242 | 243 | 244 | def normalize_relevances(lrp_explanations, only_positive=False, eps=1e-6): 245 | """ 246 | Sample wise Normalize LRP explanations. We use two schemes. 247 | Scheme 1: Use only positive relevances for normalization 248 | Scheme 2 (Used in the submissiom): Use absolute relevances for normalization 249 | Args: 250 | lrp_explanations : Relevance values 251 | only_positive : If set to True, use scheme 1. If False, use scheme 2 252 | eps : An epsilon value to avoid division by zero 253 | Returns: 254 | Sample wise normalized lrp explanations 255 | """ 256 | layer_names = lrp_explanations[0].keys() 257 | assert only_positive == False 258 | 259 | for layer_name in layer_names: 260 | if layer_name in ['relfname', 'prob', 'f(x)', 'y']: 261 | continue 262 | 263 | for i in range(len(lrp_explanations)): 264 | if only_positive: 265 | normalization_constant = torch.sum( torch.maximum( torch.tensor([0.0]), lrp_explanations[i][layer_name] ) ) 266 | 267 | if torch.allclose(normalization_constant, torch.tensor([0.0]), atol=1e-8): 268 | lrp_explanations[i][layer_name] = torch.zeros(size=lrp_explanations[i][layer_name].size()) 269 | 270 | else: 271 | lrp_explanations[i][layer_name] = torch.div( torch.maximum( torch.tensor([0.0]), lrp_explanations[i][layer_name] ), 272 | normalization_constant ) 273 | assert torch.allclose( torch.sum(lrp_explanations[i][layer_name]), torch.tensor([1.0]) ), "{}:{}".format(normalization_constant.item(), layer_name) 274 | 275 | 276 | else: 277 | normalization_constant = torch.sum( torch.abs( lrp_explanations[i][layer_name]) ) # Total amount of "evidence" 278 | #print(layer_name, lrp_explanations[i][layer_name].size(), normalization_constant.item(), torch.sum( lrp_explanations[i][layer_name] ).item()) 279 | 280 | # Now look at the ratio of positive evidence given the total absolute evidence. 281 | lrp_explanations[i][layer_name] = torch.div( torch.maximum( torch.tensor([0.0]), lrp_explanations[i][layer_name] ), 282 | normalization_constant ) 283 | 284 | return lrp_explanations 285 | 286 | 287 | def calculate_channelwise_stats(normalized_lrp_explanations): 288 | """ 289 | This is the pipeline to samplewise normalize the relevances and obtain the topk discriminative channels 290 | 291 | Args: 292 | normalized_lrp_explanations : Take \hat{R} and return the channelwsie stats 293 | """ 294 | 295 | layer_names = normalized_lrp_explanations[0].keys() 296 | channelwise_stats = {} 297 | 298 | for layer_name in layer_names: 299 | if layer_name in ['relfname', 'prob']: 300 | continue 301 | 302 | list_norm = [ normalized_lrp_explanations[i][layer_name] for i in range(len(normalized_lrp_explanations)) ] 303 | stack_norm = torch.stack(list_norm) 304 | 305 | average_fmap_relevances = torch.sum( torch.mean(stack_norm, dim=0), dim=(1, 2)) # We obtain the averaged feauture map, then sum over h, w 306 | 307 | for channel_idx in range(average_fmap_relevances.shape[0]): 308 | channelwise_stats["{}.#{}(T={})".format(layer_name, channel_idx, average_fmap_relevances.shape[0])] = average_fmap_relevances[channel_idx] 309 | 310 | return channelwise_stats 311 | 312 | 313 | def pipeline(model_e, dl, device, outpath, save, 314 | num_instances, 315 | minus_fx, normalize_only_using_positive, topk): 316 | """ 317 | Pipeline to run overall algorithm for real or fake images 318 | 319 | Args: 320 | model_e : Wrapped resnet50 321 | dl : dataloader 322 | device : cuda/ cpu 323 | outpath : output path to save the channelwise stats 324 | save : If set to True, all LRP relevances are saved as .pt files 325 | minus_fx: Needs to be set to True for real images 326 | normalize_using_only_positive : If set to True, scheme 1 else scheme 2. 327 | topk : #topk feature maps to return. 328 | """ 329 | # Get LRP explanations 330 | blockPrint() 331 | all_lrp_explanations = get_all_lrp_positive_explanations(model_e, dl, device, outpath, save, minus_fx) 332 | enablePrint() 333 | 334 | if len(all_lrp_explanations) == 0: 335 | print("issue here") 336 | return None, None 337 | 338 | # Normalize 339 | norm_lrp = normalize_relevances(all_lrp_explanations, only_positive=normalize_only_using_positive) 340 | 341 | # Calculate channelwise stats 342 | channelwise_stats = calculate_channelwise_stats(norm_lrp) 343 | 344 | # Calulate top k channels 345 | channelwise_stats = sorted(channelwise_stats.items(), key=lambda x: x[1], reverse=True) 346 | topk_channelwise_stats = channelwise_stats[:topk] 347 | 348 | final_channelwise_stats = OrderedDict() 349 | for i in range(len(topk_channelwise_stats)): 350 | key, value = topk_channelwise_stats[i][0], topk_channelwise_stats[i][1] 351 | final_channelwise_stats[key] = value 352 | 353 | #print(final_channelwise_stats) 354 | return final_channelwise_stats, len(all_lrp_explanations) 355 | 356 | 357 | 358 | def save_results_as_csv(channelwise_stats, key1, key2, save_name): 359 | """ 360 | Save as CSV. Preliminary Analysis can be done first using excel before automation 361 | 362 | Args: 363 | channelwise_stats : dict of stats 364 | key1 : name of layers 365 | key2 : Real or fake 366 | """ 367 | df = pd.DataFrame(columns=[key1, key2]) 368 | df[key1] = channelwise_stats.keys() 369 | df[key2] = [i.item() for i in channelwise_stats.values()] 370 | df.to_csv('{}.csv'.format(save_name), index=False) 371 | return df 372 | 373 | 374 | def rank_fmaps(parent_dir, gan_name, 375 | aug, have_classes, 376 | bsize, 377 | num_instances_real, num_instances_fake, 378 | save_pt_files, 379 | normalize_only_using_positive = False, 380 | topk=None): 381 | ## ------------Set Parameters-------- 382 | # Define device and other parameters 383 | device = torch.device('cuda:0') 384 | 385 | # LRP model keys and weights 386 | key = 'beta0' # 'beta0' 'beta1' , 'betaada' 387 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 388 | 389 | # Directories 390 | root_dir = os.path.join(parent_dir, gan_name) 391 | outpath_channelwise = os.path.join('./fmap_relevances/', 'efb0', '{}_{}'.format(gan_name, aug)) 392 | outpath = './output/efb0/{}_{}/'.format(gan_name, aug) # For saving .pt files if required 393 | 394 | ## ------------End of Parameters-------- 395 | 396 | # Model 397 | blockPrint() 398 | model_e = get_wrapped_efficientnet_b0(weightfn, key, device) 399 | enablePrint() 400 | print("> LRP wrapped ResNet-50 loaded successfully") 401 | 402 | def writeintomodule_bwhook(self,grad_input, grad_output): 403 | #gradoutput is what arrives from above, shape id eq to output 404 | setattr(self,'relfromoutput', grad_output[0]) 405 | 406 | # Register hook 407 | for i, (name,mod) in enumerate(model_e.named_modules()): 408 | #print(i,name ) 409 | # if (('conv' in name) and ('module' not in name)) or (name in ['layer1', 'layer2', 'layer3', 'layer4']) 410 | 411 | if ((('conv' in name) or ('downsample.0' in name)) and ('module' not in name)): 412 | # print(name, 'ok') 413 | mod.register_full_backward_hook(writeintomodule_bwhook) # modify to full_backward_hook 414 | 415 | print("> All backward hooks registered successfully") 416 | 417 | # Dataset (Use a Random Resized Crop so that you don't ignore any boundary artifacts on images) 418 | # Do note you might expect very small numerical changes when you repeat this experiments due to random resized crop 419 | # But the resulting ranking and topk feature maps should be the same. 420 | 421 | transform = transforms.Compose([ 422 | torchvision.transforms.RandomResizedCrop(224, 423 | scale=(0.99, 1.0), ratio=(0.99, 1.00), interpolation=torchvision.transforms.InterpolationMode.BILINEAR), 424 | transforms.ToTensor(), 425 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 426 | ]) 427 | 428 | # Obtain D_real and D_fake 429 | dl_real = get_dataloader(root_dir, have_classes, num_instances_real , transform, bsize, onlyreal=True, onlyfake=False) 430 | dl_fake = get_dataloader(root_dir, have_classes, num_instances_fake , transform, bsize, onlyreal=False, onlyfake=True) 431 | print("> Dataloaders loaded successfully") 432 | 433 | # Pass to the overall algorithm pipeline to obtain C_real_topk, C_fake_topk 434 | print("> Calculating feature map relevances (This will take some time)") 435 | real_channelwise_stats, num_real_samples = pipeline(model_e, dl_real, device, outpath, save=save_pt_files, 436 | num_instances=num_instances_real, 437 | minus_fx=True, normalize_only_using_positive=normalize_only_using_positive, topk=topk) 438 | model_e.eval() 439 | fake_channelwise_stats, num_fake_samples = pipeline(model_e, dl_fake, device, outpath, save=save_pt_files, 440 | num_instances=num_instances_fake, 441 | minus_fx=False, normalize_only_using_positive=normalize_only_using_positive, topk=topk) 442 | 443 | # Save as csv for seperate analysis 444 | print("> Ranking feature map relevances") 445 | outpath_channelwise = os.path.join(outpath_channelwise) 446 | os.makedirs(outpath_channelwise, exist_ok=True) 447 | _ = save_results_as_csv(real_channelwise_stats, 'key', 'mean_relevance', os.path.join(outpath_channelwise, 448 | '{}-real'.format(gan_name)) ) 449 | _ = save_results_as_csv(fake_channelwise_stats, 'key', 'mean_relevance', os.path.join(outpath_channelwise, 450 | '{}-fake'.format(gan_name)) ) 451 | 452 | print( "> Completed..." ) 453 | print( "> #real used = {}, #fake used = {}".format(num_real_samples, num_fake_samples) ) 454 | 455 | return 456 | 457 | 458 | if __name__=='__main__': 459 | rank_fmaps() 460 | -------------------------------------------------------------------------------- /src/get_max_activation_rankings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import imp 3 | from patch_extraction.get_top_activated_images import get_activation_rankings 4 | import math 5 | from utils.general import get_all_channels 6 | 7 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 8 | 9 | # architecture 10 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 11 | 12 | # Data augmentation of model 13 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 14 | parser.add_argument('--blur_jpg', type=str, required=True) 15 | 16 | # other metrics 17 | parser.add_argument('--bsize', type=int, default=16) 18 | 19 | # dataset 20 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 21 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 22 | parser.add_argument('--have_classes', type=int, default=1) 23 | 24 | # Images per class for identifying channels 25 | parser.add_argument('--num_instances', type=int, default=5) 26 | 27 | # topk 28 | parser.add_argument('--topk', type=int, default=114) 29 | 30 | 31 | def main(): 32 | args = parser.parse_args() 33 | 34 | if(type(args.gan_name)==str): 35 | args.gan_name = [args.gan_name,] 36 | 37 | args.have_classes = bool(args.have_classes) 38 | 39 | print(args.gan_name, args.have_classes) 40 | 41 | # Get feature map names 42 | topk_channels, lowk_channels, all_channels = get_all_channels( 43 | fake_csv_path="fmap_relevances/{}/progan_val_{}/progan_val-fake.csv".format(args.arch, args.blur_jpg), 44 | topk=args.topk) 45 | 46 | for gan_name in args.gan_name: 47 | for feature_map_name in topk_channels: 48 | get_activation_rankings(feature_map_name, args.arch, args.dataset_dir, 49 | gan_name, args.blur_jpg, args.have_classes, args.bsize, num_instances=None) 50 | 51 | 52 | if __name__=='__main__': 53 | main() -------------------------------------------------------------------------------- /src/gradcam.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import torch 5 | from torchvision import models 6 | from pytorch_grad_cam import GradCAM, \ 7 | ScoreCAM, \ 8 | GradCAMPlusPlus, \ 9 | AblationCAM, \ 10 | XGradCAM, \ 11 | EigenCAM, \ 12 | EigenGradCAM, \ 13 | LayerCAM, \ 14 | FullGrad 15 | from pytorch_grad_cam import GuidedBackpropReLUModel 16 | from pytorch_grad_cam.utils.image import show_cam_on_image, \ 17 | deprocess_image, \ 18 | preprocess_image 19 | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget 20 | 21 | 22 | import torch, os 23 | from PIL import Image 24 | from torchvision import transforms 25 | import numpy as np 26 | import matplotlib.pyplot as plt 27 | 28 | from utils.heatmap_helpers import * 29 | 30 | from efficientnet_pytorch import EfficientNet 31 | 32 | 33 | import argparse 34 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 35 | 36 | # architecture 37 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 38 | 39 | # architecture 40 | parser.add_argument('--classifier', type=str, required=True, choices=['imagenet', 'ud']) 41 | 42 | args = parser.parse_args() 43 | 44 | def load_img_as_tensor(path): 45 | pil_image = Image.open(path).convert("RGB") 46 | 47 | pil_transform = transforms.Compose([ 48 | transforms.Resize(256), 49 | transforms.CenterCrop(224), 50 | #transforms.ToTensor(), 51 | #transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) 52 | ]) 53 | 54 | ud_transforms = transforms.Compose([ 55 | transforms.Resize(256), 56 | transforms.CenterCrop(224), 57 | transforms.ToTensor(), 58 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 59 | ]) 60 | return pil_transform(pil_image), ud_transforms(pil_image).unsqueeze(0) 61 | 62 | 63 | def get_ud_model_r50(path): 64 | model = models.resnet50(pretrained=False, num_classes=1) 65 | model.load_state_dict(torch.load(path)['model']) 66 | model.eval() 67 | return model 68 | 69 | 70 | def get_imagenet_model_r50(): 71 | model = models.resnet50(pretrained=True) 72 | model.eval() 73 | return model 74 | 75 | 76 | def get_ud_model_efb0(path): 77 | model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None) 78 | somedict = torch.load(path) 79 | model0.load_state_dict( somedict['model'] ) 80 | model0.eval() 81 | return model0 82 | 83 | 84 | def get_imagenet_model_efb0(): 85 | model0 = EfficientNet.from_pretrained('efficientnet-b0', image_size=None) 86 | model0.eval() 87 | return model0 88 | 89 | 90 | def center_crop(image, h=224, w=224): 91 | center = [ image.shape[0] // 2, image.shape[1] // 2 ] 92 | #print(center) 93 | x = int(center[1] - w/2) 94 | y = int(center[0] - h/2) 95 | #print(x, y) 96 | return image[y:y+h, x:x+w] 97 | 98 | 99 | def my_deprocess_image(img, savename): 100 | """ see https://github.com/jacobgil/keras-grad-cam/blob/master/grad-cam.py#L65 """ 101 | 102 | img = img - np.mean(img) 103 | img = img / (np.std(img) + 1e-5) 104 | img = np.clip(img, 0, None) # consider only positive values 105 | 106 | 107 | #red_channel_sum = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Get the grayscale values 108 | red_channel_sum = (img[:, :, 0] + img[:, :, 1] + img[:, :, 2]) 109 | img[:, :, 0] = red_channel_sum/np.max(red_channel_sum) # Copy the grayscale value to red channel 110 | img[:, :, 1] = np.zeros((224, 224)) # Only red color is allowed 111 | img[:, :, 2] = np.zeros((224, 224)) # Only red color is allowed 112 | img = np.clip(img, 0, 1) 113 | 114 | cmap = plt.cm.seismic 115 | plt.imsave("{}".format(savename), img[:, :, 0], cmap=cmap, vmin=-1.0, vmax=1.0, format='pdf') 116 | 117 | 118 | def get_heatmaps( 119 | image_path, 120 | rgb_img, 121 | input_tensor, 122 | prob, 123 | model, 124 | method, 125 | use_cuda, 126 | gan_name, 127 | type, 128 | clss, 129 | arch, 130 | classifier): 131 | 132 | methods = \ 133 | {"gradcam": GradCAM, 134 | "scorecam": ScoreCAM, 135 | "gradcam++": GradCAMPlusPlus, 136 | "ablationcam": AblationCAM, 137 | "xgradcam": XGradCAM, 138 | "eigencam": EigenCAM, 139 | "eigengradcam": EigenGradCAM, 140 | "layercam": LayerCAM, 141 | "fullgrad": FullGrad} 142 | 143 | 144 | if arch == 'resnet50': 145 | target_layers = [model.layer4] 146 | else: 147 | target_layers = [model._conv_head] 148 | 149 | #print(target_layers) 150 | rgb_img = np.float32(rgb_img.copy()) / 255 151 | 152 | 153 | # We have to specify the target we want to generate 154 | # the Class Activation Maps for. 155 | # If targets is None, the highest scoring category (for every member in the batch) will be used. 156 | # You can target specific categories by 157 | # targets = [e.g ClassifierOutputTarget(281)] 158 | targets = None 159 | #targets = [ClassifierOutputTarget(218)] # sorrel 160 | #targets = [ClassifierOutputTarget(625)] # lifeboat 161 | 162 | # Using the with statement ensures the context is freed, and you can 163 | # recreate different CAM objects in a loop. 164 | cam_algorithm = methods[method] 165 | with cam_algorithm(model=model, 166 | target_layers=target_layers, 167 | use_cuda=True) as cam: 168 | 169 | # AblationCAM and ScoreCAM have batched implementations. 170 | # You can override the internal batch size for faster computation. 171 | cam.batch_size = 32 172 | grayscale_cam = cam(input_tensor=input_tensor, 173 | targets=targets, 174 | # aug_smooth=args.aug_smooth, 175 | # eigen_smooth=args.eigen_smooth 176 | ) 177 | 178 | # Here grayscale_cam has only one image in the batch 179 | grayscale_cam = grayscale_cam[0, :] 180 | 181 | cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=False, colormap=cv2.COLORMAP_OCEAN) 182 | 183 | cam_image = cv2.cvtColor(cam_image, cv2.COLOR_RGB2BGR) 184 | 185 | gb_model = GuidedBackpropReLUModel(model=model, use_cuda=use_cuda) 186 | gb = gb_model(input_tensor, target_category=None) 187 | 188 | cam_mask = cv2.merge([grayscale_cam, grayscale_cam, grayscale_cam]) 189 | cam_gb = deprocess_image(cam_mask * gb) 190 | gb = deprocess_image(gb) 191 | 192 | img_name = image_path.split('/')[-1].split(".")[0] + "-p={:.3f}".format(prob) 193 | save_dir = os.path.join("./output/hms/gradcam_heatmaps_{}_{}".format(classifier, arch ), gan_name, clss, type, img_name, method) 194 | os.makedirs(save_dir, exist_ok=True) 195 | 196 | plt.imsave(os.path.join(save_dir, "image.pdf"), rgb_img, format='pdf') 197 | 198 | plt.imsave(os.path.join(save_dir, "cam.pdf"), cv2.cvtColor(cam_image, cv2.COLOR_BGR2RGB), format='pdf') 199 | 200 | plt.imsave(os.path.join(save_dir, "gb.pdf"), cv2.cvtColor(gb, cv2.COLOR_BGR2RGB), format='pdf') 201 | 202 | plt.imsave(os.path.join(save_dir, "cam_gb.pdf"), cv2.cvtColor(cam_gb, cv2.COLOR_BGR2RGB), format='pdf') 203 | 204 | my_deprocess_image(cam_mask * gb, os.path.join(save_dir, "cam_gb_paper.pdf") ) 205 | 206 | # My heatmap 207 | # Combine rgb_image + (CAM * GB) 208 | gb_heatmap = cam_mask * gb 209 | rgb_image = rgb_img 210 | 211 | save_img_guided_gradcam_overlay_only_positive( gb_heatmap, rgb_image, q=100, title=None, outname=os.path.join(save_dir, "cam_gb_paper_final.pdf")) 212 | 213 | 214 | 215 | 216 | def main(): 217 | ## ------------Set Parameters-------- 218 | # Define device and other parameters 219 | device = torch.device('cuda:0') 220 | method = "gradcam" 221 | use_cuda=True 222 | 223 | # Directories 224 | parent_dir = '/mnt/data/CNN_synth_testset/' # Use our version 225 | gan_and_classes = {} 226 | 227 | # GAN names and classes (You can define the GAN and the corresponding classes) 228 | #gan_and_classes['progan_test'] = os.listdir(os.path.join(parent_dir, 'progan_test')) 229 | #gan_and_classes['progan_test'] = ['boat'] 230 | #gan_and_classes['biggan'] = ['bird'] 231 | gan_and_classes['stylegan2'] = ['church', 'car', 'cat'] 232 | #gan_and_classes['stylegan'] = ['car'] 233 | #gan_and_classes['cyclegan'] = ['horse'] 234 | #gan_and_classes['stargan'] = ['person'] 235 | #gan_and_classes['gaugan'] = ['mscoco'] 236 | 237 | 238 | # model and weights 239 | if args.arch == 'resnet50': 240 | if args.classifier == 'ud': 241 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(0.5) 242 | model = get_ud_model_r50(weightfn).to(device) 243 | else: 244 | model = get_imagenet_model_r50().to(device) 245 | 246 | elif args.arch == 'efb0': 247 | if args.classifier == 'ud': 248 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(0.5) 249 | model = get_ud_model_efb0(weightfn).to(device) 250 | else: 251 | model = get_imagenet_model_efb0().to(device) 252 | 253 | # print(model) 254 | 255 | 256 | for gan_name in gan_and_classes: 257 | for clss in gan_and_classes[gan_name]: 258 | for type in ['1_fake']: 259 | print(gan_name, clss) 260 | root_dir = os.path.join(parent_dir, gan_name, clss, type) 261 | 262 | sample_paths = [os.path.join(root_dir, i) for i in os.listdir(root_dir)] 263 | sample_paths.sort() 264 | 265 | for img_path in sample_paths[:500]: 266 | rgb_img, input_tensor = load_img_as_tensor(img_path) 267 | prob = (model(input_tensor.to(device))).sigmoid().item() 268 | #prob = torch.max(model(input_tensor.to(device))).item() 269 | 270 | 271 | get_heatmaps( 272 | img_path, 273 | rgb_img, 274 | input_tensor, 275 | prob, 276 | model, 277 | method, use_cuda, 278 | gan_name, type, clss, 279 | args.arch, args.classifier) 280 | 281 | 282 | 283 | if __name__=='__main__': 284 | main() -------------------------------------------------------------------------------- /src/grayscale_detector_sensitivity_whisker_plots.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | from turtle import position 3 | 4 | from click import style 5 | from sensitivity_assessment.color import all_metrics_sensitivity_analysis 6 | import math 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | plt.rcParams["font.family"] = "Times New Roman" 12 | 13 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 14 | 15 | # architecture 16 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 17 | 18 | # Data augmentation of model 19 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 20 | 21 | parser.add_argument('--blur_jpg', type=str, required=True) 22 | 23 | # other metrics 24 | parser.add_argument('--bsize', type=int, default=16) 25 | 26 | # dataset 27 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 28 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 29 | parser.add_argument('--have_classes', type=int, default=1) 30 | 31 | # Images per class for identifying channels 32 | parser.add_argument('--num_instances', type=int, default=5) 33 | 34 | 35 | 36 | def main(): 37 | args = parser.parse_args() 38 | 39 | if(type(args.gan_name)==str): 40 | args.gan_name = [args.gan_name,] 41 | 42 | args.have_classes = bool(args.have_classes) 43 | 44 | print(args.gan_name, args.have_classes) 45 | 46 | for gan_name in args.gan_name: 47 | (y_pred, y_true), (y_pred_grayscale, y_true_grayscale) = all_metrics_sensitivity_analysis(args.arch, 48 | args.dataset_dir, gan_name, args.blur_jpg, 49 | args.have_classes, args.bsize, num_instances=None) 50 | 51 | save_loc = os.path.join('output', 'whisker_plots', args.arch, str(args.blur_jpg)) 52 | os.makedirs(save_loc, exist_ok=True) 53 | 54 | # Plot and save 55 | plot_box_whisker_plots(y_pred, y_true, y_pred_grayscale, y_true_grayscale, save_loc, gan_name) 56 | 57 | 58 | 59 | def plot_box_whisker_plots(y_pred, y_true, y_pred_grayscale, y_true_grayscale, save_loc, gan_name): 60 | 61 | # Get GAN values 62 | y_pred_gan = y_pred[ y_true == 1 ] 63 | y_pred_gan_grayscale = y_pred_grayscale[ y_true_grayscale == 1 ] 64 | 65 | # sanity check 66 | ind1 = np.argwhere(y_true == 1 ).flatten() 67 | ind2 = np.argwhere(y_true_grayscale == 1 ).flatten() 68 | assert set(list(ind1)) == set(list(ind2)) 69 | 70 | # # Focus only on samples with >= 20% prob 71 | # y_pred_gan = y_pred_gan[ y_pred_gan_grayscale >= 0.20 ] 72 | # y_pred_gan_grayscale = y_pred_gan_grayscale[ y_pred_gan_grayscale >= 0.20 ] 73 | # print(y_pred_gan.shape, y_pred_gan_grayscale.shape) 74 | 75 | # Focus only on samples with >= 20% prob 76 | y_pred_gan_grayscale = y_pred_gan_grayscale[ y_pred_gan >= 0.20 ] 77 | y_pred_gan = y_pred_gan[ y_pred_gan >= 0.20 ] 78 | print(y_pred_gan.shape, y_pred_gan_grayscale.shape) 79 | 80 | # Look at percentage 81 | y_pred_gan *= 100.0 82 | y_pred_gan_grayscale *= 100.0 83 | 84 | # Create plot 85 | plt.rcParams['axes.xmargin'] = 0 86 | plt.rc('font', weight='bold') 87 | fig = plt.figure(figsize=(8, 10)) 88 | ax = fig.add_subplot(111) 89 | 90 | # Create data for boxplot 91 | data = [y_pred_gan, y_pred_gan_grayscale] 92 | m1, m2 = np.median(y_pred_gan), np.median(y_pred_gan_grayscale) 93 | #print(y_pred_gan.mean(), y_pred_gan_grayscale.mean(), (m1-m2)/2) 94 | 95 | # Define locations for annotations 96 | pos = [0.25, 0.40] 97 | y_loc = (m1-m2)/2 + m2 + 10 98 | y_loc = 50 99 | x_loc = (pos[0] + pos[1])/2 - 0.01 100 | #print( np.median(y_pred_gan), np.median(y_pred_gan_grayscale) ) 101 | 102 | # Create boxplot 103 | bplot = ax.boxplot(data, showfliers=True, positions=pos, widths=[0.1, 0.1], labels=['Baseline', 'Grayscale'], 104 | manage_ticks=False, patch_artist=True, medianprops=dict(color="r", linewidth=5)) 105 | #ax.axhline(m1, linestyle='--') 106 | #ax.axhline(m2, linestyle='--') 107 | 108 | # Annotate the drop in median prob 109 | text_for_drop = "{:.1f}%\ndrop".format((m1-m2)) 110 | ax.annotate( '', xy=(x_loc, m2 ), xytext=(x_loc, m1 ) , horizontalalignment="center", 111 | arrowprops=dict(arrowstyle='<->',lw=5, shrinkA = 0, shrinkB = 0, color='r') ) 112 | ax.annotate( text_for_drop, xy=(x_loc, m2 ), xytext=(x_loc+0.05, y_loc) , horizontalalignment="center", 113 | color='r', weight='bold', fontsize=50) 114 | 115 | # Now make it aesthetic 116 | ax.set_xlim(0.19, 0.46) 117 | ax.set_xticks(pos) 118 | ax.set_xticklabels(['Baseline', 'Grayscale'], fontsize=50, weight='bold') 119 | ax.set_yticks([0, 20, 40, 60, 80, 100]) 120 | ax.grid(visible=True) 121 | ax.tick_params(axis='y', labelsize=40) 122 | ax.set_ylabel("Probability (%)", fontsize=50, weight='bold') 123 | plt.tight_layout() 124 | 125 | # Add colors 126 | colors = ['#83c5be', '#b7b7a4', ] 127 | for patch, color in zip(bplot['boxes'], colors): 128 | patch.set_facecolor(color) 129 | 130 | # Save 131 | plt.savefig('{}/{}.pdf'.format(save_loc, gan_name), format='pdf', dpi=1200) 132 | #plt.show() 133 | 134 | 135 | 136 | if __name__=='__main__': 137 | main() -------------------------------------------------------------------------------- /src/grayscale_sensitivity_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sensitivity_assessment.color import all_metrics_sensitivity_analysis 3 | import math 4 | 5 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 6 | 7 | # architecture 8 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 9 | 10 | # Data augmentation of model 11 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 12 | parser.add_argument('--blur_jpg', type=str, required=True) 13 | 14 | # other metrics 15 | parser.add_argument('--bsize', type=int, default=16) 16 | 17 | # dataset 18 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 19 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 20 | parser.add_argument('--have_classes', type=int, default=1) 21 | 22 | # Images per class for identifying channels 23 | parser.add_argument('--num_instances', type=int, default=5) 24 | 25 | 26 | 27 | def main(): 28 | args = parser.parse_args() 29 | 30 | if(type(args.gan_name)==str): 31 | args.gan_name = [args.gan_name,] 32 | 33 | args.have_classes = bool(args.have_classes) 34 | 35 | print(args.gan_name, args.have_classes) 36 | 37 | for gan_name in args.gan_name: 38 | all_metrics_sensitivity_analysis(args.arch, 39 | args.dataset_dir, gan_name, args.blur_jpg, 40 | args.have_classes, args.bsize, num_instances=None) 41 | 42 | 43 | if __name__=='__main__': 44 | main() -------------------------------------------------------------------------------- /src/grayscale_sensitivity_whisker_plots.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | from turtle import position 3 | 4 | from click import style 5 | from sensitivity_assessment.color import all_metrics_sensitivity_analysis 6 | import math 7 | 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | 11 | plt.rcParams["font.family"] = "Times New Roman" 12 | 13 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 14 | 15 | # architecture 16 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 17 | 18 | # Data augmentation of model 19 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 20 | 21 | parser.add_argument('--blur_jpg', type=str, required=True) 22 | 23 | # other metrics 24 | parser.add_argument('--bsize', type=int, default=16) 25 | 26 | # dataset 27 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 28 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 29 | parser.add_argument('--have_classes', type=int, default=1) 30 | 31 | # Images per class for identifying channels 32 | parser.add_argument('--num_instances', type=int, default=5) 33 | 34 | 35 | 36 | def main(): 37 | args = parser.parse_args() 38 | 39 | if(type(args.gan_name)==str): 40 | args.gan_name = [args.gan_name,] 41 | 42 | args.have_classes = bool(args.have_classes) 43 | 44 | print(args.gan_name, args.have_classes) 45 | 46 | for gan_name in args.gan_name: 47 | (y_pred, y_true), (y_pred_grayscale, y_true_grayscale) = all_metrics_sensitivity_analysis(args.arch, 48 | args.dataset_dir, gan_name, args.blur_jpg, 49 | args.have_classes, args.bsize, num_instances=None) 50 | 51 | save_loc = os.path.join('output', 'whisker_plots', args.arch, str(args.blur_jpg)) 52 | os.makedirs(save_loc, exist_ok=True) 53 | 54 | # Plot and save 55 | plot_box_whisker_plots(y_pred, y_true, y_pred_grayscale, y_true_grayscale, save_loc, gan_name) 56 | 57 | 58 | 59 | def plot_box_whisker_plots(y_pred, y_true, y_pred_grayscale, y_true_grayscale, save_loc, gan_name): 60 | 61 | # Get GAN values 62 | y_pred_gan = y_pred[ y_true == 1 ] 63 | y_pred_gan_grayscale = y_pred_grayscale[ y_true_grayscale == 1 ] 64 | 65 | # sanity check 66 | ind1 = np.argwhere(y_true == 1 ).flatten() 67 | ind2 = np.argwhere(y_true_grayscale == 1 ).flatten() 68 | assert set(list(ind1)) == set(list(ind2)) 69 | 70 | # Focus only on samples with >= 20% prob 71 | y_pred_gan_grayscale = y_pred_gan_grayscale[ y_pred_gan >= 0.20 ] 72 | y_pred_gan = y_pred_gan[ y_pred_gan >= 0.20 ] 73 | print(y_pred_gan.shape, y_pred_gan_grayscale.shape) 74 | 75 | # Look at percentage 76 | y_pred_gan *= 100.0 77 | y_pred_gan_grayscale *= 100.0 78 | 79 | # Create plot 80 | plt.rcParams['axes.xmargin'] = 0 81 | plt.rc('font', weight='bold') 82 | fig = plt.figure(figsize=(8, 10)) 83 | ax = fig.add_subplot(111) 84 | 85 | # Create data for boxplot 86 | data = [y_pred_gan, y_pred_gan_grayscale] 87 | m1, m2 = np.median(y_pred_gan), np.median(y_pred_gan_grayscale) 88 | #print(y_pred_gan.mean(), y_pred_gan_grayscale.mean(), (m1-m2)/2) 89 | 90 | # Define locations for annotations 91 | pos = [0.25, 0.40] 92 | y_loc = (m1-m2)/2 + m2 + 10 93 | #y_loc = 80 94 | x_loc = (pos[0] + pos[1])/2 - 0.01 95 | #print( np.median(y_pred_gan), np.median(y_pred_gan_grayscale) ) 96 | 97 | # Create boxplot 98 | bplot = ax.boxplot(data, showfliers=True, positions=pos, widths=[0.1, 0.1], labels=['Baseline', 'Grayscale'], 99 | manage_ticks=False, patch_artist=True, medianprops=dict(color="r", linewidth=5)) 100 | #ax.axhline(m1, linestyle='--') 101 | #ax.axhline(m2, linestyle='--') 102 | 103 | # Annotate the drop in median prob 104 | text_for_drop = "{:.1f}%\ndrop".format((m1-m2)) 105 | ax.annotate( '', xy=(x_loc, m2 ), xytext=(x_loc, m1 ) , horizontalalignment="center", 106 | arrowprops=dict(arrowstyle='<->',lw=5, shrinkA = 0, shrinkB = 0, color='r') ) 107 | ax.annotate( text_for_drop, xy=(x_loc, m2 ), xytext=(x_loc+0.05, y_loc) , horizontalalignment="center", 108 | color='r', weight='bold', fontsize=50) 109 | 110 | # Now make it aesthetic 111 | ax.set_xlim(0.19, 0.46) 112 | ax.set_xticks(pos) 113 | ax.set_xticklabels(['Baseline', 'Grayscale'], fontsize=50, weight='bold') 114 | ax.set_yticks([0, 20, 40, 60, 80, 100]) 115 | ax.grid(visible=True) 116 | ax.tick_params(axis='y', labelsize=40) 117 | ax.set_ylabel("Probability (%)", fontsize=50, weight='bold') 118 | plt.tight_layout() 119 | 120 | # Add colors 121 | colors = ['#83c5be', '#b7b7a4', ] 122 | for patch, color in zip(bplot['boxes'], colors): 123 | patch.set_facecolor(color) 124 | 125 | # Save 126 | plt.savefig('{}/{}.pdf'.format(save_loc, gan_name), format='pdf', dpi=1200) 127 | #plt.show() 128 | 129 | 130 | 131 | if __name__=='__main__': 132 | main() -------------------------------------------------------------------------------- /src/imagenet_lrp.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | # Import lrp modules 16 | from lrp.resnet_lrp_general import * 17 | from lrp.resnet_wrapper import * 18 | 19 | # Import utils 20 | from utils.heatmap_helpers import * 21 | from utils.dataset_helpers import * 22 | from utils import * 23 | 24 | # Import other modules 25 | import copy 26 | from collections import OrderedDict 27 | import pandas as pd 28 | 29 | 30 | def get_wrapped_resnet50_imagenet(weightpath, key, device): 31 | """ 32 | Get Wrapped ResNet50 model loaded into the device. Written by Prof. Alex. 33 | 34 | Args: 35 | weightspath : path of berkeley classifier weights 36 | key : LRP key 37 | device : cuda or cpu to store the model 38 | 39 | Returns resnet50 pytorch object 40 | """ 41 | 42 | if key == 'beta0': 43 | #beta0 44 | lrp_params_def1={ 45 | 'conv2d_ignorebias': True, 46 | 'eltwise_eps': 1e-6, 47 | 'linear_eps': 1e-6, 48 | 'pooling_eps': 1e-6, 49 | 'use_zbeta': True , 50 | } 51 | 52 | lrp_layer2method={ 53 | 'nn.ReLU': relu_wrapper_fct, 54 | 'nn.BatchNorm2d': relu_wrapper_fct, 55 | 'nn.Conv2d': conv2d_beta0_wrapper_fct, 56 | 'nn.Linear': linearlayer_eps_wrapper_fct, 57 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 58 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 59 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 60 | } 61 | 62 | elif key == 'beta1': 63 | #beta1 64 | lrp_params_def1={ 65 | 'conv2d_ignorebias': True, 66 | 'eltwise_eps': 1e-6, 67 | 'linear_eps': 1e-6, 68 | 'pooling_eps': 1e-6, 69 | 'use_zbeta': True , 70 | 'conv2d_beta': 1.0, 71 | } 72 | 73 | lrp_layer2method={ 74 | 'nn.ReLU': relu_wrapper_fct, 75 | 'nn.BatchNorm2d': relu_wrapper_fct, 76 | 'nn.Conv2d': conv2d_betaany_wrapper_fct, 77 | 'nn.Linear': linearlayer_eps_wrapper_fct, 78 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 79 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 80 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 81 | } 82 | 83 | elif key == 'betaada': 84 | # betaada 85 | lrp_params_def1={ 86 | 'conv2d_ignorebias': True, 87 | 'eltwise_eps': 1e-6, 88 | 'linear_eps': 1e-6, 89 | 'pooling_eps': 1e-6, 90 | 'use_zbeta': True , 91 | 'conv2d_maxbeta': 4.5, 92 | } 93 | 94 | lrp_layer2method={ 95 | 'nn.ReLU': relu_wrapper_fct, 96 | 'nn.BatchNorm2d': relu_wrapper_fct, 97 | 'nn.Conv2d': conv2d_betaadaptive_wrapper_fct, 98 | 'nn.Linear': linearlayer_eps_wrapper_fct, 99 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 100 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 101 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 102 | } 103 | else: 104 | raise NotImplementedError("Unknown key", key) 105 | 106 | model0 = models.resnet50(pretrained=True) 107 | #model0.fc = nn.Linear(2048,1) 108 | 109 | #somedict = torch.load(weightpath) 110 | #model0.load_state_dict( somedict['model'] ) 111 | 112 | model_e = resnet50_canonized(pretrained=False) 113 | model_e.copyfromresnet(model0, lrp_params = lrp_params_def1, lrp_layer2method = lrp_layer2method ) 114 | model_e = model_e.to(device) 115 | 116 | return model_e 117 | 118 | 119 | 120 | def get_lrp_explanations_for_batch(model, 121 | imagetensor, label, 122 | relfname , save, outpath, minus_fx=False): 123 | """ 124 | Get LRP explanations for a single sample. 125 | 126 | Args: 127 | model : pytorch model 128 | imagetensor : images 129 | label : label 130 | relfname : filenames 131 | save : If True, save LRP explanations 132 | outpath : output path to save pixel space explanations 133 | 134 | Get dict with positive relevances for the sample 135 | """ 136 | model.eval() 137 | 138 | all_lrp_explanations = [] 139 | 140 | os.makedirs(outpath, exist_ok=True) 141 | 142 | if imagetensor.grad is not None: 143 | imagetensor.grad.zero_() 144 | 145 | imagetensor.requires_grad=True # gxinp needs it here 146 | 147 | with torch.enable_grad(): 148 | outputs = model(imagetensor) 149 | 150 | with torch.no_grad(): 151 | # probs = outputs.sigmoid().flatten() 152 | # preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 153 | # correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 154 | #print(correct_pred_indices, correct_pred_indices.size(0)) 155 | #print(outputs, outputs[correct_pred_indices, :]) 156 | 157 | probs = outputs.softmax(dim=1).flatten() 158 | #preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 159 | #correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 160 | _, predclasses = torch.max(outputs, 1) 161 | 162 | # if not correct_pred_indices.size(0) > 0: 163 | # return all_lrp_explanations 164 | 165 | #Propagate the signals for the correctly predicted samples for LRP (We should get the same LRP results if we use all samples as well.) 166 | with torch.enable_grad(): 167 | if minus_fx: 168 | z = torch.sum( -outputs[:, predclasses] ) # Explain -f(x) if images are real 169 | 170 | else: 171 | z = torch.sum( outputs[:, predclasses] ) # Explain f(x) if images are fake 172 | 173 | with torch.no_grad(): 174 | z.backward(retain_graph=True) 175 | rel = imagetensor.grad.data.clone() 176 | 177 | for b in range(imagetensor.shape[0]): 178 | # Check for correct preds and skip incorrect preds 179 | # cond = (probs[b].item() >= 0.5 and label[b].item() == 1) or (probs[b].item() < 0.5 and label[b].item() == 0) 180 | 181 | # if not cond: 182 | # continue 183 | 184 | fn = relfname[b] 185 | lrp_explanations = {} 186 | lrp_explanations['relfname'] = relfname[b] 187 | lrp_explanations['prob'] = probs[b].item() 188 | 189 | for i, (name, mod) in enumerate(model.named_modules()): 190 | if hasattr(mod, 'relfromoutput'): 191 | v = getattr(mod, 'relfromoutput') 192 | #print(i, name, v.shape) # conv rel map 193 | 194 | ftrelevances = v[b,:] 195 | 196 | # take only positives 197 | #ftrelevances[ftrelevances<0] = 0 198 | 199 | # Save feature relevances to LRP explanations dict. Move to cpu since data is big. 200 | #lrp_explanations[name] = ftrelevances.detach().cpu() 201 | 202 | # All LRP explanations 203 | if label[b].item() == 0: 204 | vis_dir_name = os.path.join(outpath, "visualization", "0_real") 205 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 206 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 207 | print(vis_fname) 208 | os.makedirs(vis_dir_name, exist_ok=True) 209 | 210 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 211 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 212 | q=100, outname=vis_fname ) 213 | 214 | # Store LRP values 215 | if save: 216 | lrp_dir_name = os.path.join(outpath, "lrp", "0_real") 217 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 218 | os.makedirs(lrp_dir_name, exist_ok=True) 219 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 220 | 221 | else: 222 | vis_dir_name = os.path.join(outpath, "visualization", "1_fake") 223 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 224 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 225 | os.makedirs(vis_dir_name, exist_ok=True) 226 | 227 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 228 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 229 | q=100, outname=vis_fname) 230 | 231 | if save: 232 | lrp_dir_name = os.path.join(outpath, "lrp", "1_fake") 233 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 234 | os.makedirs(lrp_dir_name, exist_ok=True) 235 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 236 | 237 | #all_lrp_explanations.append(lrp_explanations) 238 | 239 | torch.cuda.empty_cache() 240 | gc.collect() 241 | 242 | # del ftrelevances 243 | 244 | return all_lrp_explanations 245 | 246 | 247 | def get_all_lrp_positive_explanations(model, dataloader, device, outpath, save, minus_fx): 248 | """ 249 | Get all LRP explanations for one folder. 250 | 251 | Args: 252 | model : resnet50 pytorch model 253 | dataloader : pytorch dataloader 254 | device : 255 | outpath : output path to save visualization and lrp numpy files 256 | save : If set to True, save 257 | minus_fx : If set to True, we will use -f(x) signal to calculate relevances 258 | 259 | Returns all LRP explanations 260 | """ 261 | # Global variable to store feature map information 262 | all_lrp_explanations = [] 263 | 264 | # > Explain prediction 265 | for index, data in enumerate(dataloader): 266 | # Get image tensor, filename, stub and labels 267 | imagetensors = data['image'].to(device) 268 | fnames = data['filename'] 269 | relfnames = data['relfilestub'] 270 | labels = data['label'].to(device) 271 | 272 | # Get LRP explanations 273 | lrp_explanations = get_lrp_explanations_for_batch(model, imagetensors, labels, relfnames, save, outpath, minus_fx) 274 | 275 | # Get LRP heatmap for all layers 276 | all_lrp_explanations.extend(lrp_explanations) 277 | 278 | torch.cuda.empty_cache() 279 | gc.collect() 280 | del lrp_explanations 281 | 282 | return all_lrp_explanations 283 | 284 | 285 | 286 | def pipeline(model_e, dl, device, outpath, save, 287 | num_instances, 288 | minus_fx): 289 | """ 290 | Pipeline to run overall algorithm for real or fake images 291 | 292 | Args: 293 | model_e : Wrapped resnet50 294 | dl : dataloader 295 | device : cuda/ cpu 296 | outpath : output path to save the channelwise stats 297 | save : If set to True, all LRP relevances are saved as .pt files 298 | minus_fx: Needs to be set to True for real images 299 | normalize_using_only_positive : If set to True, scheme 1 else scheme 2. 300 | topk : #topk feature maps to return. 301 | """ 302 | # Get LRP explanations 303 | all_lrp_explanations = get_all_lrp_positive_explanations(model_e, dl, device, outpath, save, minus_fx)[:num_instances] 304 | 305 | #print(final_channelwise_stats) 306 | return all_lrp_explanations 307 | 308 | 309 | 310 | def main(): 311 | ## ------------Set Parameters-------- 312 | # Define device and other parameters 313 | device = torch.device('cuda:0') 314 | bsize = 16 315 | 316 | # LRP model keys and weights 317 | key = 'beta0' # 'beta0' 'beta1' , 'betaada' 318 | weightfn = './weights/blur_jpg_prob0.5.pth' 319 | 320 | # Directories 321 | #parent_dir = '/mnt/workspace/projects/deepfake_classifiers_interpretability/samples/' 322 | parent_dir = '/mnt/data/CNN_synth_testset/' # Use our version 323 | have_classes = False 324 | gan_and_classes = {} 325 | # gan_and_classes['biggan'] = ['beer_bottle', 'monitor', 'vase', 'table_lamp', 326 | # 'hummingbird', 'church', 'egyptian_cat', 'welsh_springer_spaniel'] 327 | 328 | #gan_and_classes['progan_test'] = os.listdir(os.path.join(parent_dir, 'progan_test')) 329 | #gan_and_classes['progan_test'] = ['boat'] 330 | #gan_and_classes['biggan'] = [ 'bird'] 331 | gan_and_classes['stylegan2'] = ['church', 'cat', 'car'] 332 | #gan_and_classes['cyclegan'] = ['horse'] 333 | #gan_and_classes['stylegan'] = ['car'] 334 | #gan_and_classes['cyclegan'] = ['apple', 'orange'] 335 | #gan_and_classes['stargan'] = ['person'] 336 | #gan_and_classes['gaugan'] = ['mscoco'] 337 | 338 | #gan_and_classes['san'] = [''] 339 | #gan_and_classes['stylegan2'] = ['car', 'cat', 'horse', 'church'] 340 | #gan_and_classes['cyclegan'] = ['horse'] 341 | #gan_and_classes['stargan'] = [''] 342 | 343 | for gan_name in gan_and_classes: 344 | for clss in gan_and_classes[gan_name]: 345 | print(gan_name, clss) 346 | root_dir = os.path.join(parent_dir, gan_name, clss) 347 | #clss='sr' 348 | outpath = './hms/imagenet_lrp_heatmaps/{}/{}/'.format(gan_name, clss) 349 | save_pt_files = False # No need to save .pt files. 350 | num_instances_real, num_instances_fake = 0, 500 # Use 1000 real and fake samples for analysis 351 | ## ------------End of Parameters-------- 352 | 353 | 354 | # Model 355 | model_e = get_wrapped_resnet50_imagenet(weightfn, key, device) 356 | 357 | def writeintomodule_bwhook(self,grad_input, grad_output): 358 | #gradoutput is what arrives from above, shape id eq to output 359 | setattr(self,'relfromoutput', grad_output[0]) 360 | 361 | # Register hook 362 | for i, (name,mod) in enumerate(model_e.named_modules()): 363 | #print(i,nm) 364 | if ('conv' in name) and ('module' not in name): 365 | #print('ok') 366 | mod.register_backward_hook(writeintomodule_bwhook) 367 | 368 | # Dataset (Use same transforms as Wang et. al without Center cropping) 369 | transform = transforms.Compose([ 370 | transforms.Resize(256), 371 | transforms.CenterCrop(224), 372 | transforms.ToTensor(), 373 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 374 | ]) 375 | 376 | # Obtain D_real and D_fake 377 | dl_real = get_dataloader(root_dir, have_classes, num_instances_real , transform, bsize, onlyreal=True, onlyfake=False) 378 | dl_fake = get_dataloader(root_dir, have_classes, num_instances_fake , transform, bsize, onlyreal=False, onlyfake=True) 379 | 380 | # Pass to the overall algorithm pipeline to obtain C_real_topk, C_fake_topk 381 | real_lrp_explanations = pipeline(model_e, dl_real, device, outpath, save=save_pt_files, 382 | num_instances=num_instances_real, 383 | minus_fx=False) 384 | fake_lrp_explanations = pipeline(model_e, dl_fake, device, outpath, save=save_pt_files, 385 | num_instances=num_instances_fake, 386 | minus_fx=False) 387 | 388 | 389 | 390 | 391 | return real_lrp_explanations, fake_lrp_explanations 392 | 393 | 394 | if __name__=='__main__': 395 | main() 396 | -------------------------------------------------------------------------------- /src/imagenet_lrp_efb0.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | from efficientnet_pytorch import EfficientNet 16 | # from efficientnet_pytorch.utils import load_pretrained_weights 17 | 18 | # Import LRP modules 19 | from utils.heatmap_helpers import * 20 | from lrp.ef_lrp_general import * 21 | from lrp.ef_wrapper import * 22 | 23 | # Import utils 24 | from utils.heatmap_helpers import * 25 | from utils.dataset_helpers import * 26 | from utils import * 27 | 28 | # Import other modules 29 | import copy 30 | from collections import OrderedDict 31 | import pandas as pd 32 | 33 | 34 | import argparse 35 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 36 | 37 | # architecture 38 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 39 | 40 | # architecture 41 | parser.add_argument('--classifier', type=str, required=True, choices=['imagenet', 'ud']) 42 | 43 | args = parser.parse_args() 44 | 45 | 46 | def get_wrapped_efficientnet_b0(weightpath, key, device, classifier): 47 | """ 48 | Get Wrapped ResNet50 model loaded into the device. 49 | 50 | Args: 51 | weightspath : path of berkeley classifier weights 52 | key : LRP key 53 | device : cuda or cpu to store the model 54 | 55 | Returns resnet50 pytorch object 56 | """ 57 | 58 | if key == 'beta0': 59 | #beta0 60 | lrp_params_def1={ 61 | 'conv2d_ignorebias': True, 62 | 'eltwise_eps': 1e-6, 63 | 'linear_eps': 1e-6, 64 | 'pooling_eps': 1e-6, 65 | 'use_zbeta': False , 66 | } 67 | 68 | lrp_layer2method={ 69 | 'Swish': relu_wrapper_fct, 70 | 'nn.BatchNorm2d': relu_wrapper_fct, 71 | 'nn.Conv2d': Conv2dDynamicSamePadding_beta0_wrapper_fct, 72 | 'nn.Linear': linearlayer_eps_wrapper_fct, 73 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 74 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 75 | } 76 | 77 | elif key == 'beta1': 78 | pass 79 | elif key == 'betaada': 80 | pass 81 | else: 82 | raise NotImplementedError("Unknown key", key) 83 | 84 | if classifier == 'ud': 85 | 86 | model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None,) 87 | somedict = torch.load(weightpath) 88 | model0.load_state_dict( somedict['model'] ) 89 | model0.eval() 90 | #load_pretrained_weights(model0, 'efficientnet-b0', weightpath) 91 | #model0.set_swish( memory_efficient=False) 92 | 93 | model_e = EfficientNet_canonized.from_pretrained('efficientnet-b0', num_classes=1, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 94 | #print(model_e) 95 | #model_e.set_swish( memory_efficient=False) 96 | 97 | model_e.copyfromefficientnet( model0, lrp_params_def1, lrp_layer2method) 98 | model_e.to(device) 99 | 100 | return model_e 101 | 102 | else: 103 | model0 = EfficientNet.from_pretrained('efficientnet-b0', image_size=None) 104 | model0.eval() 105 | #load_pretrained_weights(model0, 'efficientnet-b0', weightpath) 106 | #model0.set_swish( memory_efficient=False) 107 | 108 | model_e = EfficientNet_canonized.from_pretrained('efficientnet-b0', num_classes=1000, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 109 | #print(model_e) 110 | #model_e.set_swish( memory_efficient=False) 111 | 112 | model_e.copyfromefficientnet( model0, lrp_params_def1, lrp_layer2method) 113 | model_e.to(device) 114 | 115 | return model_e 116 | 117 | 118 | 119 | 120 | def get_lrp_explanations_for_batch(model, 121 | imagetensor, label, 122 | relfname , save, outpath, minus_fx=False): 123 | """ 124 | Get LRP explanations for a single sample. 125 | 126 | Args: 127 | model : pytorch model 128 | imagetensor : images 129 | label : label 130 | relfname : filenames 131 | save : If True, save LRP explanations 132 | outpath : output path to save pixel space explanations 133 | 134 | Get dict with positive relevances for the sample 135 | """ 136 | model.eval() 137 | 138 | all_lrp_explanations = [] 139 | 140 | os.makedirs(outpath, exist_ok=True) 141 | 142 | if imagetensor.grad is not None: 143 | imagetensor.grad.zero_() 144 | 145 | imagetensor.requires_grad=True # gxinp needs it here 146 | 147 | with torch.enable_grad(): 148 | outputs = model(imagetensor) 149 | 150 | with torch.no_grad(): 151 | # probs = outputs.sigmoid().flatten() 152 | # preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 153 | # correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 154 | #print(correct_pred_indices, correct_pred_indices.size(0)) 155 | #print(outputs, outputs[correct_pred_indices, :]) 156 | 157 | # if not correct_pred_indices.size(0) > 0: 158 | # return all_lrp_explanations 159 | 160 | 161 | # For imagenet: 162 | probs = outputs.softmax(dim=1).flatten() 163 | #preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 164 | #correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 165 | _, predclasses = torch.max(outputs, 1) 166 | 167 | #Propagate the signals for the correctly predicted samples for LRP (We should get the same LRP results if we use all samples as well.) 168 | # with torch.enable_grad(): 169 | # if minus_fx: 170 | # z = torch.sum( -outputs[correct_pred_indices, :] ) # Explain -f(x) if images are real 171 | 172 | # else: 173 | # z = torch.sum( outputs[correct_pred_indices, :] ) # Explain f(x) if images are fake 174 | 175 | # For imagenet 176 | with torch.enable_grad(): 177 | if minus_fx: 178 | z = torch.sum( -outputs[:, predclasses] ) # Explain imagenet 179 | 180 | else: 181 | z = torch.sum( outputs[:, predclasses] ) # Explain imagenet 182 | 183 | with torch.no_grad(): 184 | z.backward(retain_graph=True) 185 | rel = imagetensor.grad.data.clone() 186 | 187 | for b in range(imagetensor.shape[0]): 188 | # Check for correct preds and skip incorrect preds 189 | # cond = (probs[b].item() >= 0.5 and label[b].item() == 1) or (probs[b].item() < 0.5 and label[b].item() == 0) 190 | 191 | # if not cond: 192 | # continue 193 | 194 | fn = relfname[b] 195 | lrp_explanations = {} 196 | lrp_explanations['relfname'] = relfname[b] 197 | lrp_explanations['prob'] = probs[b].item() 198 | 199 | for i, (name, mod) in enumerate(model.named_modules()): 200 | if hasattr(mod, 'relfromoutput'): 201 | v = getattr(mod, 'relfromoutput') 202 | #print(i, name, v.shape) # conv rel map 203 | 204 | ftrelevances = v[b,:] 205 | 206 | # take only positives 207 | #ftrelevances[ftrelevances<0] = 0 208 | 209 | # Save feature relevances to LRP explanations dict. Move to cpu since data is big. 210 | #lrp_explanations[name] = ftrelevances.detach().cpu() 211 | 212 | # All LRP explanations 213 | if label[b].item() == 0: 214 | vis_dir_name = os.path.join(outpath, "visualization", "0_real") 215 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 216 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 217 | print(vis_fname) 218 | os.makedirs(vis_dir_name, exist_ok=True) 219 | 220 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 221 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 222 | q=100, outname=vis_fname ) 223 | 224 | # Store LRP values 225 | if save: 226 | lrp_dir_name = os.path.join(outpath, "lrp", "0_real") 227 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 228 | os.makedirs(lrp_dir_name, exist_ok=True) 229 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 230 | 231 | else: 232 | vis_dir_name = os.path.join(outpath, "visualization", "1_fake") 233 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 234 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 235 | os.makedirs(vis_dir_name, exist_ok=True) 236 | 237 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 238 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 239 | q=100, outname=vis_fname) 240 | 241 | if save: 242 | lrp_dir_name = os.path.join(outpath, "lrp", "1_fake") 243 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 244 | os.makedirs(lrp_dir_name, exist_ok=True) 245 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 246 | 247 | #all_lrp_explanations.append(lrp_explanations) 248 | 249 | torch.cuda.empty_cache() 250 | gc.collect() 251 | 252 | # del ftrelevances 253 | 254 | return all_lrp_explanations 255 | 256 | 257 | def get_all_lrp_positive_explanations(model, dataloader, device, outpath, save, minus_fx): 258 | """ 259 | Get all LRP explanations for one folder. 260 | 261 | Args: 262 | model : resnet50 pytorch model 263 | dataloader : pytorch dataloader 264 | device : 265 | outpath : output path to save visualization and lrp numpy files 266 | save : If set to True, save 267 | minus_fx : If set to True, we will use -f(x) signal to calculate relevances 268 | 269 | Returns all LRP explanations 270 | """ 271 | # Global variable to store feature map information 272 | all_lrp_explanations = [] 273 | 274 | # > Explain prediction 275 | for index, data in enumerate(dataloader): 276 | # Get image tensor, filename, stub and labels 277 | imagetensors = data['image'].to(device) 278 | fnames = data['filename'] 279 | relfnames = data['relfilestub'] 280 | labels = data['label'].to(device) 281 | 282 | # Get LRP explanations 283 | lrp_explanations = get_lrp_explanations_for_batch(model, imagetensors, labels, relfnames, save, outpath, minus_fx) 284 | 285 | # Get LRP heatmap for all layers 286 | all_lrp_explanations.extend(lrp_explanations) 287 | 288 | torch.cuda.empty_cache() 289 | gc.collect() 290 | del lrp_explanations 291 | 292 | return all_lrp_explanations 293 | 294 | 295 | 296 | def pipeline(model_e, dl, device, outpath, save, 297 | num_instances, 298 | minus_fx): 299 | """ 300 | Pipeline to run overall algorithm for real or fake images 301 | 302 | Args: 303 | model_e : Wrapped resnet50 304 | dl : dataloader 305 | device : cuda/ cpu 306 | outpath : output path to save the channelwise stats 307 | save : If set to True, all LRP relevances are saved as .pt files 308 | minus_fx: Needs to be set to True for real images 309 | normalize_using_only_positive : If set to True, scheme 1 else scheme 2. 310 | topk : #topk feature maps to return. 311 | """ 312 | # Get LRP explanations 313 | all_lrp_explanations = get_all_lrp_positive_explanations(model_e, dl, device, outpath, save, minus_fx)[:num_instances] 314 | 315 | #print(final_channelwise_stats) 316 | return all_lrp_explanations 317 | 318 | 319 | 320 | def main(): 321 | ## ------------Set Parameters-------- 322 | # Define device and other parameters 323 | device = torch.device('cuda:0') 324 | bsize = 16 325 | 326 | # LRP model keys and weights 327 | key = 'beta0' # 'beta0' 'beta1' , 'betaada' 328 | weightfn = './weights/{}/blur_jpg_prob0.5.pth'.format(args.arch) 329 | 330 | # Directories 331 | #parent_dir = '/mnt/workspace/projects/deepfake_classifiers_interpretability/samples/' 332 | parent_dir = '/mnt/data/CNN_synth_testset/' # Use our version 333 | have_classes = False 334 | gan_and_classes = {} 335 | # gan_and_classes['biggan'] = ['beer_bottle', 'monitor', 'vase', 'table_lamp', 336 | # 'hummingbird', 'church', 'egyptian_cat', 'welsh_springer_spaniel'] 337 | 338 | #gan_and_classes['progan_test'] = os.listdir(os.path.join(parent_dir, 'progan_test')) 339 | #gan_and_classes['progan_test'] = [ 'boat' ] 340 | #gan_and_classes['biggan'] = ['bird'] 341 | gan_and_classes['stylegan2'] = ['church', 'car', 'cat'] 342 | #gan_and_classes['stylegan'] = ['car'] 343 | #gan_and_classes['cyclegan'] = ['horse'] 344 | #gan_and_classes['stargan'] = ['person'] 345 | #gan_and_classes['gaugan'] = ['mscoco'] 346 | 347 | #gan_and_classes['san'] = [''] 348 | #gan_and_classes['stylegan2'] = ['car', 'cat', 'horse', 'church'] 349 | #gan_and_classes['cyclegan'] = ['horse'] 350 | #gan_and_classes['stargan'] = [''] 351 | 352 | 353 | for gan_name in gan_and_classes: 354 | for clss in gan_and_classes[gan_name]: 355 | print(gan_name, clss) 356 | root_dir = os.path.join(parent_dir, gan_name, clss) 357 | #clss='sr' 358 | outpath = './hms/lrp_heatmaps_{}_{}/{}/{}/'.format(args.arch, args.classifier, gan_name, clss) 359 | save_pt_files = False # No need to save .pt files. 360 | num_instances_real, num_instances_fake = 1, 500 # Use 1000 real and fake samples for analysis 361 | ## ------------End of Parameters-------- 362 | 363 | 364 | # Model 365 | model_e = get_wrapped_efficientnet_b0(weightfn, key, device, args.classifier) 366 | 367 | def writeintomodule_bwhook(self,grad_input, grad_output): 368 | #gradoutput is what arrives from above, shape id eq to output 369 | setattr(self,'relfromoutput', grad_output[0]) 370 | 371 | # Register hook 372 | for i, (name,mod) in enumerate(model_e.named_modules()): 373 | #print(i,nm) 374 | if ('conv' in name) and ('module' not in name): 375 | #print('ok') 376 | mod.register_backward_hook(writeintomodule_bwhook) 377 | 378 | # Dataset (Use same transforms as Wang et. al without Center cropping) 379 | transform = transforms.Compose([ 380 | transforms.Resize(256), 381 | transforms.CenterCrop(224), 382 | transforms.ToTensor(), 383 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 384 | ]) 385 | 386 | # Obtain D_real and D_fake 387 | dl_real = get_dataloader(root_dir, have_classes, num_instances_real , transform, bsize, onlyreal=True, onlyfake=False) 388 | dl_fake = get_dataloader(root_dir, have_classes, num_instances_fake , transform, bsize, onlyreal=False, onlyfake=True) 389 | 390 | # Pass to the overall algorithm pipeline to obtain C_real_topk, C_fake_topk 391 | real_lrp_explanations = pipeline(model_e, dl_real, device, outpath, save=save_pt_files, 392 | num_instances=num_instances_real, 393 | minus_fx=True) 394 | fake_lrp_explanations = pipeline(model_e, dl_fake, device, outpath, save=save_pt_files, 395 | num_instances=num_instances_fake, 396 | minus_fx=False) 397 | 398 | 399 | 400 | 401 | return real_lrp_explanations, fake_lrp_explanations 402 | 403 | 404 | if __name__=='__main__': 405 | main() 406 | -------------------------------------------------------------------------------- /src/lrp/resnet_wrapper.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | from torch.optim import lr_scheduler 7 | 8 | 9 | import torchvision 10 | from torchvision import datasets, models, transforms, utils 11 | from torch.utils.data import Dataset, DataLoader 12 | 13 | import numpy as np 14 | 15 | 16 | #from vocparseclslabels import PascalVOC 17 | 18 | from typing import Callable, Optional 19 | 20 | # from heatmaphelpers import * 21 | from lrp.resnet_lrp_general import * 22 | 23 | try: 24 | from torch.hub import load_state_dict_from_url 25 | except ImportError: 26 | from torch.utils.model_zoo import load_url as load_state_dict_from_url 27 | 28 | #resnet internals for conv-bn fusion 29 | from torchvision.models.resnet import BasicBlock,Bottleneck,ResNet 30 | 31 | 32 | class BasicBlock_fused(BasicBlock): 33 | expansion = 1 34 | 35 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 36 | base_width=64, dilation=1, norm_layer=None): 37 | super(BasicBlock_fused, self).__init__(inplanes, planes, stride, downsample, groups, 38 | base_width, dilation, norm_layer) 39 | 40 | #own 41 | self.elt=sum_stacked2() # eltwisesum2() 42 | 43 | def forward(self, x): 44 | identity = x 45 | 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | 50 | out = self.conv2(out) 51 | out = self.bn2(out) 52 | 53 | if self.downsample is not None: 54 | identity = self.downsample(x) 55 | 56 | #out += identity 57 | #out = self.relu(out) 58 | 59 | out = self.elt( torch.stack([out,identity], dim=0) ) #self.elt(out,identity) 60 | out = self.relu(out) 61 | 62 | return out 63 | 64 | class Bottleneck_fused(Bottleneck): 65 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 66 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 67 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 68 | # This variant is also known as ResNet V1.5 and improves accuracy according to 69 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 70 | 71 | expansion = 4 72 | 73 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 74 | base_width=64, dilation=1, norm_layer=None): 75 | super(Bottleneck_fused, self).__init__(inplanes, planes, stride, downsample, groups, 76 | base_width, dilation, norm_layer) 77 | 78 | #own 79 | self.elt=sum_stacked2() # eltwisesum2() 80 | 81 | 82 | def forward(self, x): 83 | identity = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | identity = self.downsample(x) 98 | 99 | #out += identity 100 | out = self.elt( torch.stack([out,identity], dim=0) ) #self.elt(out,identity) 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | class ResNet_canonized(ResNet): 106 | 107 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 108 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 109 | norm_layer=None): 110 | super(ResNet_canonized, self).__init__(block, layers, num_classes, zero_init_residual, 111 | groups, width_per_group, replace_stride_with_dilation, 112 | norm_layer) 113 | 114 | ###################### 115 | # change 116 | ###################### 117 | #own 118 | #self.avgpool = nn.AvgPool2d(kernel_size=7,stride=7 ) #nn.AdaptiveAvgPool2d((1, 1)) 119 | 120 | 121 | 122 | # runs in your current module to find the object layer3.1.conv2, and replaces it by the obkect stored in value (see success=iteratset(self,components,value) as initializer, can be modified to run in another class when replacing that self) 123 | def setbyname(self,name,value): 124 | 125 | def iteratset(obj,components,value): 126 | 127 | if not hasattr(obj,components[0]): 128 | return False 129 | elif len(components)==1: 130 | setattr(obj,components[0],value) 131 | #print('found!!', components[0]) 132 | #exit() 133 | return True 134 | else: 135 | nextobj=getattr(obj,components[0]) 136 | return iteratset(nextobj,components[1:],value) 137 | 138 | components=name.split('.') 139 | success=iteratset(self,components,value) 140 | return success 141 | 142 | def copyfromresnet(self,net, lrp_params, lrp_layer2method): 143 | assert( isinstance(net,ResNet)) 144 | 145 | 146 | # --copy linear 147 | # --copy conv2, while fusing bns 148 | # --reset bn 149 | 150 | # first conv, then bn, 151 | #means: when encounter bn, find the conv before -- implementation dependent 152 | 153 | 154 | updated_layers_names=[] 155 | 156 | last_src_module_name=None 157 | last_src_module=None 158 | 159 | for src_module_name, src_module in net.named_modules(): 160 | print('at src_module_name', src_module_name ) 161 | 162 | foundsth=False 163 | 164 | 165 | if isinstance(src_module, nn.Linear): 166 | #copy linear layers 167 | foundsth=True 168 | print('is Linear') 169 | #m = oneparam_wrapper_class( copy.deepcopy(src_module) , linearlayer_eps_wrapper_fct(), parameter1 = linear_eps ) 170 | wrapped = get_lrpwrapperformodule( copy.deepcopy(src_module) , lrp_params, lrp_layer2method) 171 | if False== self.setbyname(src_module_name, wrapped ): 172 | raise Modulenotfounderror("could not find module "+src_module_name+ " in target net to copy" ) 173 | updated_layers_names.append(src_module_name) 174 | # end of if 175 | 176 | 177 | if isinstance(src_module, nn.Conv2d): 178 | #store conv2d layers 179 | foundsth=True 180 | print('is Conv2d') 181 | last_src_module_name=src_module_name 182 | last_src_module=src_module 183 | # end of if 184 | 185 | if isinstance(src_module, nn.BatchNorm2d): 186 | # conv-bn chain 187 | foundsth=True 188 | print('is BatchNorm2d') 189 | 190 | if (True == lrp_params['use_zbeta']) and (last_src_module_name == 'conv1'): 191 | thisis_inputconv_andiwant_zbeta = True 192 | else: 193 | thisis_inputconv_andiwant_zbeta = False 194 | 195 | m = copy.deepcopy(last_src_module) 196 | m = bnafterconv_overwrite_intoconv(m , bn = src_module) 197 | # wrap conv 198 | wrapped = get_lrpwrapperformodule( m , lrp_params, lrp_layer2method, thisis_inputconv_andiwant_zbeta = thisis_inputconv_andiwant_zbeta ) 199 | 200 | if False== self.setbyname(last_src_module_name, wrapped ): 201 | raise Modulenotfounderror("could not find module "+nametofind+ " in target net to copy" ) 202 | 203 | updated_layers_names.append(last_src_module_name) 204 | 205 | # wrap batchnorm 206 | wrapped = get_lrpwrapperformodule( resetbn(src_module) , lrp_params, lrp_layer2method) 207 | if False== self.setbyname(src_module_name, wrapped ): 208 | raise Modulenotfounderror("could not find module "+src_module_name+ " in target net to copy" ) 209 | updated_layers_names.append(src_module_name) 210 | # end of if 211 | 212 | 213 | #if False== foundsth: 214 | # print('!untreated layer') 215 | print('\n') 216 | 217 | # sum_stacked2 is present only in the targetclass, so must iterate here 218 | for target_module_name, target_module in self.named_modules(): 219 | 220 | if isinstance(target_module, (nn.ReLU, nn.AdaptiveAvgPool2d, nn.MaxPool2d)): 221 | wrapped = get_lrpwrapperformodule( target_module , lrp_params, lrp_layer2method) 222 | 223 | if False== self.setbyname(target_module_name, wrapped ): 224 | raise Modulenotfounderror("could not find module "+src_module_name+ " in target net to copy" ) 225 | updated_layers_names.append(target_module_name) 226 | 227 | 228 | if isinstance(target_module, sum_stacked2 ): 229 | 230 | wrapped = get_lrpwrapperformodule( target_module , lrp_params, lrp_layer2method) 231 | if False== self.setbyname(target_module_name, wrapped ): 232 | raise Modulenotfounderror("could not find module "+target_module_name+ " in target net , impossible!" ) 233 | updated_layers_names.append(target_module_name) 234 | 235 | for target_module_name, target_module in self.named_modules(): 236 | if target_module_name not in updated_layers_names: 237 | print('not updated:', target_module_name) 238 | 239 | 240 | def _resnet_canonized(arch, block, layers, pretrained, progress, **kwargs): 241 | model = ResNet_canonized(block, layers, **kwargs) 242 | if pretrained: 243 | raise Cannotloadmodelweightserror("explainable nn model wrapper was never meant to load dictionary weights, load into standard model first, then instatiate this class from the standard model") 244 | return model 245 | 246 | 247 | def resnet18_canonized(pretrained=False, progress=True, **kwargs): 248 | r"""ResNet-18 model from 249 | `"Deep Residual Learning for Image Recognition" `_ 250 | Args: 251 | pretrained (bool): If True, returns a model pre-trained on ImageNet 252 | progress (bool): If True, displays a progress bar of the download to stderr 253 | """ 254 | return _resnet_canonized('resnet18', BasicBlock_fused, [2, 2, 2, 2], pretrained, progress, **kwargs) 255 | 256 | def resnet50_canonized(pretrained=False, progress=True, **kwargs): 257 | r"""ResNet-50 model from 258 | `"Deep Residual Learning for Image Recognition" `_ 259 | Args: 260 | pretrained (bool): If True, returns a model pre-trained on ImageNet 261 | progress (bool): If True, displays a progress bar of the download to stderr 262 | """ 263 | return _resnet_canonized('resnet50', Bottleneck_fused, [3, 4, 6, 3], pretrained, progress, **kwargs) 264 | 265 | 266 | 267 | -------------------------------------------------------------------------------- /src/median_test_activation_histograms.py: -------------------------------------------------------------------------------- 1 | import argparse, os 2 | import imp 3 | import math 4 | from utils.general import get_all_channels 5 | from activation_histograms.get_activation_values import get_activation_values 6 | 7 | import torch 8 | from torchvision import models, transforms 9 | 10 | from scipy import stats 11 | # from scipy imistats import ktest 12 | 13 | import seaborn as sns 14 | import pandas as pd 15 | import matplotlib.pyplot as plt 16 | 17 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 18 | 19 | # architecture 20 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 21 | 22 | # Data augmentation of model 23 | # parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 24 | parser.add_argument('--blur_jpg', type=str, required=True) 25 | 26 | # other metrics 27 | parser.add_argument('--bsize', type=int, default=16) 28 | 29 | # dataset 30 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 31 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 32 | parser.add_argument('--have_classes', type=int, default=1) 33 | 34 | # Images per class for identifying channels 35 | parser.add_argument('--num_instances', type=int, default=5) 36 | 37 | # topk 38 | parser.add_argument('--topk', type=int, default=114) 39 | 40 | 41 | def plot_act_hist(feature_map_name, hist1, hist2, label1, label2, gan_name, save_dir): 42 | plt.rcParams['axes.xmargin'] = 0 43 | plt.rcParams['lines.linewidth'] = 10.0 # thicker lines 44 | fig = plt.figure(figsize=(13, 8)) 45 | ax = fig.add_subplot(111) 46 | 47 | sns.kdeplot(data=hist1, ax=ax, label=label1) 48 | sns.kdeplot(data=hist2, ax=ax, label=label2) 49 | 50 | #plt.legend() 51 | # plt.xticks(linspace(0.0, 6.0, step=.5)) 52 | # #plt.ylim(0, 1.0) 53 | # plt.title(feature_map_name, fontsize=60) 54 | # plt.tight_layout() 55 | 56 | #ax.legend(loc="best", bbox_to_anchor=(1.0, 0.9), prop={'size': 44}) 57 | ax.set_title(feature_map_name.split('(')[0], fontsize=72, weight='bold') 58 | #ax.set_xticks(linspace(0.0, 6.0, step=2.0)) 59 | ax.set_ylabel("density", fontsize=72, weight='bold') 60 | ax.xaxis.set_tick_params(labelsize=68) 61 | ax.yaxis.set_tick_params(labelsize=68) 62 | #ax.set_xlim(0, 5) 63 | x_ticks = ax.xaxis.get_major_ticks() 64 | x_ticks[0].label1.set_visible(False) 65 | 66 | ax.grid(True) 67 | 68 | ax.set_xlabel("max spatial activation", fontsize=72, weight='bold' ) 69 | fig.tight_layout() 70 | #plt.show() 71 | 72 | fig.savefig("{}/{}.pdf".format(save_dir, gan_name), format="pdf", dpi=1200, bbox_inches='tight') 73 | 74 | # Plot legend seperately 75 | figsize = (20, 0.1) 76 | fig_leg = plt.figure(figsize=figsize) 77 | ax_leg = fig_leg.add_subplot(111) 78 | 79 | # add the legend from the previous axes 80 | ax_leg.legend(*ax.get_legend_handles_labels(), loc="upper center", mode = "expand", 81 | ncol = 2, frameon=False, fontsize=50) 82 | 83 | # hide the axes frame and the x/y labels 84 | ax_leg.axis('off') 85 | fig_leg.savefig('{}/legend.pdf'.format(save_dir), format='pdf', dpi=1200, bbox_inches='tight') 86 | #plt.show() 87 | plt.close() 88 | 89 | 90 | def main(): 91 | args = parser.parse_args() 92 | 93 | if(type(args.gan_name)==str): 94 | args.gan_name = [args.gan_name,] 95 | 96 | args.have_classes = bool(args.have_classes) 97 | 98 | print(args.gan_name, args.have_classes) 99 | 100 | # Get feature map names 101 | topk_channels, lowk_channels, all_channels = get_all_channels( 102 | fake_csv_path="fmap_relevances/{}/progan_val_{}/progan_val-fake.csv".format(args.arch, args.blur_jpg), 103 | topk=args.topk) 104 | 105 | transform = transforms.Compose([ 106 | #transforms.Resize(256), 107 | #transforms.CenterCrop(224), 108 | transforms.ToTensor(), 109 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 110 | ]) 111 | 112 | transform_grayscale = transforms.Compose([ 113 | #transforms.Resize(256), 114 | #transforms.CenterCrop(224), 115 | transforms.Grayscale(3), 116 | transforms.ToTensor(), 117 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 118 | ]) 119 | 120 | print(topk_channels) 121 | 122 | for gan_name in args.gan_name: 123 | # Store median test results 124 | df_median_test = pd.DataFrame(columns=['feature_map_name', 'p_value']) 125 | 126 | for feature_map_name in topk_channels: 127 | df, act_fake = get_activation_values(feature_map_name, args.dataset_dir, args.arch, gan_name, args.have_classes, args.blur_jpg, args.bsize, 128 | transform, num_instances=500) 129 | 130 | df, act_fake_grayscale = get_activation_values(feature_map_name, args.dataset_dir, args.arch, gan_name, args.have_classes, args.blur_jpg, args.bsize, 131 | transform_grayscale, num_instances=500) 132 | 133 | # Perform median test, returned tuple in stat, p, median, contingency table 134 | median_test = stats.median_test(act_fake, act_fake_grayscale) [1] 135 | 136 | record = {'feature_map_name': feature_map_name , 'p_value': median_test} 137 | print(record) 138 | df_median_test = df_median_test.append(record, ignore_index=True) 139 | 140 | save_loc = os.path.join('output', 'activation_histograms', args.arch, str(args.blur_jpg), feature_map_name) 141 | os.makedirs(save_loc, exist_ok=True) 142 | 143 | plot_act_hist(feature_map_name, act_fake, act_fake_grayscale, "Baseline", "Grayscale", gan_name, save_dir=save_loc) 144 | 145 | 146 | output_dir= 'output/median_test/{}_{}'.format(args.arch, args.blur_jpg) 147 | os.makedirs(output_dir, exist_ok=True) 148 | df_median_test.to_csv("{}/{}.csv".format(output_dir, gan_name), index=False) 149 | 150 | if __name__=='__main__': 151 | main() -------------------------------------------------------------------------------- /src/patch_collage.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | from utils.general import get_all_channels 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | import cv2 8 | import numpy as np 9 | 10 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 11 | 12 | # architecture 13 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 14 | 15 | # Data augmentation of model 16 | # parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 17 | parser.add_argument('--blur_jpg', type=str, required=True) 18 | 19 | # dataset 20 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 21 | 22 | # Images per class for identifying channels 23 | parser.add_argument('--num_instances', type=int, default=5) 24 | 25 | 26 | 27 | 28 | def create_collage(path_list, save_loc, gan, name): 29 | imgs = [ Image.open(i).resize((224, 224), Image.NEAREST) for i in path_list ][:5] 30 | #print(imgs) 31 | 32 | num_images = 5 33 | new_im = Image.new('RGB', (int(num_images*224), 224) ) 34 | 35 | # Add border, convert back to pil 36 | 37 | 38 | index = 0 39 | for i in range(0, int(num_images*224), 224): 40 | patch_np = np.asarray(imgs[i//224]) 41 | patch = cv2.copyMakeBorder(patch_np, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=[0, 0, 0]) 42 | patch = cv2.resize(patch, (224, 224)) 43 | patch = Image.fromarray(patch) 44 | 45 | new_im.paste(patch, (i, 0)) 46 | 47 | new_im.save('{}/{}.pdf'.format(save_loc, gan), quality=95, subsampling=0) 48 | 49 | 50 | 51 | def main(): 52 | args = parser.parse_args() 53 | 54 | if(type(args.gan_name)==str): 55 | args.gan_name = [args.gan_name,] 56 | 57 | #args.have_classes = bool(args.have_classes) 58 | print(args.gan_name) 59 | 60 | # Get feature map names 61 | topk_channels, lowk_channels, all_channels = get_all_channels( 62 | fake_csv_path="fmap_relevances/{}/progan_val_{}/progan_val-fake.csv".format(args.arch, args.blur_jpg), 63 | topk=27) 64 | 65 | for gan_name in args.gan_name: 66 | for feature_map_name in topk_channels: 67 | # feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 68 | # layertobeattached = feature_map_name.split("#")[0][:-1] 69 | feature_map_name = feature_map_name.split('(')[0] 70 | 71 | patch_parent_path = os.path.join('output', 'patches', args.arch, gan_name, 'fake', feature_map_name, 'visualization') 72 | print(patch_parent_path) 73 | paths = [ os.path.join(patch_parent_path, i) for i in os.listdir(patch_parent_path) ] 74 | paths.sort() 75 | 76 | save_loc = os.path.join('output', 'collages', args.arch, str(args.blur_jpg), feature_map_name) 77 | os.makedirs(save_loc, exist_ok=True) 78 | create_collage(paths, save_loc, gan_name, feature_map_name) 79 | 80 | 81 | 82 | 83 | if __name__=='__main__': 84 | main() -------------------------------------------------------------------------------- /src/patch_extraction/extract_lrp_max_patches_using_filenames_efb0.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # from utils import mask_topk_channels 4 | import torch 5 | import torch.nn as nn 6 | 7 | import torchvision 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | #import matplotlib.pyplot as plt 11 | 12 | from torch import Tensor 13 | 14 | import datetime 15 | import time 16 | import os 17 | import numpy as np 18 | 19 | import PIL.Image 20 | 21 | 22 | #LRP 23 | 24 | from lrp.ef_wrapper import * 25 | from utils.heatmap_helpers import * 26 | from utils.dataset_helpers import * 27 | from utils.general import * 28 | from utils.mask_fmaps import * 29 | 30 | from utils import * 31 | 32 | 33 | # Keep penultimate features as global varialble such that hook modifies these features 34 | penultimate_fts = None 35 | 36 | 37 | def get_penultimate_fts(self, input, output): 38 | global penultimate_fts 39 | penultimate_fts = output 40 | return None 41 | 42 | 43 | class Dataset_from_paths(torch.utils.data.Dataset): 44 | 45 | def __init__(self, img_paths, transform=None): 46 | self.transform = transform 47 | self.imgfilenames=img_paths 48 | 49 | def __len__(self): 50 | return len(self.imgfilenames) 51 | 52 | def __getitem__(self, idx): 53 | image = PIL.Image.open(self.imgfilenames[idx]).convert('RGB') 54 | label=1 55 | 56 | if self.transform: 57 | image = self.transform(image) 58 | 59 | sample = {'image': image, 'label': label, 'filename': self.imgfilenames[idx] } 60 | 61 | return sample 62 | 63 | 64 | def interpret_simtest(model, actual_model, dataloader, probs, 65 | device, savept, 66 | layer_name, channel_index, 67 | bias_inducer): 68 | #print(probs) 69 | if not os.path.isdir(savept): 70 | os.makedirs(savept) 71 | 72 | model.eval() 73 | fnames = [] 74 | rank = 0 75 | 76 | max_act_values = [] 77 | corr_filenames = [] 78 | 79 | for batch_idx, data in enumerate(dataloader): 80 | 81 | if (batch_idx%100==0) and (batch_idx>=100): 82 | print('at val batchindex: ',batch_idx) 83 | 84 | inputs = data['image'].to(device) 85 | labels = data['label'] 86 | if bias_inducer is not None: 87 | inputs=bias_inducer(inputs, labels) 88 | 89 | fnames.extend(data['filename']) 90 | inputs.requires_grad=True 91 | print('inputs.requires_grad',inputs.requires_grad) 92 | 93 | with torch.no_grad(): 94 | global penultimate_fts 95 | penultimate_fts = None 96 | assert(penultimate_fts == None) 97 | actual_activation = actual_model(inputs) 98 | #sys.exit() 99 | assert torch.is_tensor(penultimate_fts) 100 | 101 | with torch.enable_grad(): 102 | outputs = model(inputs) 103 | 104 | for b in range(outputs.shape[0]): 105 | rank += 1 106 | # Look for high activating samples 107 | maxpool_activation = torch.max(penultimate_fts[:, channel_index, :, :]) 108 | 109 | #print('inp',inputs.shape, outputs.shape) 110 | if inputs.grad is not None: 111 | inputs.grad.zero_() 112 | 113 | prez = outputs[b] 114 | mask = torch.zeros_like(prez) 115 | mask[outputs[b] > 0] = 1.0 116 | 117 | # Create prez 118 | print("prez size", prez.size()) 119 | prez2= prez*mask 120 | 121 | #find maxind 122 | vh,indh = torch.max(prez2,dim=0) 123 | vw,indw = torch.max(vh,dim=0) 124 | z=vh[indw] 125 | 126 | #find sum 127 | #print("prez size", prez2.size(), outputs.size()) 128 | #z=torch.sum(prez2) 129 | 130 | print('z val', z.item()) 131 | 132 | # if (labels[b].item() == 0): 133 | # z = -z 134 | 135 | with torch.no_grad(): 136 | z.backward() 137 | rel=inputs.grad.data.clone().detach().cpu() 138 | # rel= rel/torch.abs(torch.sum(rel)) 139 | #if outputs[b,n] Tensor: 177 | # # See note [TorchScript super()] 178 | # x = self.conv1(x) 179 | # x = self.bn1(x) 180 | # x = self.relu(x) 181 | # x = self.maxpool(x) 182 | 183 | # x = self.layer1(x) 184 | # x = self.layer2(x) 185 | # x = self.layer3(x) 186 | # x = self.layer4(x) 187 | 188 | # x = self.avgpool(x) 189 | # x = torch.flatten(x, 1) 190 | # x = self.fc(x) 191 | 192 | # Convolution layers 193 | #print('pre ex') 194 | x = self.extract_features(x) 195 | #print('post ex') 196 | # Pooling and final linear layer 197 | x = self._avg_pooling(x) 198 | if self._global_params.include_top: 199 | x = x.flatten(start_dim=1) 200 | x = self._dropout(x) 201 | x = self._fc(x) 202 | 203 | rettensor=None 204 | for ind,(name,module) in enumerate(self.named_modules()): 205 | if hasattr(module,'tempstorefeature'): 206 | rettensor = getattr(module,'tempstorefeature').clone() 207 | print('found tempstorefeature at',ind, name) 208 | #clean up chicken shit 209 | delattr(module,'tempstorefeature') 210 | 211 | if rettensor is not None: 212 | return rettensor 213 | 214 | print('no special feature map found') 215 | return x 216 | 217 | def forward(self, x: Tensor) -> Tensor: 218 | return self._forward_impl(x) 219 | 220 | 221 | 222 | def onlywritefeaturemap_hook(module, input_ , output, channelind): 223 | 224 | if channelind is None: 225 | module.tempstorefeature=output 226 | else: 227 | module.tempstorefeature=output[:,channelind] 228 | 229 | def hook_factory2(channelind): 230 | 231 | # define the function with the right signature to be created 232 | def ahook(module, input_, output): 233 | # instantiate it by taking a parametrized function, 234 | # and fill the parameters 235 | # return the filled function 236 | return onlywritefeaturemap_hook(module, input_, output, channelind = channelind) 237 | 238 | # return the hook function as if it were a string 239 | return ahook 240 | 241 | 242 | 243 | def get_probs(model, data_loader, device): 244 | y_true, y_pred = [], [] 245 | Hs, Ws = [], [] 246 | 247 | from tqdm import tqdm 248 | 249 | with torch.no_grad(): 250 | for datas in tqdm(data_loader): 251 | #print(datas['label'].size()) 252 | data = datas['image'].to(device) 253 | label = datas['label'] 254 | # for data, label in data_loader: 255 | Hs.append(data.shape[2]) 256 | Ws.append(data.shape[3]) 257 | 258 | y_true.extend(label.flatten().tolist()) 259 | data = data.cuda() 260 | y_pred.extend(model(data).sigmoid().flatten().tolist()) 261 | 262 | Hs, Ws = np.array(Hs), np.array(Ws) 263 | y_true, y_pred = np.array(y_true), np.array(y_pred).astype(np.float16) 264 | 265 | print(np.count_nonzero(y_pred[y_true==1]>=0.5)) 266 | 267 | return y_pred 268 | 269 | 270 | def get_high_activation_patches(feature_map_name, arch, gan_name, aug, bsize, num_instances=None): 271 | ## ------------Set Parameters-------- 272 | # Define device and other parameters 273 | device = torch.device('cuda:0') 274 | key = 'beta0' 275 | 276 | # model and weights 277 | if arch == 'resnet50': 278 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 279 | model = get_resnet50_universal_detector(weightfn).to(device) 280 | 281 | elif arch == 'efb0': 282 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 283 | model = get_efb0_universal_detector(weightfn).to(device) 284 | 285 | # Define device and other parameters 286 | feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 287 | layertobeattached = feature_map_name.split("#")[0][:-1] 288 | #print(layertobeattached) 289 | 290 | # Use original transform as Wang et. al 291 | transform = transforms.Compose([ 292 | transforms.ToTensor(), 293 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 294 | ]) 295 | 296 | transform_crop = transforms.Compose([ 297 | transforms.CenterCrop(224), 298 | transforms.ToTensor(), 299 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 300 | ]) 301 | 302 | # Read csv 303 | df = pd.read_csv('output/activation_rankings/{}/{}_{}/{}.csv'.format(arch, gan_name, aug, feature_map_name)) 304 | img_paths = df['name'][:num_instances] 305 | ds = Dataset_from_paths(img_paths, transform=transform) 306 | dl = torch.utils.data.DataLoader(ds, batch_size= 1, shuffle=False) 307 | 308 | ds_ = Dataset_from_paths(img_paths, transform=transform_crop) 309 | dl_prob = torch.utils.data.DataLoader(ds_, batch_size= bsize, shuffle=False) 310 | 311 | # Load LRP wrapped model 312 | #model = get_resnet50_universal_detector(weightfn).to(device) 313 | 314 | lrp_params_def1={ 315 | 'conv2d_ignorebias': True, 316 | 'eltwise_eps': 1e-6, 317 | 'linear_eps': 1e-6, 318 | 'pooling_eps': 1e-6, 319 | 'use_zbeta': False , 320 | } 321 | 322 | lrp_layer2method={ 323 | 'Swish': relu_wrapper_fct, 324 | 'nn.BatchNorm2d': relu_wrapper_fct, 325 | 'nn.Conv2d': Conv2dDynamicSamePadding_beta0_wrapper_fct, 326 | 'nn.Linear': linearlayer_eps_wrapper_fct, 327 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 328 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 329 | } 330 | 331 | 332 | # model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None,) 333 | # somedict = torch.load(weightfn) 334 | # model0.load_state_dict( somedict['model'] ) 335 | # model0.eval() 336 | 337 | blockPrint() 338 | model_e = EfficientNet_canonized_modfwd.from_pretrained('efficientnet-b0', num_classes=1, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 339 | model_e.copyfromefficientnet( model, lrp_params_def1, lrp_layer2method) 340 | model_e.to(device) 341 | enablePrint() 342 | 343 | model.eval() 344 | model_e.eval() 345 | 346 | # Get output probs 347 | preds = get_probs(model, dl_prob, device) 348 | 349 | # --------------------------------------------- 350 | for n, m in model.named_modules(): 351 | m.auto_name = n 352 | 353 | for n, m in model_e.named_modules(): 354 | m.auto_name = n 355 | 356 | # Attach hooks to LRP model 357 | handles=[] 358 | for ind,(name,module) in enumerate(model_e.named_modules()): 359 | print('name: {}'.format(name) ) 360 | if name ==layertobeattached: 361 | h=module.register_forward_hook( hook_factory2( channelind = feature_map_idx )) 362 | handles.append(h) 363 | 364 | # Attach hook to original model 365 | new_handles = [] 366 | for ind,(name,module) in enumerate(model.named_modules()): 367 | if name ==layertobeattached: 368 | print('name: {}'.format(name) ) 369 | h=module.register_forward_hook( get_penultimate_fts ) 370 | new_handles.append(h) 371 | 372 | #sys.exit() 373 | 374 | # do this only for a subset by subsetting the dataloader 375 | # fwd hook copies module, return check through if a module has a feature 376 | save_suffix = 'fake' 377 | save_dir = os.path.join('./output/patches/', arch, gan_name, save_suffix, "{}.#{}".format(layertobeattached, feature_map_idx)) 378 | max_act_vals, corr_filenames = interpret_simtest(model_e, model, dataloader = dl, probs=preds, device = device, 379 | savept='{}/'.format(save_dir), 380 | layer_name= layertobeattached, channel_index=feature_map_idx, bias_inducer = None ) 381 | 382 | max_act_vals = np.asarray(max_act_vals) 383 | enablePrint() 384 | 385 | 386 | # if __name__=='__main__': 387 | 388 | # # Read fmaps 389 | # df = pd.read_csv("progan_val-grayscale.csv") 390 | # df = df[df['p_value'] < 0.05] 391 | # print(df.shape) 392 | # print(df) 393 | # fmaps = list(df['feature_map_name'])[:] 394 | 395 | # # fmaps = [ 396 | # # '"layer4.2.conv1.#487(T=512)"', 397 | 398 | # # ] 399 | # for i in range(len(fmaps)): 400 | # feature_map_name = fmaps[i] 401 | # print(i, feature_map_name) 402 | # extract_patches( feature_map_name, None) 403 | -------------------------------------------------------------------------------- /src/patch_extraction/extract_lrp_max_patches_using_filenames_resnet50.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # from utils import mask_topk_channels 4 | import torch 5 | import torch.nn as nn 6 | 7 | import torchvision 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset, DataLoader 10 | #import matplotlib.pyplot as plt 11 | 12 | from torch import Tensor 13 | 14 | import datetime 15 | import time 16 | import os 17 | import numpy as np 18 | 19 | import PIL.Image 20 | 21 | 22 | #LRP 23 | 24 | from lrp.resnet_wrapper import * 25 | from utils.heatmap_helpers import * 26 | from utils.dataset_helpers import * 27 | from utils.general import * 28 | from utils.mask_fmaps import * 29 | 30 | from utils import * 31 | 32 | 33 | # Keep penultimate features as global varialble such that hook modifies these features 34 | penultimate_fts = None 35 | 36 | 37 | def get_penultimate_fts(self, input, output): 38 | global penultimate_fts 39 | penultimate_fts = output 40 | return None 41 | 42 | 43 | class Dataset_from_paths(torch.utils.data.Dataset): 44 | 45 | def __init__(self, img_paths, transform=None): 46 | self.transform = transform 47 | self.imgfilenames=img_paths 48 | 49 | def __len__(self): 50 | return len(self.imgfilenames) 51 | 52 | def __getitem__(self, idx): 53 | image = PIL.Image.open(self.imgfilenames[idx]).convert('RGB') 54 | label=1 55 | 56 | if self.transform: 57 | image = self.transform(image) 58 | 59 | sample = {'image': image, 'label': label, 'filename': self.imgfilenames[idx] } 60 | 61 | return sample 62 | 63 | 64 | def interpret_simtest(model, actual_model, dataloader, probs, 65 | device, savept, 66 | layer_name, channel_index, 67 | bias_inducer): 68 | #print(probs) 69 | if not os.path.isdir(savept): 70 | os.makedirs(savept) 71 | 72 | model.eval() 73 | fnames = [] 74 | rank = 0 75 | 76 | max_act_values = [] 77 | corr_filenames = [] 78 | 79 | for batch_idx, data in enumerate(dataloader): 80 | 81 | if (batch_idx%100==0) and (batch_idx>=100): 82 | print('at val batchindex: ',batch_idx) 83 | 84 | inputs = data['image'].to(device) 85 | labels = data['label'] 86 | if bias_inducer is not None: 87 | inputs=bias_inducer(inputs, labels) 88 | 89 | fnames.extend(data['filename']) 90 | inputs.requires_grad=True 91 | print('inputs.requires_grad',inputs.requires_grad) 92 | 93 | with torch.no_grad(): 94 | global penultimate_fts 95 | penultimate_fts = None 96 | assert(penultimate_fts == None) 97 | actual_activation = actual_model(inputs) 98 | #sys.exit() 99 | assert torch.is_tensor(penultimate_fts) 100 | 101 | with torch.enable_grad(): 102 | outputs = model(inputs) 103 | 104 | for b in range(outputs.shape[0]): 105 | rank += 1 106 | # Look for high activating samples 107 | maxpool_activation = torch.max(penultimate_fts[:, channel_index, :, :]) 108 | 109 | #print('inp',inputs.shape, outputs.shape) 110 | if inputs.grad is not None: 111 | inputs.grad.zero_() 112 | 113 | prez = outputs[b] 114 | mask = torch.zeros_like(prez) 115 | mask[outputs[b] > 0] = 1.0 116 | 117 | # Create prez 118 | print("prez size", prez.size()) 119 | prez2= prez*mask 120 | 121 | #find maxind 122 | vh,indh = torch.max(prez2,dim=0) 123 | vw,indw = torch.max(vh,dim=0) 124 | z=vh[indw] 125 | 126 | #find sum 127 | #print("prez size", prez2.size(), outputs.size()) 128 | #z=torch.sum(prez2) 129 | 130 | print('z val', z.item()) 131 | 132 | # if (labels[b].item() == 0): 133 | # z = -z 134 | 135 | with torch.no_grad(): 136 | z.backward() 137 | rel=inputs.grad.data.clone().detach().cpu() 138 | # rel= rel/torch.abs(torch.sum(rel)) 139 | #if outputs[b,n] Tensor: 177 | # See note [TorchScript super()] 178 | x = self.conv1(x) 179 | x = self.bn1(x) 180 | x = self.relu(x) 181 | x = self.maxpool(x) 182 | 183 | x = self.layer1(x) 184 | x = self.layer2(x) 185 | x = self.layer3(x) 186 | x = self.layer4(x) 187 | 188 | x = self.avgpool(x) 189 | x = torch.flatten(x, 1) 190 | x = self.fc(x) 191 | 192 | rettensor=None 193 | for ind,(name,module) in enumerate(self.named_modules()): 194 | if hasattr(module,'tempstorefeature'): 195 | rettensor = getattr(module,'tempstorefeature').clone() 196 | print('found tempstorefeature at',ind, name) 197 | #clean up chicken shit 198 | delattr(module,'tempstorefeature') 199 | 200 | if rettensor is not None: 201 | return rettensor 202 | 203 | print('no special feature map found') 204 | return x 205 | 206 | def forward(self, x: Tensor) -> Tensor: 207 | return self._forward_impl(x) 208 | 209 | 210 | def _resnet_canonized_modfwd(arch, block, layers, pretrained, progress, **kwargs): 211 | model = ResNet_canonized_modfwd(block, layers, **kwargs) 212 | if pretrained: 213 | raise Cannotloadmodelweightserror("explainable nn model wrapper was never meant to load dictionary weights, load into standard model first, then instatiate this class from the standard model") 214 | return model 215 | 216 | 217 | def resnet18_canonized_modfwd(pretrained=False, progress=True, **kwargs): 218 | r"""ResNet-18 model from 219 | `"Deep Residual Learning for Image Recognition" `_ 220 | Args: 221 | pretrained (bool): If True, returns a model pre-trained on ImageNet 222 | progress (bool): If True, displays a progress bar of the download to stderr 223 | """ 224 | return _resnet_canonized_modfwd('resnet18', BasicBlock_fused, [2, 2, 2, 2], pretrained, progress, **kwargs) 225 | 226 | def resnet50_canonized_modfwd(pretrained=False, progress=True, **kwargs): 227 | r"""ResNet-50 model from 228 | `"Deep Residual Learning for Image Recognition" `_ 229 | Args: 230 | pretrained (bool): If True, returns a model pre-trained on ImageNet 231 | progress (bool): If True, displays a progress bar of the download to stderr 232 | """ 233 | return _resnet_canonized_modfwd('resnet50', Bottleneck_fused, [3, 4, 6, 3], pretrained, progress, **kwargs) 234 | 235 | 236 | 237 | 238 | def onlywritefeaturemap_hook(module, input_ , output, channelind): 239 | 240 | if channelind is None: 241 | module.tempstorefeature=output 242 | else: 243 | module.tempstorefeature=output[:,channelind] 244 | 245 | def hook_factory2(channelind): 246 | 247 | # define the function with the right signature to be created 248 | def ahook(module, input_, output): 249 | # instantiate it by taking a parametrized function, 250 | # and fill the parameters 251 | # return the filled function 252 | return onlywritefeaturemap_hook(module, input_, output, channelind = channelind) 253 | 254 | # return the hook function as if it were a string 255 | return ahook 256 | 257 | 258 | 259 | def get_probs(model, data_loader, device): 260 | y_true, y_pred = [], [] 261 | Hs, Ws = [], [] 262 | 263 | from tqdm import tqdm 264 | 265 | with torch.no_grad(): 266 | for datas in tqdm(data_loader): 267 | #print(datas['label'].size()) 268 | data = datas['image'].to(device) 269 | label = datas['label'] 270 | # for data, label in data_loader: 271 | Hs.append(data.shape[2]) 272 | Ws.append(data.shape[3]) 273 | 274 | y_true.extend(label.flatten().tolist()) 275 | data = data.cuda() 276 | y_pred.extend(model(data).sigmoid().flatten().tolist()) 277 | 278 | Hs, Ws = np.array(Hs), np.array(Ws) 279 | y_true, y_pred = np.array(y_true), np.array(y_pred).astype(np.float16) 280 | 281 | print(np.count_nonzero(y_pred[y_true==1]>=0.5)) 282 | 283 | return y_pred 284 | 285 | 286 | def get_high_activation_patches(feature_map_name, arch, gan_name, aug, bsize, num_instances=None): 287 | ## ------------Set Parameters-------- 288 | # Define device and other parameters 289 | device = torch.device('cuda:0') 290 | key = 'beta0' 291 | 292 | # model and weights 293 | if arch == 'resnet50': 294 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 295 | model = get_resnet50_universal_detector(weightfn).to(device) 296 | 297 | elif arch == 'efb0': 298 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 299 | model = get_efb0_universal_detector(weightfn).to(device) 300 | 301 | # Define device and other parameters 302 | feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 303 | layertobeattached = feature_map_name.split("#")[0][:-1] 304 | #print(layertobeattached) 305 | 306 | # Use original transform as Wang et. al 307 | transform = transforms.Compose([ 308 | transforms.ToTensor(), 309 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 310 | ]) 311 | 312 | transform_crop = transforms.Compose([ 313 | transforms.CenterCrop(224), 314 | transforms.ToTensor(), 315 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 316 | ]) 317 | 318 | # Read csv 319 | df = pd.read_csv('output/activation_rankings/{}/{}_{}/{}.csv'.format(arch, gan_name, aug, feature_map_name)) 320 | img_paths = df['name'][:num_instances] 321 | ds = Dataset_from_paths(img_paths, transform=transform) 322 | dl = torch.utils.data.DataLoader(ds, batch_size= 1, shuffle=False) 323 | 324 | ds_ = Dataset_from_paths(img_paths, transform=transform_crop) 325 | dl_prob = torch.utils.data.DataLoader(ds_, batch_size= bsize, shuffle=False) 326 | 327 | # Load LRP wrapped model 328 | #model = get_resnet50_universal_detector(weightfn).to(device) 329 | 330 | lrp_params_def1={ 331 | 'conv2d_ignorebias': True, 332 | 'eltwise_eps': 1e-6, 333 | 'linear_eps': 1e-6, 334 | 'pooling_eps': 1e-6, 335 | 'use_zbeta': True , 336 | } 337 | 338 | lrp_layer2method={ 339 | 'nn.ReLU': relu_wrapper_fct, 340 | 'nn.BatchNorm2d': relu_wrapper_fct, 341 | 'nn.Conv2d': conv2d_beta0_wrapper_fct, 342 | 'nn.Linear': linearlayer_eps_wrapper_fct, 343 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 344 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 345 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 346 | } 347 | 348 | model_e = resnet50_canonized_modfwd(pretrained=False) 349 | model_e.copyfromresnet(model, lrp_params=lrp_params_def1, lrp_layer2method = lrp_layer2method) 350 | model_e = model_e.to(device) 351 | 352 | model.eval() 353 | model_e.eval() 354 | 355 | # Get output probs 356 | preds = get_probs(model, dl_prob, device) 357 | 358 | # --------------------------------------------- 359 | for n, m in model.named_modules(): 360 | m.auto_name = n 361 | 362 | for n, m in model_e.named_modules(): 363 | m.auto_name = n 364 | 365 | # Attach hooks to LRP model 366 | handles=[] 367 | for ind,(name,module) in enumerate(model_e.named_modules()): 368 | print('name: {}'.format(name) ) 369 | if name ==layertobeattached: 370 | h=module.register_forward_hook( hook_factory2( channelind = feature_map_idx )) 371 | handles.append(h) 372 | 373 | # Attach hook to original model 374 | new_handles = [] 375 | for ind,(name,module) in enumerate(model.named_modules()): 376 | if name ==layertobeattached: 377 | print('name: {}'.format(name) ) 378 | h=module.register_forward_hook( get_penultimate_fts ) 379 | new_handles.append(h) 380 | 381 | #sys.exit() 382 | 383 | # do this only for a subset by subsetting the dataloader 384 | # fwd hook copies module, return check through if a module has a feature 385 | save_suffix = 'fake' 386 | save_dir = os.path.join('./output/patches/', arch, gan_name, save_suffix, "{}.#{}".format(layertobeattached, feature_map_idx)) 387 | max_act_vals, corr_filenames = interpret_simtest(model_e, model, dataloader = dl, probs=preds, device = device, 388 | savept='{}/'.format(save_dir), 389 | layer_name= layertobeattached, channel_index=feature_map_idx, bias_inducer = None ) 390 | 391 | max_act_vals = np.asarray(max_act_vals) 392 | enablePrint() 393 | 394 | 395 | # if __name__=='__main__': 396 | 397 | # # Read fmaps 398 | # df = pd.read_csv("progan_val-grayscale.csv") 399 | # df = df[df['p_value'] < 0.05] 400 | # print(df.shape) 401 | # print(df) 402 | # fmaps = list(df['feature_map_name'])[:] 403 | 404 | # # fmaps = [ 405 | # # '"layer4.2.conv1.#487(T=512)"', 406 | 407 | # # ] 408 | # for i in range(len(fmaps)): 409 | # feature_map_name = fmaps[i] 410 | # print(i, feature_map_name) 411 | # extract_patches( feature_map_name, None) 412 | -------------------------------------------------------------------------------- /src/patch_extraction/get_top_activated_images.py: -------------------------------------------------------------------------------- 1 | 2 | # Import base libraries 3 | import os, sys, math, gc 4 | import PIL 5 | import copy 6 | import random, json 7 | from collections import OrderedDict 8 | 9 | # Import scientific computing libraries 10 | import numpy as np 11 | 12 | # Import torch and dependencies 13 | import torch 14 | from torchvision import models, transforms 15 | 16 | # Import utils 17 | from utils.heatmap_helpers import * 18 | from utils.dataset_helpers import * 19 | from utils.general import * 20 | 21 | # Import other libraries 22 | import matplotlib.pyplot as plt 23 | import pandas as pd 24 | from scipy import stats 25 | # from scipy imistats import ktest 26 | 27 | import seaborn as sns 28 | 29 | 30 | # Keep penultimate features as global varialble such that hook modifies these features 31 | penultimate_fts = None 32 | 33 | 34 | def get_penultimate_fts(self, input, output): 35 | global penultimate_fts 36 | penultimate_fts = output 37 | return None 38 | 39 | 40 | def get_class_specific_penultimate_fts(model, dls, device): 41 | global penultimate_fts 42 | penultimate_fts = None 43 | assert(penultimate_fts == None) 44 | 45 | model.eval() # Set model to eval mode 46 | m = torch.nn.AdaptiveMaxPool2d((1, 1)) # Look at maximum activation features 47 | 48 | # Store fts in a global array 49 | all_features = None 50 | fnames = [] 51 | 52 | with torch.no_grad(): 53 | for dl in dls: 54 | for batch_idx, data in enumerate(dl): 55 | x = data['image'].to(device) 56 | y = data['label'] 57 | fname = data['filename'] 58 | output = model(x) 59 | assert torch.is_tensor(penultimate_fts) 60 | 61 | fnames.extend(fname) 62 | 63 | if all_features is None: 64 | all_features = (m(penultimate_fts.data.clone().cpu()) ).numpy().squeeze() 65 | 66 | else: 67 | features = (m(penultimate_fts.data.clone().cpu()) ).numpy().squeeze() 68 | all_features = np.concatenate((all_features, features), axis=0) 69 | 70 | return all_features, fnames 71 | 72 | 73 | 74 | def get_activation_rankings(feature_map_name, arch, parent_dir, gan_name, aug, have_classes, bsize, num_instances=None): 75 | 76 | ## ------------Set Parameters-------- 77 | # Define device and other parameters 78 | device = torch.device('cuda:0') 79 | 80 | # model and weights 81 | if arch == 'resnet50': 82 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 83 | model = get_resnet50_universal_detector(weightfn).to(device) 84 | 85 | elif arch == 'efb0': 86 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 87 | model = get_efb0_universal_detector(weightfn).to(device) 88 | 89 | # Define device and other parameters 90 | feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 91 | layertobeattached = feature_map_name.split("#")[0][:-1] 92 | #print(layertobeattached) 93 | 94 | # Attach hook to original model 95 | new_handles = [] 96 | for ind,(name,module) in enumerate(model.named_modules()): 97 | if name ==layertobeattached: 98 | print('name: {}'.format(name) ) 99 | h=module.register_forward_hook( get_penultimate_fts ) 100 | new_handles.append(h) 101 | 102 | # Use original transform as Wang et. al 103 | transform = transforms.Compose([ 104 | transforms.ToTensor(), 105 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 106 | ]) 107 | 108 | #for idx, clss in enumerate(clsses): 109 | root_dir = os.path.join(parent_dir, gan_name) 110 | 111 | if have_classes: 112 | dls = get_classwise_dataloader(root_dir, have_classes, max_num=num_instances, transform=transform, bsize=bsize, onlyreal=False, onlyfake=True) 113 | else: 114 | dl = get_dataloader(root_dir, have_classes, max_num=num_instances, transform=transform, bsize=bsize, onlyreal=False, onlyfake=True) 115 | dls = [dl] 116 | 117 | clss_specific_fts_fake, fnames = get_class_specific_penultimate_fts(model, dls, device) 118 | clss_specific_fts_fake = clss_specific_fts_fake[:, feature_map_idx] 119 | #print(clss_specific_fts_fake, fnames) 120 | 121 | # Sort and get the top activated images 122 | max_act_vals = np.asarray(clss_specific_fts_fake) 123 | idx = (-max_act_vals.copy()).argsort() 124 | 125 | # Save df 126 | df = pd.DataFrame() 127 | df['name'] = [fnames[i] for i in idx] 128 | df['max_act'] = [ max_act_vals[i] for i in idx] 129 | 130 | # Create output directory and save 131 | output_dir = os.path.join( 'output', 'activation_rankings', arch, '{}_{}'.format(gan_name, aug) ) 132 | os.makedirs(output_dir, exist_ok=True) 133 | df.to_csv("{}/{}.csv".format(output_dir, feature_map_name), index=None) 134 | 135 | 136 | 137 | 138 | # if __name__=='__main__': 139 | 140 | # # Read fmaps 141 | # df = pd.read_csv("fmaps.csv") 142 | # fmaps = list(df['fmap'])[76:] 143 | 144 | # for i in range(len(fmaps)): 145 | # feature_map_name = fmaps[i][:-1] 146 | # print(feature_map_name) 147 | # get_activation_rankings(feature_map_name) -------------------------------------------------------------------------------- /src/rank_fmaps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 5 | 6 | # architecture 7 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 8 | 9 | # Data augmentation of model 10 | # parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 11 | parser.add_argument('--blur_jpg', type=str, required=True) 12 | 13 | # other metrics 14 | parser.add_argument('--bsize', type=int, default=16) 15 | 16 | # dataset 17 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 18 | parser.add_argument('--gan_name', type=str, default='progan_val') 19 | parser.add_argument('--have_classes', type=bool, default=True) 20 | 21 | # Images per class for identifying channels 22 | parser.add_argument('--num_real', type=int, default=5) 23 | parser.add_argument('--num_fake', type=int, default=5) 24 | 25 | # Saving 26 | parser.add_argument('--save_pt_files', type=bool, default=True) 27 | 28 | def main(): 29 | args = parser.parse_args() 30 | 31 | if args.arch == 'resnet50': 32 | from fmap_ranking.rank_fmaps_ud_r50 import rank_fmaps 33 | 34 | rank_fmaps(args.dataset_dir, args.gan_name, 35 | args.blur_jpg, args.have_classes, 36 | args.bsize, 37 | args.num_real, args.num_fake, 38 | args.save_pt_files, 39 | normalize_only_using_positive = False, 40 | topk=None) 41 | 42 | elif args.arch == 'efb0': 43 | from fmap_ranking.rank_fmaps_ud_efb0 import rank_fmaps 44 | 45 | rank_fmaps(args.dataset_dir, args.gan_name, 46 | args.blur_jpg, args.have_classes, 47 | args.bsize, 48 | args.num_real, args.num_fake, 49 | args.save_pt_files, 50 | normalize_only_using_positive = False, 51 | topk=None) 52 | 53 | 54 | 55 | 56 | if __name__=='__main__': 57 | main() -------------------------------------------------------------------------------- /src/sensitivity_assessment/ap_sensitivity.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | # Import utils 16 | from utils.heatmap_helpers import * 17 | from utils.dataset_helpers import * 18 | from utils.general import * 19 | 20 | 21 | import copy 22 | from collections import OrderedDict 23 | import pandas as pd 24 | import random, json 25 | from tqdm import tqdm 26 | from utils.mask_fmaps import * 27 | 28 | 29 | 30 | def ap_sensitivity_analysis(arch, parent_dir, gan_name, aug, have_classes, bsize, num_instances=None, topk_list=None): 31 | ## ------------Set Parameters-------- 32 | # Define device and other parameters 33 | device = torch.device('cuda:0') 34 | 35 | # model and weights 36 | if arch == 'resnet50': 37 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 38 | model = get_resnet50_universal_detector(weightfn).to(device) 39 | 40 | elif arch == 'efb0': 41 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 42 | model = get_efb0_universal_detector(weightfn).to(device) 43 | 44 | root_dir = os.path.join(parent_dir, gan_name) 45 | 46 | # Dataset (Use same transforms as Wang et. al 47 | transform = transforms.Compose([ 48 | transforms.ToTensor(), 49 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 50 | ]) 51 | 52 | # Obtain D 53 | if have_classes: 54 | dl = get_classwise_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 55 | 56 | else: 57 | dl = get_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 58 | dl = [dl] 59 | 60 | df_ap = pd.DataFrame() 61 | 62 | # Original : Independent of topk 63 | original_results, (y_pred, y_true) = get_ap_and_acc(model, device, dl, threshold=0.5) 64 | 65 | # Collect topk filters 66 | for topk in topk_list: 67 | topk_channels, lowk_channels, all_channels = get_all_channels( 68 | fake_csv_path="fmap_relevances/{}/{}_{}/{}-fake.csv".format(arch, gan_name, aug, gan_name), 69 | topk=topk) 70 | 71 | print("Topk set to ", topk) 72 | assert topk == len(topk_channels), "mismatch" 73 | 74 | # Create model replicas by masking topk filters 75 | model_topk_masked, _ = mask_target_channels(copy.deepcopy(model), topk_channels) 76 | models_randomk_masked = [ mask_random_channels(copy.deepcopy(model), topk, topk_channels, all_channels )[0] for i in range(5)] 77 | model_lowk_masked, _ = mask_target_channels(copy.deepcopy(model), lowk_channels) 78 | 79 | # Topk : Now get 2 sets of results for each setup 80 | topk_masked_results, (y_pred, y_true) = get_ap_and_acc(model_topk_masked, device, dl, 0.5) 81 | 82 | # Randomk : Now get 2 sets of results for each setup 83 | randomk_masked_results, (y_pred, y_true) = get_ap_and_acc_random(models_randomk_masked, device, dl, 0.5) 84 | 85 | # Lowk : Now get 2 sets of results for each setup 86 | lowk_masked_results, (y_pred, y_true) = get_ap_and_acc(model_lowk_masked, device, dl, 0.5) 87 | 88 | record_ap = {'gan_name/topk': "{}/{}".format(gan_name, topk), 89 | 90 | 'original_ap': original_results[0], 91 | 'topk_masked_ap': topk_masked_results[0], 92 | 'randomk_masked_ap': randomk_masked_results[0], 93 | 'lowk_masked_ap': lowk_masked_results[0], 94 | 95 | } 96 | 97 | df_ap = df_ap.append(record_ap, ignore_index=True, sort=False)[list(record_ap.keys())] 98 | 99 | 100 | output_dir = os.path.join('output', 'k_vs_ap', arch, "{}_{}".format( gan_name, aug)) 101 | os.makedirs(output_dir, exist_ok=True) 102 | df_ap.to_csv('{}/{}-AP-#{}-samples-crop224.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 103 | 104 | 105 | -------------------------------------------------------------------------------- /src/sensitivity_assessment/color.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | from utils.heatmap_helpers import * 16 | 17 | 18 | from utils.dataset_helpers import * 19 | from utils.general import * 20 | import copy 21 | from collections import OrderedDict 22 | import pandas as pd 23 | import random, json 24 | from scipy import stats 25 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 26 | from tqdm import tqdm 27 | from utils.mask_fmaps import * 28 | 29 | from sklearn.metrics import roc_curve 30 | 31 | 32 | def all_metrics_sensitivity_analysis(arch, parent_dir, gan_name, aug, have_classes, bsize, num_instances=None): 33 | ## ------------Set Parameters-------- 34 | # Define device and other parameters 35 | device = torch.device('cuda:0') 36 | 37 | # model and weights 38 | if arch == 'resnet50': 39 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 40 | model = get_resnet50_universal_detector(weightfn).to(device) 41 | 42 | elif arch == 'efb0': 43 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 44 | model = get_efb0_universal_detector(weightfn).to(device) 45 | 46 | root_dir = os.path.join(parent_dir, gan_name) 47 | 48 | # Dataset (Use same transforms as Wang et. al 49 | transform = transforms.Compose([ 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 52 | ]) 53 | 54 | transform_grayscale = transforms.Compose([ 55 | transforms.Grayscale(3), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 58 | ]) 59 | 60 | 61 | # Obtain D 62 | if have_classes: 63 | dl = get_classwise_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 64 | dl_grayscale = get_classwise_dataloader(root_dir, have_classes, num_instances , transform_grayscale, bsize, onlyreal=False, onlyfake=False) 65 | 66 | else: 67 | dl = get_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 68 | dl_grayscale = get_dataloader(root_dir, have_classes, num_instances , transform_grayscale, bsize, onlyreal=False, onlyfake=False) 69 | dl = [dl] 70 | dl_grayscale = [ dl_grayscale ] 71 | 72 | df_ap = pd.DataFrame() 73 | df_uncalibrated_acc = pd.DataFrame() 74 | df_calibrated_acc = pd.DataFrame() 75 | 76 | # Calibrate GAN 77 | y_pred, y_true = get_probs(model, dl, device) 78 | original_calibrated_threshold = get_calibrated_thres(y_true, y_pred) 79 | 80 | # Original : Now get 2 sets of results for each setup (independent of topk) 81 | original_results, (y_pred, y_true) = get_ap_and_acc(model, device, dl, threshold=0.5) 82 | original_results_calibrated, (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred, y_true=y_true, threshold=original_calibrated_threshold, prefix='Baseline') 83 | 84 | # Grayscale : Now get 2 sets of results for each setup (independent of topk) 85 | grayscale_results, (y_pred_grayscale, y_true_grayscale) = get_ap_and_acc(model, device, dl_grayscale, threshold=0.5) 86 | grayscale_results_calibrated, (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred_grayscale, y_true=y_true_grayscale, threshold=original_calibrated_threshold, prefix='Grayscale') 87 | 88 | record_ap = {'gan_name': "{}".format(gan_name), 89 | 'original_ap': original_results[0], 90 | 'grayscale_ap': grayscale_results[0], 91 | 92 | } 93 | 94 | record_uncalibrated_acc = {'gan_name': "{}".format(gan_name), 95 | 96 | 'original_r_acc': original_results[1], 97 | 'original_f_acc': original_results[2], 98 | 'original_acc': original_results[3], 99 | 'original_y_real_mean': original_results[4], 100 | 'original_y_fake_mean': original_results[5], 101 | 102 | 'original_y_real_std': original_results[6], 103 | 'original_y_fake_std': original_results[7], 104 | 105 | 106 | 'grayscale_r_acc': grayscale_results[1], 107 | 'grayscale_f_acc': grayscale_results[2], 108 | 'grayscale_acc': grayscale_results[3], 109 | 'grayscale_y_real_mean': grayscale_results[4], 110 | 'grayscale_y_fake_mean': grayscale_results[5], 111 | 112 | 'grayscale_y_real_std': grayscale_results[6], 113 | 'grayscale_y_fake_std': grayscale_results[7], 114 | 115 | } 116 | 117 | 118 | record_calibrated_acc = {'gan_name/topk': "{}".format(gan_name), 119 | 'original_threshold_calibrated' : original_calibrated_threshold, 120 | 121 | 'original_r_acc_calibrated': original_results_calibrated[1], 122 | 'original_f_acc_calibrated': original_results_calibrated[2], 123 | 'original_acc_calibrated': original_results_calibrated[3], 124 | 125 | 'grayscale_r_acc_calibrated': grayscale_results_calibrated[1], 126 | 'grayscale_f_acc_calibrated': grayscale_results_calibrated[2], 127 | 'grayscale_acc_calibrated': grayscale_results_calibrated[3], 128 | 129 | } 130 | 131 | 132 | # Append to csv 133 | df_ap = df_ap.append(record_ap, ignore_index=True, sort=False)[list(record_ap.keys())] 134 | df_uncalibrated_acc = df_uncalibrated_acc.append(record_uncalibrated_acc, ignore_index=True, sort=False)[list(record_uncalibrated_acc.keys())] 135 | df_calibrated_acc = df_calibrated_acc.append(record_calibrated_acc, ignore_index=True, sort=False)[list(record_calibrated_acc.keys())] 136 | 137 | df_ap[df_ap.select_dtypes(include=['number']).columns] *= 100.0 138 | df_uncalibrated_acc[df_uncalibrated_acc.select_dtypes(include=['number']).columns] *= 100.0 139 | df_calibrated_acc[df_calibrated_acc.select_dtypes(include=['number']).columns] *= 100.0 140 | 141 | output_dir = os.path.join('output', 'grayscale-transfer', arch, '{}_{}'.format(gan_name, aug) ) 142 | os.makedirs(output_dir, exist_ok=True) 143 | df_ap.to_csv('{}/{}-AP-#{}-samples-no-crop.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 144 | df_uncalibrated_acc.to_csv('{}/{}-UNCALIBRATED_ACC-#{}-samples-no-crop.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 145 | df_calibrated_acc.to_csv('{}/{}-CALIBRATED_ACC-#{}-samples-no-crop.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 146 | 147 | return (y_pred, y_true), (y_pred_grayscale, y_true_grayscale) 148 | 149 | if __name__=='__main__': 150 | pass 151 | 152 | -------------------------------------------------------------------------------- /src/sensitivity_assessment/transferability.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | 16 | from utils.heatmap_helpers import * 17 | 18 | 19 | from utils.dataset_helpers import * 20 | from utils.general import * 21 | import copy 22 | from collections import OrderedDict 23 | import pandas as pd 24 | import random, json 25 | from scipy import stats 26 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 27 | from tqdm import tqdm 28 | from utils.mask_fmaps import * 29 | 30 | from sklearn.metrics import roc_curve 31 | 32 | from termcolor import colored 33 | 34 | 35 | def all_metrics_sensitivity_analysis(arch, parent_dir, gan_name, aug, have_classes, bsize, num_instances=None, topk_list=None): 36 | ## ------------Set Parameters-------- 37 | # Define device and other parameters 38 | device = torch.device('cuda:0') 39 | 40 | # model and weights 41 | if arch == 'resnet50': 42 | weightfn = './weights/resnet50/blur_jpg_prob{}.pth'.format(aug) 43 | model = get_resnet50_universal_detector(weightfn).to(device) 44 | 45 | elif arch == 'efb0': 46 | weightfn = './weights/efb0/blur_jpg_prob{}.pth'.format(aug) 47 | model = get_efb0_universal_detector(weightfn).to(device) 48 | 49 | root_dir = os.path.join(parent_dir, gan_name) 50 | 51 | # Dataset (Use same transforms as Wang et. al 52 | transform = transforms.Compose([ 53 | # transforms.Resize(256), 54 | # transforms.CenterCrop(224), 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 57 | ]) 58 | 59 | # Obtain dataloaders 60 | if have_classes: 61 | dl = get_classwise_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 62 | 63 | else: 64 | dl = get_dataloader(root_dir, have_classes, num_instances , transform, bsize, onlyreal=False, onlyfake=False) 65 | dl = [dl] 66 | 67 | df_ap = pd.DataFrame() 68 | df_uncalibrated_acc = pd.DataFrame() 69 | df_calibrated_acc = pd.DataFrame() 70 | 71 | # Calibrate threshold 72 | y_pred, y_true = get_probs(model, dl, device) 73 | 74 | if gan_name == 'progan': 75 | original_calibrated_threshold = 0.5 76 | else: 77 | original_calibrated_threshold = get_calibrated_thres(y_true, y_pred) 78 | 79 | # Original : Now get 2 sets of results for each setup (independent of topk) 80 | print(colored('\n> Sensitivity assessments using feature map dropout', 'cyan')) 81 | original_results, (y_pred, y_true) = get_ap_and_acc(model, device, dl, threshold=0.5) 82 | original_results_calibrated , (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred, y_true=y_true, threshold=original_calibrated_threshold, prefix='baseline') 83 | 84 | # Collect topk filters 85 | for top_index, topk in enumerate(topk_list): 86 | #print("Topk set to ", topk) 87 | topk_channels, lowk_channels, all_channels = get_all_channels( 88 | fake_csv_path="fmap_relevances/{}/progan_val_{}/progan_val-fake.csv".format(arch, aug), 89 | topk=topk) 90 | 91 | # Create model replicas by masking topk filters 92 | model_topk_masked, _ = mask_target_channels(copy.deepcopy(model), topk_channels) 93 | models_randomk_masked = [ mask_random_channels(copy.deepcopy(model), topk, topk_channels, all_channels )[0] for i in range(5)] 94 | model_lowk_masked, _ = mask_target_channels(copy.deepcopy(model), lowk_channels) 95 | 96 | print(colored('---------------------', 'cyan')) 97 | #print('---------------------') 98 | # Topk : Now get 2 sets of results for each setup 99 | topk_masked_results, (y_pred, y_true) = get_ap_and_acc(model_topk_masked, device, dl, 0.5) 100 | topk_masked_results_calibrated , (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred, y_true=y_true, threshold=original_calibrated_threshold, prefix='top-k') 101 | 102 | # Randomk : Now get 2 sets of results for each setup 103 | randomk_masked_results, (y_pred, y_true) = get_ap_and_acc_random(models_randomk_masked, device, dl, 0.5) 104 | randomk_masked_results_calibrated , (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred, y_true=y_true, threshold=original_calibrated_threshold, prefix='random-k') 105 | 106 | # Lowk : Now get 2 sets of results for each setup 107 | lowk_masked_results, (y_pred, y_true) = get_ap_and_acc(model_lowk_masked, device, dl, 0.5) 108 | lowk_masked_results_calibrated , (_, _) = get_ap_and_acc_with_new_threshold(y_pred=y_pred, y_true=y_true, threshold=original_calibrated_threshold, prefix='low-k') 109 | 110 | record_ap = {'gan_name/topk': "{}/{}".format(gan_name, topk), 111 | 112 | 'original_ap': original_results[0], 113 | 'topk_masked_ap': topk_masked_results[0], 114 | 'randomk_masked_ap': randomk_masked_results[0], 115 | 'lowk_masked_ap': lowk_masked_results[0], 116 | 117 | } 118 | 119 | record_calibrated_acc = {'gan_name/topk': "{}/{}".format(gan_name, topk), 120 | 'original_threshold_calibrated' : original_calibrated_threshold, 121 | 122 | 'original_r_acc_calibrated': original_results_calibrated[1], 123 | 'original_f_acc_calibrated': original_results_calibrated[2], 124 | 'original_acc_calibrated': original_results_calibrated[3], 125 | 126 | 'topk_masked_r_acc_calibrated': topk_masked_results_calibrated[1], 127 | 'topk_masked_f_acc_calibrated': topk_masked_results_calibrated[2], 128 | 'topk_masked_acc_calibrated': topk_masked_results_calibrated[3], 129 | 130 | 131 | 'randomk_masked_r_acc_calibrated': randomk_masked_results_calibrated[1], 132 | 'randomk_masked_f_acc_calibrated': randomk_masked_results_calibrated[2], 133 | 'randomk_masked_acc_calibrated': randomk_masked_results_calibrated[3], 134 | 135 | 136 | 'lowk_masked_r_acc_calibrated': lowk_masked_results_calibrated[1], 137 | 'lowk_masked_f_acc_calibrated': lowk_masked_results_calibrated[2], 138 | 'lowk_masked_acc_calibrated': lowk_masked_results_calibrated[3], 139 | 140 | } 141 | 142 | # Append to csv 143 | df_ap = df_ap.append(record_ap, ignore_index=True, sort=False)[list(record_ap.keys())] 144 | df_calibrated_acc = df_calibrated_acc.append(record_calibrated_acc, ignore_index=True, sort=False)[list(record_calibrated_acc.keys())] 145 | 146 | df_ap[df_ap.select_dtypes(include=['number']).columns] *= 100.0 147 | df_calibrated_acc[df_calibrated_acc.select_dtypes(include=['number']).columns] *= 100.0 148 | 149 | output_dir = os.path.join('output', 'transfer', arch, '{}_{}'.format(gan_name, aug) ) 150 | os.makedirs(output_dir, exist_ok=True) 151 | df_ap.to_csv('{}/{}-AP-#{}-samples-crop.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 152 | df_calibrated_acc.to_csv('{}/{}-CALIBRATED_ACC-#{}-samples-no-crop.csv'.format(output_dir, gan_name, dl[0].dataset.__len__()), index=None) 153 | 154 | 155 | if __name__=='__main__': 156 | pass 157 | 158 | -------------------------------------------------------------------------------- /src/transfer_sensitivity_analysis.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from sensitivity_assessment.transferability import all_metrics_sensitivity_analysis 3 | import math 4 | 5 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 6 | 7 | # architecture 8 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 9 | 10 | # Data augmentation of model 11 | #parser.add_argument('--blur_jpg', type=float, required=True, choices=[0.1, 0.5]) 12 | parser.add_argument('--blur_jpg', type=str, required=True) 13 | 14 | # other metrics 15 | parser.add_argument('--bsize', type=int, default=16) 16 | 17 | # dataset 18 | parser.add_argument('--dataset_dir', type=str, default='/mnt/data/v2.0_CNN_synth_testset/') 19 | parser.add_argument('--gan_name', type=str, nargs='+', default='progan_val') 20 | parser.add_argument('--have_classes', type=int, default=1) 21 | 22 | # Images per class for identifying channels 23 | parser.add_argument('--num_instances', type=int, default=5) 24 | 25 | # topk 26 | parser.add_argument('--topk', type=int, default=114) 27 | 28 | 29 | def main(): 30 | args = parser.parse_args() 31 | 32 | if(type(args.gan_name)==str): 33 | args.gan_name = [args.gan_name,] 34 | 35 | args.have_classes = bool(args.have_classes) 36 | 37 | print(args.gan_name, args.have_classes) 38 | 39 | for gan_name in args.gan_name: 40 | all_metrics_sensitivity_analysis(args.arch, 41 | args.dataset_dir, gan_name, args.blur_jpg, 42 | args.have_classes, args.bsize, num_instances=None, topk_list=[args.topk]) 43 | 44 | 45 | if __name__=='__main__': 46 | main() -------------------------------------------------------------------------------- /src/ud_lrp.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | # Import lrp modules 16 | from lrp.resnet_lrp_general import * 17 | from lrp.resnet_wrapper import * 18 | 19 | # Import utils 20 | from utils.heatmap_helpers import * 21 | from utils.dataset_helpers import * 22 | from utils import * 23 | 24 | # Import other modules 25 | import copy 26 | from collections import OrderedDict 27 | import pandas as pd 28 | 29 | 30 | import argparse 31 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 32 | 33 | # architecture 34 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 35 | 36 | # architecture 37 | parser.add_argument('--classifier', type=str, required=True, choices=['imagenet', 'ud']) 38 | 39 | args = parser.parse_args() 40 | 41 | 42 | 43 | def get_wrapped_resnet50(weightpath, key, device): 44 | """ 45 | Get Wrapped ResNet50 model loaded into the device. Written by Prof. Alex. 46 | 47 | Args: 48 | weightspath : path of berkeley classifier weights 49 | key : LRP key 50 | device : cuda or cpu to store the model 51 | 52 | Returns resnet50 pytorch object 53 | """ 54 | 55 | if key == 'beta0': 56 | #beta0 57 | lrp_params_def1={ 58 | 'conv2d_ignorebias': True, 59 | 'eltwise_eps': 1e-6, 60 | 'linear_eps': 1e-6, 61 | 'pooling_eps': 1e-6, 62 | 'use_zbeta': True , 63 | } 64 | 65 | lrp_layer2method={ 66 | 'nn.ReLU': relu_wrapper_fct, 67 | 'nn.BatchNorm2d': relu_wrapper_fct, 68 | 'nn.Conv2d': conv2d_beta0_wrapper_fct, 69 | 'nn.Linear': linearlayer_eps_wrapper_fct, 70 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 71 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 72 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 73 | } 74 | 75 | elif key == 'beta1': 76 | #beta1 77 | lrp_params_def1={ 78 | 'conv2d_ignorebias': True, 79 | 'eltwise_eps': 1e-6, 80 | 'linear_eps': 1e-6, 81 | 'pooling_eps': 1e-6, 82 | 'use_zbeta': True , 83 | 'conv2d_beta': 1.0, 84 | } 85 | 86 | lrp_layer2method={ 87 | 'nn.ReLU': relu_wrapper_fct, 88 | 'nn.BatchNorm2d': relu_wrapper_fct, 89 | 'nn.Conv2d': conv2d_betaany_wrapper_fct, 90 | 'nn.Linear': linearlayer_eps_wrapper_fct, 91 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 92 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 93 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 94 | } 95 | 96 | elif key == 'betaada': 97 | # betaada 98 | lrp_params_def1={ 99 | 'conv2d_ignorebias': True, 100 | 'eltwise_eps': 1e-6, 101 | 'linear_eps': 1e-6, 102 | 'pooling_eps': 1e-6, 103 | 'use_zbeta': True , 104 | 'conv2d_maxbeta': 4.5, 105 | } 106 | 107 | lrp_layer2method={ 108 | 'nn.ReLU': relu_wrapper_fct, 109 | 'nn.BatchNorm2d': relu_wrapper_fct, 110 | 'nn.Conv2d': conv2d_betaadaptive_wrapper_fct, 111 | 'nn.Linear': linearlayer_eps_wrapper_fct, 112 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 113 | 'nn.MaxPool2d': maxpool2d_wrapper_fct, 114 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 115 | } 116 | else: 117 | raise NotImplementedError("Unknown key", key) 118 | 119 | model0 = models.resnet50(pretrained=False) 120 | model0.fc = nn.Linear(2048,1) 121 | 122 | somedict = torch.load(weightpath) 123 | model0.load_state_dict( somedict['model'] ) 124 | 125 | model_e = resnet50_canonized(pretrained=False) 126 | model_e.copyfromresnet(model0, lrp_params = lrp_params_def1, lrp_layer2method = lrp_layer2method ) 127 | model_e = model_e.to(device) 128 | 129 | return model_e 130 | 131 | 132 | 133 | def get_lrp_explanations_for_batch(model, 134 | imagetensor, label, 135 | relfname , save, outpath, minus_fx=False): 136 | """ 137 | Get LRP explanations for a single sample. 138 | 139 | Args: 140 | model : pytorch model 141 | imagetensor : images 142 | label : label 143 | relfname : filenames 144 | save : If True, save LRP explanations 145 | outpath : output path to save pixel space explanations 146 | 147 | Get dict with positive relevances for the sample 148 | """ 149 | model.eval() 150 | 151 | all_lrp_explanations = [] 152 | 153 | os.makedirs(outpath, exist_ok=True) 154 | 155 | if imagetensor.grad is not None: 156 | imagetensor.grad.zero_() 157 | 158 | imagetensor.requires_grad=True # gxinp needs it here 159 | 160 | with torch.enable_grad(): 161 | outputs = model(imagetensor) 162 | 163 | with torch.no_grad(): 164 | probs = outputs.sigmoid().flatten() 165 | preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 166 | correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 167 | #print(correct_pred_indices, correct_pred_indices.size(0)) 168 | #print(outputs, outputs[correct_pred_indices, :]) 169 | 170 | # if not correct_pred_indices.size(0) > 0: 171 | # return all_lrp_explanations 172 | 173 | #Propagate the signals for the correctly predicted samples for LRP (We should get the same LRP results if we use all samples as well.) 174 | with torch.enable_grad(): 175 | if minus_fx: 176 | z = torch.sum( -outputs[correct_pred_indices, :] ) # Explain -f(x) if images are real 177 | 178 | else: 179 | z = torch.sum( outputs[correct_pred_indices, :] ) # Explain f(x) if images are fake 180 | 181 | with torch.no_grad(): 182 | z.backward(retain_graph=True) 183 | rel = imagetensor.grad.data.clone() 184 | 185 | for b in range(imagetensor.shape[0]): 186 | # Check for correct preds and skip incorrect preds 187 | # cond = (probs[b].item() >= 0.5 and label[b].item() == 1) or (probs[b].item() < 0.5 and label[b].item() == 0) 188 | 189 | # if not cond: 190 | # continue 191 | 192 | fn = relfname[b] 193 | lrp_explanations = {} 194 | lrp_explanations['relfname'] = relfname[b] 195 | lrp_explanations['prob'] = probs[b].item() 196 | 197 | for i, (name, mod) in enumerate(model.named_modules()): 198 | if hasattr(mod, 'relfromoutput'): 199 | v = getattr(mod, 'relfromoutput') 200 | #print(i, name, v.shape) # conv rel map 201 | 202 | ftrelevances = v[b,:] 203 | 204 | # take only positives 205 | #ftrelevances[ftrelevances<0] = 0 206 | 207 | # Save feature relevances to LRP explanations dict. Move to cpu since data is big. 208 | #lrp_explanations[name] = ftrelevances.detach().cpu() 209 | 210 | # All LRP explanations 211 | if label[b].item() == 0: 212 | vis_dir_name = os.path.join(outpath, "visualization", "0_real") 213 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 214 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 215 | print(vis_fname) 216 | os.makedirs(vis_dir_name, exist_ok=True) 217 | 218 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 219 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 220 | q=100, outname=vis_fname ) 221 | 222 | # Store LRP values 223 | if save: 224 | lrp_dir_name = os.path.join(outpath, "lrp", "0_real") 225 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 226 | os.makedirs(lrp_dir_name, exist_ok=True) 227 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 228 | 229 | else: 230 | vis_dir_name = os.path.join(outpath, "visualization", "1_fake") 231 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 232 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 233 | os.makedirs(vis_dir_name, exist_ok=True) 234 | 235 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 236 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 237 | q=100, outname=vis_fname) 238 | 239 | if save: 240 | lrp_dir_name = os.path.join(outpath, "lrp", "1_fake") 241 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 242 | os.makedirs(lrp_dir_name, exist_ok=True) 243 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 244 | 245 | #all_lrp_explanations.append(lrp_explanations) 246 | 247 | torch.cuda.empty_cache() 248 | gc.collect() 249 | 250 | # del ftrelevances 251 | 252 | return all_lrp_explanations 253 | 254 | 255 | def get_all_lrp_positive_explanations(model, dataloader, device, outpath, save, minus_fx): 256 | """ 257 | Get all LRP explanations for one folder. 258 | 259 | Args: 260 | model : resnet50 pytorch model 261 | dataloader : pytorch dataloader 262 | device : 263 | outpath : output path to save visualization and lrp numpy files 264 | save : If set to True, save 265 | minus_fx : If set to True, we will use -f(x) signal to calculate relevances 266 | 267 | Returns all LRP explanations 268 | """ 269 | # Global variable to store feature map information 270 | all_lrp_explanations = [] 271 | 272 | # > Explain prediction 273 | for index, data in enumerate(dataloader): 274 | # Get image tensor, filename, stub and labels 275 | imagetensors = data['image'].to(device) 276 | fnames = data['filename'] 277 | relfnames = data['relfilestub'] 278 | labels = data['label'].to(device) 279 | 280 | # Get LRP explanations 281 | lrp_explanations = get_lrp_explanations_for_batch(model, imagetensors, labels, relfnames, save, outpath, minus_fx) 282 | 283 | # Get LRP heatmap for all layers 284 | all_lrp_explanations.extend(lrp_explanations) 285 | 286 | torch.cuda.empty_cache() 287 | gc.collect() 288 | del lrp_explanations 289 | 290 | return all_lrp_explanations 291 | 292 | 293 | 294 | def pipeline(model_e, dl, device, outpath, save, 295 | num_instances, 296 | minus_fx): 297 | """ 298 | Pipeline to run overall algorithm for real or fake images 299 | 300 | Args: 301 | model_e : Wrapped resnet50 302 | dl : dataloader 303 | device : cuda/ cpu 304 | outpath : output path to save the channelwise stats 305 | save : If set to True, all LRP relevances are saved as .pt files 306 | minus_fx: Needs to be set to True for real images 307 | normalize_using_only_positive : If set to True, scheme 1 else scheme 2. 308 | topk : #topk feature maps to return. 309 | """ 310 | # Get LRP explanations 311 | all_lrp_explanations = get_all_lrp_positive_explanations(model_e, dl, device, outpath, save, minus_fx)[:num_instances] 312 | 313 | #print(final_channelwise_stats) 314 | return all_lrp_explanations 315 | 316 | 317 | 318 | def main(): 319 | ## ------------Set Parameters-------- 320 | # Define device and other parameters 321 | device = torch.device('cuda:0') 322 | bsize = 16 323 | 324 | # LRP model keys and weights 325 | key = 'beta0' # 'beta0' 'beta1' , 'betaada' 326 | weightfn = './weights/resnet50/blur_jpg_prob0.5.pth' 327 | 328 | # Directories 329 | #parent_dir = '/mnt/workspace/projects/deepfake_classifiers_interpretability/samples/' 330 | parent_dir = '/mnt/data/CNN_synth_testset/' # Use our version 331 | have_classes = False 332 | gan_and_classes = {} 333 | # gan_and_classes['biggan'] = ['beer_bottle', 'monitor', 'vase', 'table_lamp', 334 | # 'hummingbird', 'church', 'egyptian_cat', 'welsh_springer_spaniel'] 335 | 336 | #gan_and_classes['progan_test'] = os.listdir(os.path.join(parent_dir, 'progan_test')) 337 | #gan_and_classes['progan_test'] = [ 'boat' ] 338 | #gan_and_classes['biggan'] = ['bird'] 339 | gan_and_classes['stylegan2'] = ['church', 'car', 'cat'] 340 | #gan_and_classes['stylegan'] = ['car'] 341 | #gan_and_classes['cyclegan'] = ['horse'] 342 | #gan_and_classes['stargan'] = ['person'] 343 | #gan_and_classes['gaugan'] = ['mscoco'] 344 | 345 | #gan_and_classes['san'] = [''] 346 | #gan_and_classes['stylegan2'] = ['car', 'cat', 'horse', 'church'] 347 | #gan_and_classes['cyclegan'] = ['horse'] 348 | #gan_and_classes['stargan'] = [''] 349 | 350 | for gan_name in gan_and_classes: 351 | for clss in gan_and_classes[gan_name]: 352 | print(gan_name, clss) 353 | root_dir = os.path.join(parent_dir, gan_name, clss) 354 | #clss='sr' 355 | outpath = './output/hms/lrp_heatmaps_{}_{}/{}/{}/'.format(args.arch, args.classifier, gan_name, clss) 356 | save_pt_files = False # No need to save .pt files. 357 | num_instances_real, num_instances_fake = 1, 500 # Use 1000 real and fake samples for analysis 358 | ## ------------End of Parameters-------- 359 | 360 | 361 | # Model 362 | model_e = get_wrapped_resnet50(weightfn, key, device) 363 | 364 | def writeintomodule_bwhook(self,grad_input, grad_output): 365 | #gradoutput is what arrives from above, shape id eq to output 366 | setattr(self,'relfromoutput', grad_output[0]) 367 | 368 | # Register hook 369 | for i, (name,mod) in enumerate(model_e.named_modules()): 370 | #print(i,nm) 371 | if ('conv' in name) and ('module' not in name): 372 | #print('ok') 373 | mod.register_backward_hook(writeintomodule_bwhook) 374 | 375 | # Dataset (Use same transforms as Wang et. al without Center cropping) 376 | transform = transforms.Compose([ 377 | transforms.Resize(256), 378 | transforms.CenterCrop(224), 379 | transforms.ToTensor(), 380 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 381 | ]) 382 | 383 | # Obtain D_real and D_fake 384 | dl_real = get_dataloader(root_dir, have_classes, num_instances_real , transform, bsize, onlyreal=True, onlyfake=False) 385 | dl_fake = get_dataloader(root_dir, have_classes, num_instances_fake , transform, bsize, onlyreal=False, onlyfake=True) 386 | 387 | # Pass to the overall algorithm pipeline to obtain C_real_topk, C_fake_topk 388 | real_lrp_explanations = pipeline(model_e, dl_real, device, outpath, save=save_pt_files, 389 | num_instances=num_instances_real, 390 | minus_fx=True) 391 | fake_lrp_explanations = pipeline(model_e, dl_fake, device, outpath, save=save_pt_files, 392 | num_instances=num_instances_fake, 393 | minus_fx=False) 394 | 395 | 396 | 397 | 398 | return real_lrp_explanations, fake_lrp_explanations 399 | 400 | 401 | if __name__=='__main__': 402 | main() 403 | -------------------------------------------------------------------------------- /src/ud_lrp_efb0.py: -------------------------------------------------------------------------------- 1 | # Import generic libraries 2 | import os, sys, math, gc 3 | import PIL 4 | 5 | # Import scientific computing libraries 6 | import numpy as np 7 | 8 | # Import torch and dependencies 9 | import torch 10 | from torchvision import models, transforms 11 | 12 | # Import other libraries 13 | import matplotlib.pyplot as plt 14 | 15 | from efficientnet_pytorch import EfficientNet 16 | # from efficientnet_pytorch.utils import load_pretrained_weights 17 | 18 | # Import LRP modules 19 | from utils.heatmap_helpers import * 20 | from lrp.ef_lrp_general import * 21 | from lrp.ef_wrapper import * 22 | 23 | # Import utils 24 | from utils.heatmap_helpers import * 25 | from utils.dataset_helpers import * 26 | from utils import * 27 | 28 | # Import other modules 29 | import copy 30 | from collections import OrderedDict 31 | import pandas as pd 32 | 33 | 34 | import argparse 35 | parser = argparse.ArgumentParser(description='Rank feature maps of universal detectors...') 36 | 37 | # architecture 38 | parser.add_argument('--arch', type=str, required=True, choices=['resnet50', 'efb0']) 39 | 40 | # architecture 41 | parser.add_argument('--classifier', type=str, required=True, choices=['imagenet', 'ud']) 42 | 43 | args = parser.parse_args() 44 | 45 | 46 | def get_wrapped_efficientnet_b0(weightpath, key, device, classifier): 47 | """ 48 | Get Wrapped ResNet50 model loaded into the device. 49 | 50 | Args: 51 | weightspath : path of berkeley classifier weights 52 | key : LRP key 53 | device : cuda or cpu to store the model 54 | 55 | Returns resnet50 pytorch object 56 | """ 57 | 58 | if key == 'beta0': 59 | #beta0 60 | lrp_params_def1={ 61 | 'conv2d_ignorebias': True, 62 | 'eltwise_eps': 1e-6, 63 | 'linear_eps': 1e-6, 64 | 'pooling_eps': 1e-6, 65 | 'use_zbeta': False , 66 | } 67 | 68 | lrp_layer2method={ 69 | 'Swish': relu_wrapper_fct, 70 | 'nn.BatchNorm2d': relu_wrapper_fct, 71 | 'nn.Conv2d': Conv2dDynamicSamePadding_beta0_wrapper_fct, 72 | 'nn.Linear': linearlayer_eps_wrapper_fct, 73 | 'nn.AdaptiveAvgPool2d': adaptiveavgpool2d_wrapper_fct, 74 | 'sum_stacked2': eltwisesum_stacked2_eps_wrapper_fct, 75 | } 76 | 77 | elif key == 'beta1': 78 | pass 79 | elif key == 'betaada': 80 | pass 81 | else: 82 | raise NotImplementedError("Unknown key", key) 83 | 84 | if classifier == 'ud': 85 | 86 | model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None,) 87 | somedict = torch.load(weightpath) 88 | model0.load_state_dict( somedict['model'] ) 89 | model0.eval() 90 | #load_pretrained_weights(model0, 'efficientnet-b0', weightpath) 91 | #model0.set_swish( memory_efficient=False) 92 | 93 | model_e = EfficientNet_canonized.from_pretrained('efficientnet-b0', num_classes=1, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 94 | #print(model_e) 95 | #model_e.set_swish( memory_efficient=False) 96 | 97 | model_e.copyfromefficientnet( model0, lrp_params_def1, lrp_layer2method) 98 | model_e.to(device) 99 | 100 | return model_e 101 | 102 | else: 103 | model0 = EfficientNet.from_pretrained('efficientnet-b0', image_size=None) 104 | model0.eval() 105 | #load_pretrained_weights(model0, 'efficientnet-b0', weightpath) 106 | #model0.set_swish( memory_efficient=False) 107 | 108 | model_e = EfficientNet_canonized.from_pretrained('efficientnet-b0', num_classes=1000, image_size=None, dropout_rate= 0.0 , drop_connect_rate=0.0) 109 | #print(model_e) 110 | #model_e.set_swish( memory_efficient=False) 111 | 112 | model_e.copyfromefficientnet( model0, lrp_params_def1, lrp_layer2method) 113 | model_e.to(device) 114 | 115 | return model_e 116 | 117 | 118 | 119 | 120 | def get_lrp_explanations_for_batch(model, 121 | imagetensor, label, 122 | relfname , save, outpath, minus_fx=False): 123 | """ 124 | Get LRP explanations for a single sample. 125 | 126 | Args: 127 | model : pytorch model 128 | imagetensor : images 129 | label : label 130 | relfname : filenames 131 | save : If True, save LRP explanations 132 | outpath : output path to save pixel space explanations 133 | 134 | Get dict with positive relevances for the sample 135 | """ 136 | model.eval() 137 | 138 | all_lrp_explanations = [] 139 | 140 | os.makedirs(outpath, exist_ok=True) 141 | 142 | if imagetensor.grad is not None: 143 | imagetensor.grad.zero_() 144 | 145 | imagetensor.requires_grad=True # gxinp needs it here 146 | 147 | with torch.enable_grad(): 148 | outputs = model(imagetensor) 149 | 150 | with torch.no_grad(): 151 | probs = outputs.sigmoid().flatten() 152 | preds_labels = torch.where(probs>0.5, 1.0, 0.0).long() 153 | correct_pred_indices = torch.where(torch.eq(preds_labels, label))[0] 154 | #print(correct_pred_indices, correct_pred_indices.size(0)) 155 | #print(outputs, outputs[correct_pred_indices, :]) 156 | 157 | # if not correct_pred_indices.size(0) > 0: 158 | # return all_lrp_explanations 159 | 160 | 161 | 162 | #Propagate the signals for the correctly predicted samples for LRP (We should get the same LRP results if we use all samples as well.) 163 | with torch.enable_grad(): 164 | if minus_fx: 165 | z = torch.sum( -outputs[correct_pred_indices, :] ) # Explain -f(x) if images are real 166 | 167 | else: 168 | z = torch.sum( outputs[correct_pred_indices, :] ) # Explain f(x) if images are fake 169 | 170 | 171 | with torch.no_grad(): 172 | z.backward(retain_graph=True) 173 | rel = imagetensor.grad.data.clone() 174 | 175 | for b in range(imagetensor.shape[0]): 176 | # Check for correct preds and skip incorrect preds 177 | # cond = (probs[b].item() >= 0.5 and label[b].item() == 1) or (probs[b].item() < 0.5 and label[b].item() == 0) 178 | 179 | # if not cond: 180 | # continue 181 | 182 | fn = relfname[b] 183 | lrp_explanations = {} 184 | lrp_explanations['relfname'] = relfname[b] 185 | lrp_explanations['prob'] = probs[b].item() 186 | 187 | for i, (name, mod) in enumerate(model.named_modules()): 188 | if hasattr(mod, 'relfromoutput'): 189 | v = getattr(mod, 'relfromoutput') 190 | #print(i, name, v.shape) # conv rel map 191 | 192 | ftrelevances = v[b,:] 193 | 194 | # take only positives 195 | #ftrelevances[ftrelevances<0] = 0 196 | 197 | # Save feature relevances to LRP explanations dict. Move to cpu since data is big. 198 | #lrp_explanations[name] = ftrelevances.detach().cpu() 199 | 200 | # All LRP explanations 201 | if label[b].item() == 0: 202 | vis_dir_name = os.path.join(outpath, "visualization", "0_real") 203 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 204 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 205 | print(vis_fname) 206 | os.makedirs(vis_dir_name, exist_ok=True) 207 | 208 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 209 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 210 | q=100, outname=vis_fname ) 211 | 212 | # Store LRP values 213 | if save: 214 | lrp_dir_name = os.path.join(outpath, "lrp", "0_real") 215 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 216 | os.makedirs(lrp_dir_name, exist_ok=True) 217 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 218 | 219 | else: 220 | vis_dir_name = os.path.join(outpath, "visualization", "1_fake") 221 | vis_fname = os.path.join( vis_dir_name, fn.replace('/', '-').replace('.png', '-p={:.3f}.pdf'.format(probs[b].item())) ) 222 | vis_fname = vis_fname.replace('.JPEG', '.pdf') 223 | os.makedirs(vis_dir_name, exist_ok=True) 224 | 225 | save_img_lrp_overlay_only_positive(rel[b].to('cpu'), imagetensor[b].to('cpu'), 226 | title="Label: {}, prob :{:.3f}".format( label[b].item(), probs[b].item() ), 227 | q=100, outname=vis_fname) 228 | 229 | if save: 230 | lrp_dir_name = os.path.join(outpath, "lrp", "1_fake") 231 | lrp_fname = os.path.join(lrp_dir_name, fn.replace('/', '-') + '.pt') 232 | os.makedirs(lrp_dir_name, exist_ok=True) 233 | torch.save(torch.sum(rel[b], dim=0).cpu(), lrp_fname) 234 | 235 | #all_lrp_explanations.append(lrp_explanations) 236 | 237 | torch.cuda.empty_cache() 238 | gc.collect() 239 | 240 | # del ftrelevances 241 | 242 | return all_lrp_explanations 243 | 244 | 245 | def get_all_lrp_positive_explanations(model, dataloader, device, outpath, save, minus_fx): 246 | """ 247 | Get all LRP explanations for one folder. 248 | 249 | Args: 250 | model : resnet50 pytorch model 251 | dataloader : pytorch dataloader 252 | device : 253 | outpath : output path to save visualization and lrp numpy files 254 | save : If set to True, save 255 | minus_fx : If set to True, we will use -f(x) signal to calculate relevances 256 | 257 | Returns all LRP explanations 258 | """ 259 | # Global variable to store feature map information 260 | all_lrp_explanations = [] 261 | 262 | # > Explain prediction 263 | for index, data in enumerate(dataloader): 264 | # Get image tensor, filename, stub and labels 265 | imagetensors = data['image'].to(device) 266 | fnames = data['filename'] 267 | relfnames = data['relfilestub'] 268 | labels = data['label'].to(device) 269 | 270 | # Get LRP explanations 271 | lrp_explanations = get_lrp_explanations_for_batch(model, imagetensors, labels, relfnames, save, outpath, minus_fx) 272 | 273 | # Get LRP heatmap for all layers 274 | all_lrp_explanations.extend(lrp_explanations) 275 | 276 | torch.cuda.empty_cache() 277 | gc.collect() 278 | del lrp_explanations 279 | 280 | return all_lrp_explanations 281 | 282 | 283 | 284 | def pipeline(model_e, dl, device, outpath, save, 285 | num_instances, 286 | minus_fx): 287 | """ 288 | Pipeline to run overall algorithm for real or fake images 289 | 290 | Args: 291 | model_e : Wrapped resnet50 292 | dl : dataloader 293 | device : cuda/ cpu 294 | outpath : output path to save the channelwise stats 295 | save : If set to True, all LRP relevances are saved as .pt files 296 | minus_fx: Needs to be set to True for real images 297 | normalize_using_only_positive : If set to True, scheme 1 else scheme 2. 298 | topk : #topk feature maps to return. 299 | """ 300 | # Get LRP explanations 301 | all_lrp_explanations = get_all_lrp_positive_explanations(model_e, dl, device, outpath, save, minus_fx)[:num_instances] 302 | 303 | #print(final_channelwise_stats) 304 | return all_lrp_explanations 305 | 306 | 307 | 308 | def main(): 309 | ## ------------Set Parameters-------- 310 | # Define device and other parameters 311 | device = torch.device('cuda:0') 312 | bsize = 16 313 | 314 | # LRP model keys and weights 315 | key = 'beta0' # 'beta0' 'beta1' , 'betaada' 316 | weightfn = './weights/{}/blur_jpg_prob0.5.pth'.format(args.arch) 317 | 318 | # Directories 319 | #parent_dir = '/mnt/workspace/projects/deepfake_classifiers_interpretability/samples/' 320 | parent_dir = '/mnt/data/CNN_synth_testset/' # Use our version 321 | have_classes = False 322 | gan_and_classes = {} 323 | # gan_and_classes['biggan'] = ['beer_bottle', 'monitor', 'vase', 'table_lamp', 324 | # 'hummingbird', 'church', 'egyptian_cat', 'welsh_springer_spaniel'] 325 | 326 | #gan_and_classes['progan_test'] = os.listdir(os.path.join(parent_dir, 'progan_test')) 327 | #gan_and_classes['progan_test'] = [ 'boat' ] 328 | #gan_and_classes['biggan'] = ['bird'] 329 | gan_and_classes['stylegan2'] = ['church', 'car', 'cat'] 330 | #gan_and_classes['stylegan'] = ['car'] 331 | #gan_and_classes['cyclegan'] = ['horse'] 332 | #gan_and_classes['stargan'] = ['person'] 333 | #gan_and_classes['gaugan'] = ['mscoco'] 334 | 335 | #gan_and_classes['san'] = [''] 336 | #gan_and_classes['stylegan2'] = ['car', 'cat', 'horse', 'church'] 337 | #gan_and_classes['cyclegan'] = ['horse'] 338 | #gan_and_classes['stargan'] = [''] 339 | 340 | 341 | for gan_name in gan_and_classes: 342 | for clss in gan_and_classes[gan_name]: 343 | print(gan_name, clss) 344 | root_dir = os.path.join(parent_dir, gan_name, clss) 345 | #clss='sr' 346 | outpath = './output/hms/lrp_heatmaps_{}_{}/{}/{}/'.format(args.arch, args.classifier, gan_name, clss) 347 | save_pt_files = False # No need to save .pt files. 348 | num_instances_real, num_instances_fake = 1, 500 # Use 1000 real and fake samples for analysis 349 | ## ------------End of Parameters-------- 350 | 351 | 352 | # Model 353 | model_e = get_wrapped_efficientnet_b0(weightfn, key, device, args.classifier) 354 | 355 | def writeintomodule_bwhook(self,grad_input, grad_output): 356 | #gradoutput is what arrives from above, shape id eq to output 357 | setattr(self,'relfromoutput', grad_output[0]) 358 | 359 | # Register hook 360 | for i, (name,mod) in enumerate(model_e.named_modules()): 361 | #print(i,nm) 362 | if ('conv' in name) and ('module' not in name): 363 | #print('ok') 364 | mod.register_backward_hook(writeintomodule_bwhook) 365 | 366 | # Dataset (Use same transforms as Wang et. al without Center cropping) 367 | transform = transforms.Compose([ 368 | transforms.Resize(256), 369 | transforms.CenterCrop(224), 370 | transforms.ToTensor(), 371 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 372 | ]) 373 | 374 | # Obtain D_real and D_fake 375 | dl_real = get_dataloader(root_dir, have_classes, num_instances_real , transform, bsize, onlyreal=True, onlyfake=False) 376 | dl_fake = get_dataloader(root_dir, have_classes, num_instances_fake , transform, bsize, onlyreal=False, onlyfake=True) 377 | 378 | # Pass to the overall algorithm pipeline to obtain C_real_topk, C_fake_topk 379 | real_lrp_explanations = pipeline(model_e, dl_real, device, outpath, save=save_pt_files, 380 | num_instances=num_instances_real, 381 | minus_fx=True) 382 | fake_lrp_explanations = pipeline(model_e, dl_fake, device, outpath, save=save_pt_files, 383 | num_instances=num_instances_fake, 384 | minus_fx=False) 385 | 386 | 387 | 388 | 389 | return real_lrp_explanations, fake_lrp_explanations 390 | 391 | 392 | if __name__=='__main__': 393 | main() 394 | -------------------------------------------------------------------------------- /src/utils/dataset_helpers.py: -------------------------------------------------------------------------------- 1 | 2 | # Import base libraries 3 | import os, sys, math 4 | import PIL 5 | 6 | # Import scientific and plotting libraries 7 | import numpy as np 8 | import matplotlib.pyplot as plt 9 | 10 | # Import torch libraries 11 | import torch 12 | from torchvision import models, transforms 13 | 14 | 15 | def _get_all_paths(image_parent_dir, real_or_fake, have_classes, num_instances): 16 | """ 17 | This is an internal function. Do not use it explicitly 18 | """ 19 | all_img_paths = [] 20 | 21 | if have_classes: 22 | classes = os.listdir(image_parent_dir) 23 | 24 | for c in classes: 25 | img_dir_location = os.path.join(image_parent_dir, c, real_or_fake) 26 | img_paths = [ os.path.join(img_dir_location, i) for i in os.listdir(img_dir_location) ] 27 | img_paths.sort() 28 | all_img_paths.extend(img_paths[:num_instances]) 29 | 30 | #print(len(all_img_paths)) 31 | #print(all_img_paths) 32 | return all_img_paths 33 | 34 | else: 35 | img_dir_location = os.path.join(image_parent_dir, real_or_fake) 36 | img_paths = [ os.path.join(img_dir_location, i) for i in os.listdir(img_dir_location) ] 37 | img_paths.sort() 38 | all_img_paths.extend(img_paths[:num_instances]) 39 | 40 | #print(len(all_img_paths)) 41 | #print(all_img_paths) 42 | return all_img_paths 43 | 44 | 45 | 46 | class dataset_gan_vs_real(torch.utils.data.Dataset): 47 | """ 48 | Class for binary classification for GAN and real images 49 | 50 | The labels used for training the original classifier are: 51 | 0 : Real images 52 | 1 : GAN/ fake images 53 | """ 54 | def __init__(self, root_dir, have_classes, max_num, transform=None, onlyreal=False, onlyfake=False): 55 | self.classsubpaths =['0_real','1_fake'] 56 | 57 | self.root_dir = root_dir 58 | self.transform = transform 59 | self.imgfilenames=[] 60 | self.labels =[] 61 | self.num_real_images = -1 62 | self.num_fake_images = -1 63 | 64 | real_img_paths = _get_all_paths(self.root_dir, self.classsubpaths[0], have_classes, max_num) 65 | fake_img_paths = _get_all_paths(self.root_dir, self.classsubpaths[1], have_classes, max_num) 66 | 67 | if onlyreal: 68 | self.imgfilenames.extend(real_img_paths) 69 | self.labels.extend([0]*len(real_img_paths)) 70 | self.num_real_images = len(real_img_paths) 71 | 72 | elif onlyfake: 73 | #print(fake_img_paths) 74 | self.imgfilenames.extend(fake_img_paths) 75 | self.labels.extend([1]*len(fake_img_paths)) 76 | self.num_fake_images = len(fake_img_paths) 77 | 78 | else: 79 | self.imgfilenames.extend(real_img_paths) 80 | self.labels.extend([0]*len(real_img_paths)) 81 | self.num_real_images = len(real_img_paths) 82 | 83 | self.imgfilenames.extend(fake_img_paths) 84 | self.labels.extend([1]*len(fake_img_paths)) 85 | self.num_fake_images = len(fake_img_paths) 86 | 87 | 88 | 89 | def __len__(self): 90 | return len(self.imgfilenames) 91 | 92 | 93 | def __getitem__(self, idx): 94 | image = PIL.Image.open(self.imgfilenames[idx]).convert('RGB') 95 | label=self.labels[idx] 96 | 97 | if self.transform: 98 | image = self.transform(image) 99 | 100 | tmpdir = os.path.dirname(self.imgfilenames[idx]) 101 | ind = tmpdir.rfind('/') 102 | ind = tmpdir[:ind].rfind('/') 103 | stub = tmpdir[ind+1:] 104 | 105 | fn = os.path.join(stub, os.path.basename(self.imgfilenames[idx]) ) 106 | 107 | sample = {'image': image, 'label': label, 'filename': self.imgfilenames[idx],'relfilestub': fn} 108 | 109 | return sample 110 | 111 | 112 | 113 | def get_dataloader(root_dir, have_classes, max_num, transform, bsize, onlyreal, onlyfake): 114 | """ 115 | Get dataloader 116 | """ 117 | ds = dataset_gan_vs_real(root_dir = root_dir, have_classes=have_classes, max_num=max_num, transform = transform, 118 | onlyreal=onlyreal, onlyfake=onlyfake) 119 | dl = torch.utils.data.DataLoader(ds, batch_size= bsize, shuffle=False) 120 | 121 | return dl 122 | 123 | 124 | def get_classwise_dataloader(root_dir, have_classes, max_num, transform, bsize, onlyreal, onlyfake): 125 | """ 126 | Get dataloader 127 | """ 128 | dls = [] 129 | clsses = os.listdir(root_dir) 130 | 131 | for clss in clsses: 132 | clss_dir = os.path.join(root_dir, clss) 133 | ds = dataset_gan_vs_real(root_dir = clss_dir, have_classes=False, max_num=max_num, transform = transform, 134 | onlyreal=onlyreal, onlyfake=onlyfake) 135 | dl = torch.utils.data.DataLoader(ds, batch_size= bsize, shuffle=False) 136 | 137 | dls.append(dl) 138 | 139 | return dls 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /src/utils/general.py: -------------------------------------------------------------------------------- 1 | import sys, os 2 | import pandas as pd 3 | 4 | # Import torch and dependencies 5 | import torch 6 | from torchvision import models 7 | from efficientnet_pytorch import EfficientNet 8 | 9 | from sklearn.metrics import average_precision_score, precision_recall_curve, accuracy_score 10 | from sklearn.metrics import roc_curve 11 | import numpy as np 12 | 13 | from termcolor import colored 14 | 15 | # Disable 16 | def blockPrint(): 17 | sys.stdout = open(os.devnull, 'w') 18 | 19 | # Restore 20 | def enablePrint(): 21 | sys.stdout = sys.__stdout__ 22 | 23 | 24 | def get_all_channels(fake_csv_path, topk): 25 | """ 26 | Get topk channels from the saved csv. 27 | """ 28 | df_fake = pd.read_csv(fake_csv_path) 29 | df_fake = df_fake.sort_values(by="mean_relevance", ascending=False) 30 | all_channels = df_fake.key 31 | 32 | if topk == 0: 33 | return [], [], all_channels 34 | else: 35 | top_fake_channels = df_fake.key.tolist()[:topk] 36 | # Now see relevant fake ones that are also in real ones 37 | topk_channels = list(top_fake_channels) 38 | lowk_channels = df_fake.key[-topk:] 39 | 40 | print("#topk channels =", len(topk_channels)) 41 | print("#lowk channels =", len(lowk_channels)) 42 | print("#all channels =", len(all_channels)) 43 | return topk_channels, lowk_channels, all_channels 44 | 45 | 46 | def get_resnet50_universal_detector(weightpath): 47 | """ 48 | Get ResNet50 model loaded into the device. 49 | 50 | Args: 51 | weightspath : path of berkeley classifier weights 52 | 53 | Returns ResNet50 pytorch object 54 | """ 55 | 56 | model0 = models.resnet50(pretrained=False, num_classes=1) 57 | 58 | somedict = torch.load(weightpath) 59 | model0.load_state_dict( somedict['model'] ) 60 | 61 | # set name as attribute to each individual modules. Hooks will be attached using this name 62 | for n, m in model0.named_modules(): 63 | m.auto_name = n 64 | 65 | return model0 66 | 67 | 68 | 69 | def get_efb0_universal_detector(weightpath): 70 | """ 71 | Get Efficient-B0 model loaded into the device. 72 | 73 | Args: 74 | weightspath : path of berkeley classifier weights 75 | 76 | Returns Efficient-B0 pytorch object 77 | """ 78 | 79 | model0 = EfficientNet.from_name('efficientnet-b0', num_classes=1, image_size=None) 80 | somedict = torch.load(weightpath) 81 | model0.load_state_dict( somedict['model'] ) 82 | model0.eval() 83 | 84 | # set name as attribute to each individual modules. Hooks will be attached using this name 85 | for n, m in model0.named_modules(): 86 | m.auto_name = n 87 | 88 | return model0 89 | 90 | 91 | 92 | 93 | def get_probs(model, dataloaders, device): 94 | model.eval() 95 | probs = [] 96 | gt = [] 97 | 98 | with torch.no_grad(): 99 | for dataloader in dataloaders: 100 | for index, data in enumerate(dataloader): 101 | imagetensors = data['image'].to(device) 102 | fnames = data['filename'] 103 | relfnames = data['relfilestub'] 104 | labels = data['label'] 105 | 106 | prob = model(imagetensors).sigmoid().flatten().detach().cpu().numpy() 107 | probs.extend(list(prob)) 108 | gt.extend(list(labels.cpu().numpy().flatten())) 109 | 110 | return np.asarray(probs), np.asarray(gt) 111 | 112 | 113 | def get_calibrated_thres(y_true, y_pred, num_samples=None): 114 | cal_y_true = np.concatenate( [ y_true[y_true==0][:num_samples], y_true[y_true==1][:num_samples] ] ) 115 | cal_y_pred = np.concatenate( [ y_pred[y_true==0][:num_samples], y_pred[y_true==1][:num_samples] ] ) 116 | 117 | fpr, tpr, thresholds = roc_curve(cal_y_true, cal_y_pred) 118 | 119 | # Calculate the G-mean 120 | gmean = np.sqrt(tpr * (1 - fpr)) 121 | 122 | # Find the optimal threshold 123 | index = np.argmax(gmean) 124 | thresholdOpt = thresholds[index] 125 | gmeanOpt = gmean[index] 126 | fprOpt = fpr[index] 127 | tprOpt = tpr[index] 128 | print(colored("> Calibration results", 'cyan')) 129 | print('Best Threshold: {:.6f} with G-Mean: {:.6f}'.format(thresholdOpt, gmeanOpt)) 130 | print('FPR: {}, TPR: {}'.format(fprOpt, tprOpt)) 131 | 132 | threshold = thresholdOpt 133 | 134 | return threshold 135 | 136 | 137 | 138 | def get_ap_and_acc(model, device, dl, threshold): 139 | y_pred, y_true = get_probs(model, dl, device) 140 | 141 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > threshold) 142 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > threshold) 143 | acc = accuracy_score(y_true, y_pred > threshold) 144 | ap = average_precision_score(y_true, y_pred) 145 | 146 | #print('AP: {:2.2f}, Acc: {:2.2f}, Acc (real): {:2.2f}, Acc (fake): {:2.2f}'.format(ap*100., acc*100., r_acc*100., f_acc*100.)) 147 | 148 | return (ap, r_acc, f_acc, acc, y_pred[y_true==0].mean(), y_pred[y_true==1].mean(), y_pred[y_true==0].std(), y_pred[y_true==1].std() ), \ 149 | (y_pred, y_true) 150 | 151 | 152 | def get_ap_and_acc_with_new_threshold(y_pred, y_true, threshold, prefix): 153 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > threshold) 154 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > threshold) 155 | acc = accuracy_score(y_true, y_pred > threshold) 156 | ap = average_precision_score(y_true, y_pred) 157 | 158 | print('{} => AP: {:2.2f}, Acc: {:2.2f}, Acc (real): {:2.2f}, Acc (fake): {:2.2f}'.format(prefix, ap*100., acc*100., r_acc*100., f_acc*100.)) 159 | 160 | return (ap, r_acc, f_acc, acc, y_pred[y_true==0].mean(), y_pred[y_true==1].mean(), y_pred[y_true==0].std(), y_pred[y_true==1].std() ), \ 161 | (y_pred, y_true) 162 | 163 | 164 | 165 | def get_ap_and_acc_random(models, device, dl, threshold, recalibrate=False): 166 | y_pred_global = [] 167 | y_true = [] 168 | 169 | for i in range(len(models)): 170 | y_pred, y_true = get_probs(models[i], dl, device) 171 | y_pred_global.append(y_pred.flatten().tolist()) # append to global 172 | 173 | y_pred = np.mean(np.asarray(y_pred_global), axis=0) 174 | 175 | if recalibrate: 176 | print("Recalibrating...") 177 | threshold = get_calibrated_thres(y_true, y_pred) 178 | 179 | r_acc = accuracy_score(y_true[y_true==0], y_pred[y_true==0] > threshold) 180 | f_acc = accuracy_score(y_true[y_true==1], y_pred[y_true==1] > threshold) 181 | acc = accuracy_score(y_true, y_pred > threshold) 182 | ap = average_precision_score(y_true, y_pred) 183 | 184 | #print('AP: {:2.2f}, Acc: {:2.2f}, Acc (real): {:2.2f}, Acc (fake): {:2.2f}'.format(ap*100., acc*100., r_acc*100., f_acc*100.)) 185 | 186 | return (ap, r_acc, f_acc, acc, y_pred[y_true==0].mean(), y_pred[y_true==1].mean(), y_pred[y_true==0].std(), y_pred[y_true==1].std() ), \ 187 | (y_pred, y_true) -------------------------------------------------------------------------------- /src/utils/heatmap_helpers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | import torch 4 | import cv2 5 | 6 | 7 | from colour import Color 8 | from matplotlib.colors import LinearSegmentedColormap 9 | from matplotlib.transforms import Bbox 10 | 11 | from PIL import Image 12 | 13 | plt.rcParams["font.family"] = "Times New Roman" 14 | plt.rcParams['axes.xmargin'] = 0 15 | 16 | def make_plot_custom_cmap( cmap_colors ): 17 | """ 18 | Pass your own colors to create a custom colorbar 19 | 20 | params: 21 | cmap_colors : List of hex colors 22 | """ 23 | color_cmap = LinearSegmentedColormap.from_list( 'my_list', [ Color( c1 ).rgb for c1 in cmap_colors ] ) 24 | plt.figure( figsize = (15,3)) 25 | plt.imshow( [list(np.arange(0, len( cmap_colors ) , 0.1)) ] , interpolation='nearest', origin='lower', cmap= color_cmap ) 26 | plt.xticks([]) 27 | plt.yticks([]) 28 | plt.show() 29 | return color_cmap 30 | 31 | 32 | 33 | def invert_normalize(ten, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 34 | """ 35 | Invert normalized images 36 | """ 37 | s = torch.tensor(np.asarray(std,dtype=np.float32)).unsqueeze(1).unsqueeze(2) 38 | m = torch.tensor(np.asarray(mean,dtype=np.float32)).unsqueeze(1).unsqueeze(2) 39 | 40 | res = ten*s+m 41 | return res 42 | 43 | 44 | def invert_normalize_image(inpdata): 45 | """ 46 | Take an image tensor, invert normalize and convert to numpy image 47 | """ 48 | ts=invert_normalize(inpdata) 49 | a=ts.data.squeeze(0).permute(1, 2, 0).numpy() 50 | saveimg=(a*255.0).astype(np.uint8) 51 | return saveimg 52 | 53 | 54 | 55 | def rgb2gray(rgb): 56 | """ 57 | Convert RGB to grayscale image 58 | """ 59 | r, g, b = rgb[:,:,0], rgb[:,:,1], rgb[:,:,2] 60 | gray = 0.2989 * r + 0.5870 * g + 0.1140 * b 61 | return gray 62 | 63 | 64 | 65 | def make_register_custom_cmap(cmap_colors): 66 | """ 67 | Take a list of custom colors, create a matplotlib Linear Segmented Color map, register the color map and return the cmap object. 68 | """ 69 | color_cmap = LinearSegmentedColormap.from_list( 'my_cmap', [ Color( c1 ).rgb for c1 in cmap_colors ] ) 70 | plt.register_cmap(cmap = color_cmap) 71 | return color_cmap 72 | 73 | 74 | def get_mpl_colormap(cmap_name): 75 | """ 76 | Convert matplotlib cmap to be used for cv2 mapping 77 | Ref : https://stackoverflow.com/questions/52498777/apply-matplotlib-or-custom-colormap-to-opencv-image 78 | """ 79 | cmap = plt.get_cmap(cmap_name) 80 | 81 | # Initialize the matplotlib color map 82 | sm = plt.cm.ScalarMappable(cmap=cmap) 83 | 84 | # Obtain linear color range 85 | color_range = sm.to_rgba(np.linspace(0, 1, 256), bytes=True)[:,2::-1] 86 | 87 | return color_range.reshape(256, 1, 3) 88 | 89 | 90 | def save_img_lrp_overlay_only_positive( lrp_result, imgtensor, q=100, title=None, outname=None): 91 | """ 92 | Take image tensor and lrp result, create the heatmaps and store in `outname` 93 | """ 94 | 95 | # Create colormaps for lrp 96 | # This is a simple color map with black to orange. I use black as the pixel value is 0, 0, 0 so easy to overlay later. 97 | cmap_lrp = make_register_custom_cmap( [ 98 | '#000000', '#ffc300', '#ffaa00', 99 | ] ) 100 | 101 | # Normalize and smooth lrp heatmap 102 | hm = lrp_result.squeeze().sum(dim=0).numpy() 103 | hm [ hm < 0 ] = 0.0 # Only consider positive relevances 104 | clim = np.percentile(np.abs(hm), q) 105 | 106 | # Now normalize, smooth and create colormap 107 | final_hm = hm.copy()/clim 108 | final_hm = cv2.medianBlur(final_hm, 5) # Smoothen, otherwise its ugly for ResNet-50 109 | final_hm = cv2.applyColorMap(np.uint8(255 * final_hm), get_mpl_colormap('my_cmap')) 110 | final_hm = cv2.cvtColor(final_hm, cv2.COLOR_BGR2RGB) 111 | 112 | 113 | # Size of output in pixels 114 | h = 224 115 | w = 224 116 | my_dpi = 1200 117 | fig, ax = plt.subplots(1, figsize=(w/my_dpi, h/my_dpi), dpi=my_dpi) 118 | 119 | # Image 120 | img_np = invert_normalize_image(imgtensor) # No need to blur the image 121 | 122 | # Now overlay and create final visualization 123 | cam = cv2.addWeighted(final_hm, 0.55, img_np, 0.45, 0) 124 | ax.imshow(cam) 125 | ax.axis("off") 126 | 127 | # Save image as pdf 128 | fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0) 129 | fig.savefig("{}".format(outname), format='pdf', bbox_inches=Bbox([[0, 0], [w/my_dpi, h/my_dpi]]), dpi=my_dpi) 130 | plt.close() 131 | 132 | 133 | 134 | def save_img_guided_gradcam_overlay_only_positive( gb_cam_result, imgtensor, q=100, title=None, outname=None): 135 | """ 136 | Take image tensor and lrp result, create the heatmaps and store in `outname` 137 | """ 138 | 139 | # Sum all the channels for relevance 140 | gb_cam_result = np.sum(gb_cam_result, axis=2, keepdims=False) 141 | 142 | # Create colormaps for lrp 143 | # This is a simple color map with black to orange. I use black as the pixel value is 0, 0, 0 so easy to overlay later. 144 | cmap_lrp = make_register_custom_cmap( [ 145 | '#000000', '#ffc300', '#ffaa00', 146 | ] ) 147 | 148 | # Normalize and smooth lrp heatmap 149 | hm = gb_cam_result 150 | hm [ hm < 0 ] = 0.0 # Negative values do not make sense for gradcam 151 | clim = np.percentile(np.abs(hm), q) 152 | 153 | # Now normalize, smooth and create colormap 154 | final_hm = hm.copy()/clim 155 | #final_hm = cv2.medianBlur(final_hm, 5) # Smoothen, otherwise its ugly for ResNet-50 156 | final_hm = cv2.applyColorMap(np.uint8(255 * final_hm), get_mpl_colormap('my_cmap')) 157 | final_hm = cv2.cvtColor(final_hm, cv2.COLOR_BGR2RGB) 158 | 159 | 160 | # Size of output in pixels 161 | h = 224 162 | w = 224 163 | my_dpi = 1200 164 | fig, ax = plt.subplots(1, figsize=(w/my_dpi, h/my_dpi), dpi=my_dpi) 165 | 166 | # Image 167 | img_np = np.uint8(imgtensor*255.0) # No need to blur the image 168 | 169 | # Now overlay and create final visualization 170 | cam = cv2.addWeighted(final_hm, 0.55, img_np, 0.45, 0) 171 | ax.imshow(cam) 172 | ax.axis("off") 173 | 174 | # Save image as pdf 175 | fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0) 176 | fig.savefig("{}".format(outname), format='pdf', bbox_inches=Bbox([[0, 0], [w/my_dpi, h/my_dpi]]), dpi=my_dpi) 177 | plt.close() 178 | 179 | 180 | 181 | 182 | def explain_save_img_lrp_overlay_only_positive( lrp_result, imgtensor, q=100, title=None, outname=None): 183 | """ 184 | Take image tensor and lrp result, create the heatmaps and store in `outname` 185 | """ 186 | 187 | # Create colormaps for lrp 188 | # This is a simple color map with black to orange. I use black as the pixel value is 0, 0, 0 so easy to overlay later. 189 | cmap_lrp = make_register_custom_cmap( [ 190 | '#000000', '#ffc300', '#ffaa00', 191 | ] ) 192 | 193 | # Normalize and smooth lrp heatmap 194 | hm = lrp_result.squeeze().sum(dim=0).numpy() 195 | hm [ hm < 0 ] = 0.0 # Only consider positive relevances 196 | clim = np.percentile(np.abs(hm), q) 197 | 198 | # Now normalize, smooth and create colormap 199 | final_hm = hm.copy()/clim 200 | #final_hm = cv2.medianBlur(final_hm, 5) # Smoothen, otherwise its ugly for ResNet-50 201 | final_hm = cv2.applyColorMap(np.uint8(255 * final_hm), get_mpl_colormap('my_cmap')) 202 | final_hm = cv2.cvtColor(final_hm, cv2.COLOR_BGR2RGB) 203 | 204 | 205 | # Size of output in pixels 206 | h = 256 207 | w = 256 208 | my_dpi = 1200 209 | fig, ax = plt.subplots(1, figsize=(w/my_dpi, h/my_dpi), dpi=my_dpi) 210 | 211 | # Image 212 | img_np = invert_normalize_image(imgtensor) # No need to blur the image 213 | 214 | # Now overlay and create final visualization 215 | cam = cv2.addWeighted(final_hm, 0.55, img_np, 0.45, 0) 216 | ax.imshow(cam) 217 | ax.axis("off") 218 | 219 | # Save image as pdf 220 | fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0) 221 | fig.savefig("{}".format(outname), format='jpg', bbox_inches=Bbox([[0, 0], [w/my_dpi, h/my_dpi]]), dpi=my_dpi) 222 | plt.close() 223 | 224 | plt.imsave("{}_image".format(outname), img_np, format='jpg') 225 | 226 | 227 | def relevance_bounding_box(bool_mask, q=90): 228 | """ 229 | Use thresholded lrp mask to obtain ROI 230 | """ 231 | bool_mask = bool_mask.copy() 232 | bool_mask[ bool_mask<=np.percentile((bool_mask), q) ] = 0.0 233 | bool_mask[ bool_mask>np.percentile((bool_mask), q) ] = 1.0 234 | bool_mask = bool_mask.astype(np.uint8) 235 | itemindex= np.where(bool_mask==True) 236 | # print(np.min(itemindex[0]), np.min(itemindex[1])) 237 | # print(np.max(itemindex[0]), np.max(itemindex[1])) 238 | 239 | if len(itemindex[0]) == 0 or len(itemindex[1]) == 0: 240 | return bool_mask[ 0:bool_mask.shape[0], 0: bool_mask.shape[1] ], (0, bool_mask.shape[0], 0, bool_mask.shape[1]) 241 | 242 | 243 | x_min, x_max = np.min(itemindex[0]), np.max(itemindex[0]) 244 | y_min, y_max = np.min(itemindex[1]), np.max(itemindex[1]) 245 | 246 | return bool_mask[ x_min:x_max, y_min:y_max ], (x_min, x_max, y_min, y_max) 247 | 248 | 249 | 250 | def extract_patch( lrp_result, imgtensor, q=100, title=None, outname=None): 251 | """ 252 | Take image tensor and lrp result, create the heatmaps and store in `outname` 253 | """ 254 | # Create colormaps for lrp 255 | # This is a simple color map with black to orange. I use black as the pixel value is 0, 0, 0 so easy to overlay later. 256 | cmap_lrp = make_register_custom_cmap( [ 257 | '#000000', '#ffc300', '#ffaa00', 258 | ] ) 259 | 260 | # Normalize and smooth lrp heatmap 261 | hm = lrp_result.squeeze().sum(dim=0).numpy() 262 | hm [ hm < 0 ] = 0.0 # Only consider positive relevances 263 | clim = np.percentile(np.abs(hm), q) 264 | 265 | # Now normalize, smooth and create colormap 266 | final_hm = hm.copy()/clim 267 | 268 | # Extract patch 269 | _, (x_min, x_max, y_min, y_max) = relevance_bounding_box(final_hm, q=75) 270 | # print( (x_min, x_max, y_min, y_max) ) 271 | 272 | # Image 273 | img_np = invert_normalize_image(imgtensor) [ x_min:x_max, y_min:y_max ] # No need to blur the image 274 | 275 | # Convert to PIL 276 | pil_image = Image.fromarray(img_np) 277 | pil_image.save("{}".format(outname), quality=95) -------------------------------------------------------------------------------- /src/utils/mask_fmaps.py: -------------------------------------------------------------------------------- 1 | 2 | # Import base libraries 3 | import os, sys, math 4 | 5 | # Import torch libraries 6 | import torch 7 | 8 | # Import other libraries 9 | from tqdm import tqdm 10 | 11 | 12 | def feature_map_dropout(feature_map_idx): 13 | """ 14 | Mask the output produced by the specific channel 15 | """ 16 | def hook(module, inputs, outputs): 17 | if not module.training: # (B, C, H, w) 18 | outputs[:, feature_map_idx, :, :] = outputs[:, feature_map_idx, :, :]*0.0 # multiply the channel output by 0.0 19 | else: 20 | raise NotImplementedError("Please set your model in evalutation mode for sensitivity assessments") 21 | return hook 22 | 23 | 24 | def mask_target_channels(model, topk_dict): 25 | """ 26 | Mask target channels (Can be topk or lowk channels) 27 | topk_dict i.e.: layer0.conv1.#33(T=64) 28 | 29 | The register_forward_hook function can modify the output directly. (It also cannot modify the inputs) 30 | """ 31 | hooks = [] 32 | 33 | with tqdm(total=len(topk_dict)) as pbar: 34 | for filter_name in topk_dict: 35 | total_filters = filter_name.split('.#')[-1].split('=')[1][:-1] 36 | feature_map_idx = int(filter_name.split('.#')[-1].split('(')[0]) 37 | 38 | for i, (name, mod) in enumerate(model.named_modules()): 39 | # print(i) 40 | if filter_name.split("#")[0][:-1] == mod.auto_name: 41 | #print("hit >> ", filter_name, feature_map_idx, total_filters) 42 | pbar.set_description("dropout >> filter_name : {}, index : {}/{}".format(filter_name, 43 | feature_map_idx, total_filters)) 44 | hook = mod.register_forward_hook(feature_map_dropout(feature_map_idx)) 45 | hooks.append(hook) 46 | 47 | pbar.update(1) 48 | 49 | return model, hooks 50 | 51 | 52 | 53 | def mask_random_channels(model, topk, topk_dict, all_feature_maps): 54 | import random 55 | """ 56 | Mask random channels not in the topk dict 57 | topk_dict i.e.: layer0.conv1.#33(T=64) 58 | 59 | The register_forward_hook function can modify the output directly. (It also cannot modify the inputs) 60 | """ 61 | hooks = [] 62 | all_feature_maps = set(random.choices(list(all_feature_maps), k=len(all_feature_maps))) # Shuffle 63 | total = 0 64 | 65 | if topk == 0: 66 | return model, hooks 67 | 68 | with tqdm(total=len(topk_dict)) as pbar: 69 | for feature_map_name in all_feature_maps: 70 | if feature_map_name in topk_dict: 71 | pbar.set_description("{} occurs in topk, skipping".format(feature_map_name)) 72 | #print("{} occurs in topk, skipping".format(feature_map_name)) 73 | continue 74 | 75 | total_filters = int(feature_map_name.split('.#')[-1].split('=')[1][:-1]) 76 | feature_map_idx = int(feature_map_name.split('.#')[-1].split('(')[0]) 77 | 78 | # Attach masking hook 79 | for i, (name, mod) in enumerate(model.named_modules()): 80 | if feature_map_name.split("#")[0][:-1] == mod.auto_name: 81 | # print("hit >> ", filter_name, channel_idx, total_filters) 82 | pbar.set_description("dropout >> filter_name : {}, index : {}/{}".format(feature_map_name, 83 | feature_map_idx, total_filters)) 84 | hook = mod.register_forward_hook(feature_map_dropout(feature_map_idx)) 85 | hooks.append(hook) 86 | 87 | total += 1 88 | pbar.update(1) 89 | 90 | if total == topk: 91 | break 92 | 93 | return model, hooks --------------------------------------------------------------------------------