├── .gitignore ├── pytorch-gradual-warmup-lr ├── warmup_scheduler │ ├── __init__.py │ ├── run.py │ └── scheduler.py └── setup.py ├── Deblurring ├── utils │ ├── __init__.py │ ├── dir_utils.py │ ├── image_utils.py │ ├── dataset_utils.py │ └── model_utils.py ├── pretrained_models │ └── README.md ├── data_RGB.py ├── training.yml ├── Datasets │ └── README.md ├── evaluate_GOPRO_HIDE.m ├── losses.py ├── README.md ├── test.py ├── config.py ├── evaluate_RealBlur.py ├── dataset_RGB.py ├── train.py └── MPRNet.py ├── Denoising ├── utils │ ├── __init__.py │ ├── dir_utils.py │ ├── image_utils.py │ ├── dataset_utils.py │ └── model_utils.py ├── pretrained_models │ └── README.md ├── data_RGB.py ├── training.yml ├── evaluate_SIDD.m ├── Datasets │ └── README.md ├── README.md ├── losses.py ├── generate_patches_SIDD.py ├── test_SIDD.py ├── test_DND.py ├── config.py ├── dataset_RGB.py ├── train.py └── MPRNet.py ├── Deraining ├── utils │ ├── __init__.py │ ├── dir_utils.py │ ├── image_utils.py │ ├── dataset_utils.py │ └── model_utils.py ├── pretrained_models │ └── README.md ├── data_RGB.py ├── Datasets │ └── README.md ├── training.yml ├── README.md ├── losses.py ├── test.py ├── config.py ├── dataset_RGB.py ├── train.py ├── evaluate_PSNR_SSIM.m └── MPRNet.py ├── LICENSE.md ├── demo.py └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | *.pyc 3 | .DS_Store 4 | 5 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /Deblurring/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | -------------------------------------------------------------------------------- /Denoising/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | -------------------------------------------------------------------------------- /Deraining/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .image_utils import * 3 | from .model_utils import * 4 | from .dataset_utils import * 5 | -------------------------------------------------------------------------------- /Denoising/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained denoising model is available [here](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) -------------------------------------------------------------------------------- /Deraining/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained deraining model is available [here](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing) -------------------------------------------------------------------------------- /Deblurring/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | pre-trained deblurring model is available [here](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) -------------------------------------------------------------------------------- /Deblurring/utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /Denoising/utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /Deraining/utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /Denoising/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | -------------------------------------------------------------------------------- /Deraining/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | -------------------------------------------------------------------------------- /Deblurring/data_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataset_RGB import DataLoaderTrain, DataLoaderVal, DataLoaderTest 3 | 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options) 7 | 8 | def get_validation_data(rgb_dir, img_options): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, img_options) 11 | 12 | def get_test_data(rgb_dir, img_options): 13 | assert os.path.exists(rgb_dir) 14 | return DataLoaderTest(rgb_dir, img_options) 15 | -------------------------------------------------------------------------------- /Deraining/Datasets/README.md: -------------------------------------------------------------------------------- 1 | Download datasets from the google drive links and place them in this directory. Your directory structure should look something like this 2 | 3 | `Synthetic_Rain_Datasets` 4 | `├──`[train](https://drive.google.com/drive/folders/1Hnnlc5kI0v9_BtfMytC2LR5VpLAFZtVe?usp=sharing) 5 | `└──`[test](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) 6 | `├──Test100` 7 | `├──Rain100H` 8 | `├──Rain100L` 9 | `├──Test1200` 10 | `└──Test2800` 11 | -------------------------------------------------------------------------------- /Denoising/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /Deraining/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /Deblurring/utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import cv2 4 | 5 | def torchPSNR(tar_img, prd_img): 6 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 7 | rmse = (imdff**2).mean().sqrt() 8 | ps = 20*torch.log10(1/rmse) 9 | return ps 10 | 11 | def save_img(filepath, img): 12 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 13 | 14 | def numpyPSNR(tar_img, prd_img): 15 | imdff = np.float32(prd_img) - np.float32(tar_img) 16 | rmse = np.sqrt(np.mean(imdff**2)) 17 | ps = 20*np.log10(255/rmse) 18 | return ps 19 | -------------------------------------------------------------------------------- /Deblurring/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /Denoising/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /Deraining/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class MixUp_AUG: 4 | def __init__(self): 5 | self.dist = torch.distributions.beta.Beta(torch.tensor([0.6]), torch.tensor([0.6])) 6 | 7 | def aug(self, rgb_gt, rgb_noisy): 8 | bs = rgb_gt.size(0) 9 | indices = torch.randperm(bs) 10 | rgb_gt2 = rgb_gt[indices] 11 | rgb_noisy2 = rgb_noisy[indices] 12 | 13 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 14 | 15 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 16 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 17 | 18 | return rgb_gt, rgb_noisy -------------------------------------------------------------------------------- /Denoising/training.yml: -------------------------------------------------------------------------------- 1 | ############### 2 | ## 3 | #### 4 | 5 | GPU: [0,1,2,3] 6 | 7 | VERBOSE: True 8 | 9 | MODEL: 10 | MODE: 'Denoising' 11 | SESSION: 'MPRNet' 12 | 13 | # Optimization arguments. 14 | OPTIM: 15 | BATCH_SIZE: 16 16 | NUM_EPOCHS: 80 17 | # NEPOCH_DECAY: [10] 18 | LR_INITIAL: 2e-4 19 | LR_MIN: 1e-6 20 | # BETA1: 0.9 21 | 22 | TRAINING: 23 | VAL_AFTER_EVERY: 1 24 | RESUME: False 25 | TRAIN_PS: 128 26 | VAL_PS: 256 27 | TRAIN_DIR: './Datasets/SIDD/train' # path to training data 28 | VAL_DIR: './Datasets/SIDD/val' # path to validation data 29 | SAVE_DIR: './checkpoints' # path to save models and images 30 | # SAVE_IMAGES: False 31 | -------------------------------------------------------------------------------- /Deblurring/training.yml: -------------------------------------------------------------------------------- 1 | ############### 2 | ## 3 | #### 4 | 5 | GPU: [0,1,2,3] 6 | 7 | VERBOSE: True 8 | 9 | MODEL: 10 | MODE: 'Deblurring' 11 | SESSION: 'MPRNet' 12 | 13 | # Optimization arguments. 14 | OPTIM: 15 | BATCH_SIZE: 16 16 | NUM_EPOCHS: 3000 17 | # NEPOCH_DECAY: [10] 18 | LR_INITIAL: 2e-4 19 | LR_MIN: 1e-6 20 | # BETA1: 0.9 21 | 22 | TRAINING: 23 | VAL_AFTER_EVERY: 20 24 | RESUME: False 25 | TRAIN_PS: 256 26 | VAL_PS: 256 27 | TRAIN_DIR: './Datasets/GoPro/train' # path to training data 28 | VAL_DIR: './Datasets/GoPro/test' # path to validation data 29 | SAVE_DIR: './checkpoints' # path to save models and images 30 | # SAVE_IMAGES: False 31 | -------------------------------------------------------------------------------- /Deraining/training.yml: -------------------------------------------------------------------------------- 1 | ############### 2 | ## 3 | #### 4 | 5 | GPU: [0,1,2,3] 6 | 7 | VERBOSE: True 8 | 9 | MODEL: 10 | MODE: 'Deraining' 11 | SESSION: 'MPRNet' 12 | 13 | # Optimization arguments. 14 | OPTIM: 15 | BATCH_SIZE: 16 16 | NUM_EPOCHS: 250 17 | # NEPOCH_DECAY: [10] 18 | LR_INITIAL: 2e-4 19 | LR_MIN: 1e-6 20 | # BETA1: 0.9 21 | 22 | TRAINING: 23 | VAL_AFTER_EVERY: 5 24 | RESUME: False 25 | TRAIN_PS: 256 26 | VAL_PS: 128 27 | TRAIN_DIR: './Datasets/train' # path to training data 28 | VAL_DIR: './Datasets/test/Rain100L' # path to validation data 29 | SAVE_DIR: './checkpoints' # path to save models and images 30 | # SAVE_IMAGES: False 31 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) 26 | -------------------------------------------------------------------------------- /Deblurring/Datasets/README.md: -------------------------------------------------------------------------------- 1 | Download datasets from the google drive links and place them in this directory. Your directory tree should look like this 2 | 3 | `GoPro` 4 | `├──`[train](https://drive.google.com/drive/folders/1AsgIP9_X0bg0olu2-1N6karm2x15cJWE?usp=sharing) 5 | `└──`[test](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) 6 | 7 | `HIDE` 8 | `└──`[test](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) 9 | 10 | `RealBlur_J` 11 | `└──`[test](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) 12 | 13 | `RealBlur_R` 14 | `└──`[test](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) 15 | -------------------------------------------------------------------------------- /Deraining/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Training 3 | - Download the [Datasets](Datasets/README.md) 4 | 5 | - Train the model with default arguments by running 6 | 7 | ``` 8 | python train.py 9 | ``` 10 | 11 | 12 | ## Evaluation 13 | 14 | 1. Download the [model](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing) and place it in `./pretrained_models/` 15 | 16 | 2. Download test datasets (Test100, Rain100H, Rain100L, Test1200, Test2800) from [here](https://drive.google.com/drive/folders/1PDWggNh8ylevFmrjo-JEvlmqsDlWWvZs?usp=sharing) and place them in `./Datasets/Synthetic_Rain_Datasets/test/` 17 | 18 | 3. Run 19 | ``` 20 | python test.py 21 | ``` 22 | 23 | #### To reproduce PSNR/SSIM scores of the paper, run 24 | ``` 25 | evaluate_PSNR_SSIM.m 26 | ``` 27 | -------------------------------------------------------------------------------- /Denoising/evaluate_SIDD.m: -------------------------------------------------------------------------------- 1 | close all;clear all; 2 | 3 | denoised = load('Idenoised.mat'); 4 | gt = load('ValidationGtBlocksSrgb.mat'); 5 | 6 | denoised = denoised.Idenoised; 7 | gt = gt.ValidationGtBlocksSrgb; 8 | gt = im2single(gt); 9 | 10 | total_psnr = 0; 11 | total_ssim = 0; 12 | for i = 1:40 13 | for k = 1:32 14 | denoised_patch = squeeze(denoised(i,k,:,:,:)); 15 | gt_patch = squeeze(gt(i,k,:,:,:)); 16 | ssim_val = ssim(denoised_patch, gt_patch); 17 | psnr_val = psnr(denoised_patch, gt_patch); 18 | total_ssim = total_ssim + ssim_val; 19 | total_psnr = total_psnr + psnr_val; 20 | end 21 | end 22 | qm_psnr = total_psnr / (40*32); 23 | qm_ssim = total_ssim / (40*32); 24 | 25 | fprintf('PSNR: %f SSIM: %f\n', qm_psnr, qm_ssim); 26 | 27 | -------------------------------------------------------------------------------- /Denoising/Datasets/README.md: -------------------------------------------------------------------------------- 1 | Download datasets from the provided links and place them in this directory. Your directory structure should look something like this 2 | 3 | `SIDD` 4 | `├──`[train](https://www.eecs.yorku.ca/~kamel/sidd/dataset.php) 5 | `├──`[val](https://drive.google.com/drive/folders/1S44fHXaVxAYW3KLNxK41NYCnyX9S79su?usp=sharing) 6 | `└──`[test](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) 7 | `├──ValidationNoisyBlocksSrgb.mat` 8 | `└──ValidationGtBlocksSrgb.mat` 9 | 10 | `DND` 11 | `└──`[test](https://noise.visinf.tu-darmstadt.de/downloads/) 12 | `├──info.mat` 13 | `└──images_srgb` 14 | `├──0001.mat` 15 | `├──0002.mat` 16 | `├── ... ` 17 | `└──0050.mat` 18 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /Denoising/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## Training 3 | - Download the [Datasets](Datasets/README.md) 4 | 5 | - Generate image patches from full-resolution training images of SIDD dataset 6 | ``` 7 | python generate_patches_SIDD.py --ps 256 --num_patches 300 --num_cores 10 8 | ``` 9 | - Train the model with default arguments by running 10 | 11 | ``` 12 | python train.py 13 | ``` 14 | 15 | 16 | ## Evaluation 17 | 18 | - Download the [model](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) and place it in `./pretrained_models/` 19 | 20 | #### Testing on SIDD dataset 21 | - Download SIDD Validation Data and Ground Truth from [here](https://www.eecs.yorku.ca/~kamel/sidd/benchmark.php) and place them in `./Datasets/SIDD/test/` 22 | - Run 23 | ``` 24 | python test_SIDD.py --save_images 25 | ``` 26 | #### Testing on DND dataset 27 | - Download DND Benchmark Data from [here](https://noise.visinf.tu-darmstadt.de/downloads/) and place it in `./Datasets/DND/test/` 28 | - Run 29 | ``` 30 | python test_DND.py --save_images 31 | ``` 32 | 33 | #### To reproduce PSNR/SSIM scores of the paper, run MATLAB script 34 | ``` 35 | evaluate_SIDD.m 36 | ``` 37 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## ACADEMIC PUBLIC LICENSE 2 | 3 | ### Permissions 4 | :heavy_check_mark: Non-Commercial use 5 | :heavy_check_mark: Modification 6 | :heavy_check_mark: Distribution 7 | :heavy_check_mark: Private use 8 | 9 | ### Limitations 10 | :x: Commercial Use 11 | :x: Liability 12 | :x: Warranty 13 | 14 | ### Conditions 15 | :information_source: License and copyright notice 16 | :information_source: Same License 17 | 18 | MPRNet is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations. 19 | You can use MPRNet in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately. 20 | 21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software. 22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license. 23 | This license guarantees that you're safe when using MPRNet in your work, for teaching or research. 24 | This license guarantees that MPRNet will remain available free of charge for nonprofit use. 25 | You can modify MPRNet to your purposes, and you can also share your modifications. 26 | 27 | If you would like to use MPRNet in commercial settings, contact us so we can discuss options. Send an email to waqas.zamir@inceptioniai.org 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /Deblurring/evaluate_GOPRO_HIDE.m: -------------------------------------------------------------------------------- 1 | %% Multi-Stage Progressive Image Restoration 2 | %% Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 3 | %% https://arxiv.org/abs/2102.02808 4 | 5 | close all;clear all; 6 | 7 | % datasets = {'GoPro'}; 8 | datasets = {'GoPro', 'HIDE'}; 9 | num_set = length(datasets); 10 | 11 | tic 12 | delete(gcp('nocreate')) 13 | parpool('local',20); 14 | 15 | for idx_set = 1:num_set 16 | file_path = strcat('./results/', datasets{idx_set}, '/'); 17 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/'); 18 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 19 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 20 | img_num = length(path_list); 21 | 22 | total_psnr = 0; 23 | total_ssim = 0; 24 | if img_num > 0 25 | parfor j = 1:img_num 26 | image_name = path_list(j).name; 27 | gt_name = gt_list(j).name; 28 | input = imread(strcat(file_path,image_name)); 29 | gt = imread(strcat(gt_path, gt_name)); 30 | ssim_val = ssim(input, gt); 31 | psnr_val = psnr(input, gt); 32 | total_ssim = total_ssim + ssim_val; 33 | total_psnr = total_psnr + psnr_val; 34 | end 35 | end 36 | qm_psnr = total_psnr / img_num; 37 | qm_ssim = total_ssim / img_num; 38 | 39 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 40 | 41 | end 42 | delete(gcp('nocreate')) 43 | toc 44 | -------------------------------------------------------------------------------- /Deblurring/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-3): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 16 | return loss 17 | 18 | class EdgeLoss(nn.Module): 19 | def __init__(self): 20 | super(EdgeLoss, self).__init__() 21 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 22 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 23 | if torch.cuda.is_available(): 24 | self.kernel = self.kernel.cuda() 25 | self.loss = CharbonnierLoss() 26 | 27 | def conv_gauss(self, img): 28 | n_channels, _, kw, kh = self.kernel.shape 29 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 30 | return F.conv2d(img, self.kernel, groups=n_channels) 31 | 32 | def laplacian_kernel(self, current): 33 | filtered = self.conv_gauss(current) # filter 34 | down = filtered[:,:,::2,::2] # downsample 35 | new_filter = torch.zeros_like(filtered) 36 | new_filter[:,:,::2,::2] = down*4 # upsample 37 | filtered = self.conv_gauss(new_filter) # filter 38 | diff = current - filtered 39 | return diff 40 | 41 | def forward(self, x, y): 42 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 43 | return loss 44 | -------------------------------------------------------------------------------- /Denoising/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-3): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 16 | return loss 17 | 18 | class EdgeLoss(nn.Module): 19 | def __init__(self): 20 | super(EdgeLoss, self).__init__() 21 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 22 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 23 | if torch.cuda.is_available(): 24 | self.kernel = self.kernel.cuda() 25 | self.loss = CharbonnierLoss() 26 | 27 | def conv_gauss(self, img): 28 | n_channels, _, kw, kh = self.kernel.shape 29 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 30 | return F.conv2d(img, self.kernel, groups=n_channels) 31 | 32 | def laplacian_kernel(self, current): 33 | filtered = self.conv_gauss(current) # filter 34 | down = filtered[:,:,::2,::2] # downsample 35 | new_filter = torch.zeros_like(filtered) 36 | new_filter[:,:,::2,::2] = down*4 # upsample 37 | filtered = self.conv_gauss(new_filter) # filter 38 | diff = current - filtered 39 | return diff 40 | 41 | def forward(self, x, y): 42 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 43 | return loss 44 | -------------------------------------------------------------------------------- /Deraining/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-3): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 16 | return loss 17 | 18 | class EdgeLoss(nn.Module): 19 | def __init__(self): 20 | super(EdgeLoss, self).__init__() 21 | k = torch.Tensor([[.05, .25, .4, .25, .05]]) 22 | self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1) 23 | if torch.cuda.is_available(): 24 | self.kernel = self.kernel.cuda() 25 | self.loss = CharbonnierLoss() 26 | 27 | def conv_gauss(self, img): 28 | n_channels, _, kw, kh = self.kernel.shape 29 | img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate') 30 | return F.conv2d(img, self.kernel, groups=n_channels) 31 | 32 | def laplacian_kernel(self, current): 33 | filtered = self.conv_gauss(current) # filter 34 | down = filtered[:,:,::2,::2] # downsample 35 | new_filter = torch.zeros_like(filtered) 36 | new_filter[:,:,::2,::2] = down*4 # upsample 37 | filtered = self.conv_gauss(new_filter) # filter 38 | diff = current - filtered 39 | return diff 40 | 41 | def forward(self, x, y): 42 | loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y)) 43 | return loss 44 | -------------------------------------------------------------------------------- /Deblurring/README.md: -------------------------------------------------------------------------------- 1 | ## Training 2 | - Download the [Datasets](Datasets/README.md) 3 | 4 | - Train the model with default arguments by running 5 | 6 | ``` 7 | python train.py 8 | ``` 9 | 10 | ## Evaluation 11 | 12 | ### Download the [model](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing) and place it in ./pretrained_models/ 13 | 14 | #### Testing on GoPro dataset 15 | - Download [images](https://drive.google.com/drive/folders/1a2qKfXWpNuTGOm2-Jex8kfNSzYJLbqkf?usp=sharing) of GoPro and place them in `./Datasets/GoPro/test/` 16 | - Run 17 | ``` 18 | python test.py --dataset GoPro 19 | ``` 20 | 21 | #### Testing on HIDE dataset 22 | - Download [images](https://drive.google.com/drive/folders/1nRsTXj4iTUkTvBhTcGg8cySK8nd3vlhK?usp=sharing) of HIDE and place them in `./Datasets/HIDE/test/` 23 | - Run 24 | ``` 25 | python test.py --dataset HIDE 26 | ``` 27 | 28 | 29 | #### Testing on RealBlur-J dataset 30 | - Download [images](https://drive.google.com/drive/folders/1KYtzeKCiDRX9DSvC-upHrCqvC4sPAiJ1?usp=sharing) of RealBlur-J and place them in `./Datasets/RealBlur_J/test/` 31 | - Run 32 | ``` 33 | python test.py --dataset RealBlur_J 34 | ``` 35 | 36 | 37 | 38 | #### Testing on RealBlur-R dataset 39 | - Download [images](https://drive.google.com/drive/folders/1EwDoajf5nStPIAcU4s9rdc8SPzfm3tW1?usp=sharing) of RealBlur-R and place them in `./Datasets/RealBlur_R/test/` 40 | - Run 41 | ``` 42 | python test.py --dataset RealBlur_R 43 | ``` 44 | 45 | #### To reproduce PSNR/SSIM scores of the paper on GoPro and HIDE datasets, run this MATLAB script 46 | ``` 47 | evaluate_GOPRO_HIDE.m 48 | ``` 49 | 50 | #### To reproduce PSNR/SSIM scores of the paper on RealBlur dataset, run 51 | ``` 52 | evaluate_RealBlur.py 53 | ``` 54 | -------------------------------------------------------------------------------- /Denoising/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | try: 25 | model.load_state_dict(checkpoint["state_dict"]) 26 | except: 27 | state_dict = checkpoint["state_dict"] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | 35 | def load_checkpoint_multigpu(model, weights): 36 | checkpoint = torch.load(weights) 37 | state_dict = checkpoint["state_dict"] 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | model.load_state_dict(new_state_dict) 43 | 44 | def load_start_epoch(weights): 45 | checkpoint = torch.load(weights) 46 | epoch = checkpoint["epoch"] 47 | return epoch 48 | 49 | def load_optim(optimizer, weights): 50 | checkpoint = torch.load(weights) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | # for p in optimizer.param_groups: lr = p['lr'] 53 | # return lr 54 | -------------------------------------------------------------------------------- /Deraining/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | try: 25 | model.load_state_dict(checkpoint["state_dict"]) 26 | except: 27 | state_dict = checkpoint["state_dict"] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | 35 | def load_checkpoint_multigpu(model, weights): 36 | checkpoint = torch.load(weights) 37 | state_dict = checkpoint["state_dict"] 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | model.load_state_dict(new_state_dict) 43 | 44 | def load_start_epoch(weights): 45 | checkpoint = torch.load(weights) 46 | epoch = checkpoint["epoch"] 47 | return epoch 48 | 49 | def load_optim(optimizer, weights): 50 | checkpoint = torch.load(weights) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | # for p in optimizer.param_groups: lr = p['lr'] 53 | # return lr 54 | -------------------------------------------------------------------------------- /Deblurring/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from collections import OrderedDict 4 | 5 | def freeze(model): 6 | for p in model.parameters(): 7 | p.requires_grad=False 8 | 9 | def unfreeze(model): 10 | for p in model.parameters(): 11 | p.requires_grad=True 12 | 13 | def is_frozen(model): 14 | x = [p.requires_grad for p in model.parameters()] 15 | return not all(x) 16 | 17 | def save_checkpoint(model_dir, state, session): 18 | epoch = state['epoch'] 19 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 20 | torch.save(state, model_out_path) 21 | 22 | def load_checkpoint(model, weights): 23 | checkpoint = torch.load(weights) 24 | try: 25 | model.load_state_dict(checkpoint["state_dict"]) 26 | except: 27 | state_dict = checkpoint["state_dict"] 28 | new_state_dict = OrderedDict() 29 | for k, v in state_dict.items(): 30 | name = k[7:] # remove `module.` 31 | new_state_dict[name] = v 32 | model.load_state_dict(new_state_dict) 33 | 34 | 35 | def load_checkpoint_multigpu(model, weights): 36 | checkpoint = torch.load(weights) 37 | state_dict = checkpoint["state_dict"] 38 | new_state_dict = OrderedDict() 39 | for k, v in state_dict.items(): 40 | name = k[7:] # remove `module.` 41 | new_state_dict[name] = v 42 | model.load_state_dict(new_state_dict) 43 | 44 | def load_start_epoch(weights): 45 | checkpoint = torch.load(weights) 46 | epoch = checkpoint["epoch"] 47 | return epoch 48 | 49 | def load_optim(optimizer, weights): 50 | checkpoint = torch.load(weights) 51 | optimizer.load_state_dict(checkpoint['optimizer']) 52 | # for p in optimizer.param_groups: lr = p['lr'] 53 | # return lr 54 | -------------------------------------------------------------------------------- /Denoising/generate_patches_SIDD.py: -------------------------------------------------------------------------------- 1 | from glob import glob 2 | from tqdm import tqdm 3 | import numpy as np 4 | import os 5 | from natsort import natsorted 6 | import cv2 7 | from joblib import Parallel, delayed 8 | import multiprocessing 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser(description='Generate patches from Full Resolution images') 12 | parser.add_argument('--src_dir', default='../SIDD_Medium_Srgb/Data', type=str, help='Directory for full resolution images') 13 | parser.add_argument('--tar_dir', default='../SIDD_patches/train',type=str, help='Directory for image patches') 14 | parser.add_argument('--ps', default=256, type=int, help='Image Patch Size') 15 | parser.add_argument('--num_patches', default=300, type=int, help='Number of patches per image') 16 | parser.add_argument('--num_cores', default=10, type=int, help='Number of CPU Cores') 17 | 18 | args = parser.parse_args() 19 | 20 | src = args.src_dir 21 | tar = args.tar_dir 22 | PS = args.ps 23 | NUM_PATCHES = args.num_patches 24 | NUM_CORES = args.num_cores 25 | 26 | noisy_patchDir = os.path.join(tar, 'input') 27 | clean_patchDir = os.path.join(tar, 'groundtruth') 28 | 29 | if os.path.exists(tar): 30 | os.system("rm -r {}".format(tar)) 31 | 32 | os.makedirs(noisy_patchDir) 33 | os.makedirs(clean_patchDir) 34 | 35 | #get sorted folders 36 | files = natsorted(glob(os.path.join(src, '*', '*.PNG'))) 37 | 38 | noisy_files, clean_files = [], [] 39 | for file_ in files: 40 | filename = os.path.split(file_)[-1] 41 | if 'GT' in filename: 42 | clean_files.append(file_) 43 | if 'NOISY' in filename: 44 | noisy_files.append(file_) 45 | 46 | def save_files(i): 47 | noisy_file, clean_file = noisy_files[i], clean_files[i] 48 | noisy_img = cv2.imread(noisy_file) 49 | clean_img = cv2.imread(clean_file) 50 | 51 | H = noisy_img.shape[0] 52 | W = noisy_img.shape[1] 53 | for j in range(NUM_PATCHES): 54 | rr = np.random.randint(0, H - PS) 55 | cc = np.random.randint(0, W - PS) 56 | noisy_patch = noisy_img[rr:rr + PS, cc:cc + PS, :] 57 | clean_patch = clean_img[rr:rr + PS, cc:cc + PS, :] 58 | 59 | cv2.imwrite(os.path.join(noisy_patchDir, '{}_{}.png'.format(i+1,j+1)), noisy_patch) 60 | cv2.imwrite(os.path.join(clean_patchDir, '{}_{}.png'.format(i+1,j+1)), clean_patch) 61 | 62 | Parallel(n_jobs=NUM_CORES)(delayed(save_files)(i) for i in tqdm(range(len(noisy_files)))) 63 | -------------------------------------------------------------------------------- /Deraining/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | import utils 17 | 18 | from data_RGB import get_test_data 19 | from MPRNet import MPRNet 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deraining using MPRNet') 24 | 25 | parser.add_argument('--input_dir', default='./Datasets/test/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./pretrained_models/model_deraining.pth', type=str, help='Path to weights') 28 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 29 | 30 | args = parser.parse_args() 31 | 32 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 33 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 34 | 35 | model_restoration = MPRNet() 36 | 37 | utils.load_checkpoint(model_restoration,args.weights) 38 | print("===>Testing using weights: ",args.weights) 39 | model_restoration.cuda() 40 | model_restoration = nn.DataParallel(model_restoration) 41 | model_restoration.eval() 42 | 43 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] 44 | # datasets = ['Rain100L'] 45 | 46 | for dataset in datasets: 47 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'input') 48 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 49 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 50 | 51 | result_dir = os.path.join(args.result_dir, dataset) 52 | utils.mkdir(result_dir) 53 | 54 | with torch.no_grad(): 55 | for ii, data_test in enumerate(tqdm(test_loader), 0): 56 | torch.cuda.ipc_collect() 57 | torch.cuda.empty_cache() 58 | 59 | input_ = data_test[0].cuda() 60 | filenames = data_test[1] 61 | 62 | restored = model_restoration(input_) 63 | restored = torch.clamp(restored[0],0,1) 64 | 65 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 66 | 67 | for batch in range(len(restored)): 68 | restored_img = img_as_ubyte(restored[batch]) 69 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 70 | -------------------------------------------------------------------------------- /Denoising/test_SIDD.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from MPRNet import MPRNet 18 | from skimage import img_as_ubyte 19 | import h5py 20 | import scipy.io as sio 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Denoising using MPRNet') 24 | 25 | parser.add_argument('--input_dir', default='./Datasets/SIDD/test/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/SIDD/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./pretrained_models/model_denoising.pth', type=str, help='Path to weights') 28 | parser.add_argument('--gpus', default='1', type=str, help='CUDA_VISIBLE_DEVICES') 29 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 30 | 31 | args = parser.parse_args() 32 | 33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 35 | 36 | result_dir = os.path.join(args.result_dir, 'mat') 37 | utils.mkdir(result_dir) 38 | 39 | if args.save_images: 40 | result_dir_img = os.path.join(args.result_dir, 'png') 41 | utils.mkdir(result_dir_img) 42 | 43 | model_restoration = MPRNet() 44 | 45 | utils.load_checkpoint(model_restoration,args.weights) 46 | print("===>Testing using weights: ",args.weights) 47 | model_restoration.cuda() 48 | model_restoration = nn.DataParallel(model_restoration) 49 | model_restoration.eval() 50 | 51 | # Process data 52 | filepath = os.path.join(args.input_dir, 'ValidationNoisyBlocksSrgb.mat') 53 | img = sio.loadmat(filepath) 54 | Inoisy = np.float32(np.array(img['ValidationNoisyBlocksSrgb'])) 55 | Inoisy /=255. 56 | restored = np.zeros_like(Inoisy) 57 | with torch.no_grad(): 58 | for i in tqdm(range(40)): 59 | for k in range(32): 60 | noisy_patch = torch.from_numpy(Inoisy[i,k,:,:,:]).unsqueeze(0).permute(0,3,1,2).cuda() 61 | restored_patch = model_restoration(noisy_patch) 62 | restored_patch = torch.clamp(restored_patch[0],0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0) 63 | restored[i,k,:,:,:] = restored_patch 64 | 65 | if args.save_images: 66 | save_file = os.path.join(result_dir_img, '%04d_%02d.png'%(i+1,k+1)) 67 | utils.save_img(save_file, img_as_ubyte(restored_patch)) 68 | 69 | # save denoised data 70 | sio.savemat(os.path.join(result_dir, 'Idenoised.mat'), {"Idenoised": restored,}) 71 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torchvision.transforms.functional as TF 5 | from PIL import Image 6 | import os 7 | from runpy import run_path 8 | from skimage import img_as_ubyte 9 | from collections import OrderedDict 10 | from natsort import natsorted 11 | from glob import glob 12 | import cv2 13 | import argparse 14 | 15 | parser = argparse.ArgumentParser(description='Demo MPRNet') 16 | parser.add_argument('--input_dir', default='./samples/input/', type=str, help='Input images') 17 | parser.add_argument('--result_dir', default='./samples/output/', type=str, help='Directory for results') 18 | parser.add_argument('--task', required=True, type=str, help='Task to run', choices=['Deblurring', 'Denoising', 'Deraining']) 19 | 20 | args = parser.parse_args() 21 | 22 | def save_img(filepath, img): 23 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 24 | 25 | def load_checkpoint(model, weights): 26 | checkpoint = torch.load(weights) 27 | try: 28 | model.load_state_dict(checkpoint["state_dict"]) 29 | except: 30 | state_dict = checkpoint["state_dict"] 31 | new_state_dict = OrderedDict() 32 | for k, v in state_dict.items(): 33 | name = k[7:] # remove `module.` 34 | new_state_dict[name] = v 35 | model.load_state_dict(new_state_dict) 36 | 37 | task = args.task 38 | inp_dir = args.input_dir 39 | out_dir = args.result_dir 40 | 41 | os.makedirs(out_dir, exist_ok=True) 42 | 43 | files = natsorted(glob(os.path.join(inp_dir, '*.jpg')) 44 | + glob(os.path.join(inp_dir, '*.JPG')) 45 | + glob(os.path.join(inp_dir, '*.png')) 46 | + glob(os.path.join(inp_dir, '*.PNG'))) 47 | 48 | if len(files) == 0: 49 | raise Exception(f"No files found at {inp_dir}") 50 | 51 | # Load corresponding model architecture and weights 52 | load_file = run_path(os.path.join(task, "MPRNet.py")) 53 | model = load_file['MPRNet']() 54 | model.cuda() 55 | 56 | weights = os.path.join(task, "pretrained_models", "model_"+task.lower()+".pth") 57 | load_checkpoint(model, weights) 58 | model.eval() 59 | 60 | img_multiple_of = 8 61 | 62 | for file_ in files: 63 | img = Image.open(file_).convert('RGB') 64 | input_ = TF.to_tensor(img).unsqueeze(0).cuda() 65 | 66 | # Pad the input if not_multiple_of 8 67 | h,w = input_.shape[2], input_.shape[3] 68 | H,W = ((h+img_multiple_of)//img_multiple_of)*img_multiple_of, ((w+img_multiple_of)//img_multiple_of)*img_multiple_of 69 | padh = H-h if h%img_multiple_of!=0 else 0 70 | padw = W-w if w%img_multiple_of!=0 else 0 71 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 72 | 73 | with torch.no_grad(): 74 | restored = model(input_) 75 | restored = restored[0] 76 | restored = torch.clamp(restored, 0, 1) 77 | 78 | # Unpad the output 79 | restored = restored[:,:,:h,:w] 80 | 81 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 82 | restored = img_as_ubyte(restored[0]) 83 | 84 | f = os.path.splitext(os.path.split(file_)[-1])[0] 85 | save_img((os.path.join(out_dir, f+'.png')), restored) 86 | 87 | print(f"Files saved at {out_dir}") 88 | -------------------------------------------------------------------------------- /Deblurring/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch.nn as nn 13 | import torch 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | import utils 17 | 18 | from data_RGB import get_test_data 19 | from MPRNet import MPRNet 20 | from skimage import img_as_ubyte 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Deblurring using MPRNet') 24 | 25 | parser.add_argument('--input_dir', default='./Datasets/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./pretrained_models/model_deblurring.pth', type=str, help='Path to weights') 28 | parser.add_argument('--dataset', default='GoPro', type=str, help='Test Dataset') # ['GoPro', 'HIDE', 'RealBlur_J', 'RealBlur_R'] 29 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 30 | 31 | args = parser.parse_args() 32 | 33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 35 | 36 | model_restoration = MPRNet() 37 | 38 | utils.load_checkpoint(model_restoration,args.weights) 39 | print("===>Testing using weights: ",args.weights) 40 | model_restoration.cuda() 41 | model_restoration = nn.DataParallel(model_restoration) 42 | model_restoration.eval() 43 | 44 | dataset = args.dataset 45 | rgb_dir_test = os.path.join(args.input_dir, dataset, 'test', 'input') 46 | test_dataset = get_test_data(rgb_dir_test, img_options={}) 47 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=4, drop_last=False, pin_memory=True) 48 | 49 | result_dir = os.path.join(args.result_dir, dataset) 50 | utils.mkdir(result_dir) 51 | 52 | with torch.no_grad(): 53 | for ii, data_test in enumerate(tqdm(test_loader), 0): 54 | torch.cuda.ipc_collect() 55 | torch.cuda.empty_cache() 56 | 57 | input_ = data_test[0].cuda() 58 | filenames = data_test[1] 59 | 60 | # Padding in case images are not multiples of 8 61 | if dataset == 'RealBlur_J' or dataset == 'RealBlur_R': 62 | factor = 8 63 | h,w = input_.shape[2], input_.shape[3] 64 | H,W = ((h+factor)//factor)*factor, ((w+factor)//factor)*factor 65 | padh = H-h if h%factor!=0 else 0 66 | padw = W-w if w%factor!=0 else 0 67 | input_ = F.pad(input_, (0,padw,0,padh), 'reflect') 68 | 69 | restored = model_restoration(input_) 70 | restored = torch.clamp(restored[0],0,1) 71 | 72 | # Unpad images to original dimensions 73 | if dataset == 'RealBlur_J' or dataset == 'RealBlur_R': 74 | restored = restored[:,:,:h,:w] 75 | 76 | restored = restored.permute(0, 2, 3, 1).cpu().detach().numpy() 77 | 78 | for batch in range(len(restored)): 79 | restored_img = img_as_ubyte(restored[batch]) 80 | utils.save_img((os.path.join(result_dir, filenames[batch]+'.png')), restored_img) 81 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | -------------------------------------------------------------------------------- /Denoising/test_DND.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import numpy as np 8 | import os 9 | import argparse 10 | from tqdm import tqdm 11 | 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.functional as F 15 | import utils 16 | 17 | from MPRNet import MPRNet 18 | from skimage import img_as_ubyte 19 | import h5py 20 | import scipy.io as sio 21 | from pdb import set_trace as stx 22 | 23 | parser = argparse.ArgumentParser(description='Image Denoising using MPRNet') 24 | 25 | parser.add_argument('--input_dir', default='./Datasets/DND/', type=str, help='Directory of validation images') 26 | parser.add_argument('--result_dir', default='./results/DND/test/', type=str, help='Directory for results') 27 | parser.add_argument('--weights', default='./pretrained_models/model_denoising.pth', type=str, help='Path to weights') 28 | parser.add_argument('--gpus', default='1', type=str, help='CUDA_VISIBLE_DEVICES') 29 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 30 | 31 | args = parser.parse_args() 32 | 33 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 34 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 35 | 36 | result_dir = os.path.join(args.result_dir, 'mat') 37 | utils.mkdir(result_dir) 38 | 39 | if args.save_images: 40 | result_dir_img = os.path.join(args.result_dir, 'png') 41 | utils.mkdir(result_dir_img) 42 | 43 | model_restoration = MPRNet() 44 | 45 | utils.load_checkpoint(model_restoration,args.weights) 46 | print("===>Testing using weights: ",args.weights) 47 | model_restoration.cuda() 48 | model_restoration = nn.DataParallel(model_restoration) 49 | model_restoration.eval() 50 | 51 | israw = False 52 | eval_version="1.0" 53 | 54 | # Load info 55 | infos = h5py.File(os.path.join(args.input_dir, 'info.mat'), 'r') 56 | info = infos['info'] 57 | bb = info['boundingboxes'] 58 | 59 | # Process data 60 | with torch.no_grad(): 61 | for i in tqdm(range(50)): 62 | Idenoised = np.zeros((20,), dtype=np.object) 63 | filename = '%04d.mat'%(i+1) 64 | filepath = os.path.join(args.input_dir, 'images_srgb', filename) 65 | img = h5py.File(filepath, 'r') 66 | Inoisy = np.float32(np.array(img['InoisySRGB']).T) 67 | 68 | # bounding box 69 | ref = bb[0][i] 70 | boxes = np.array(info[ref]).T 71 | 72 | for k in range(20): 73 | idx = [int(boxes[k,0]-1),int(boxes[k,2]),int(boxes[k,1]-1),int(boxes[k,3])] 74 | noisy_patch = torch.from_numpy(Inoisy[idx[0]:idx[1],idx[2]:idx[3],:]).unsqueeze(0).permute(0,3,1,2).cuda() 75 | restored_patch = model_restoration(noisy_patch) 76 | restored_patch = torch.clamp(restored_patch[0],0,1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy() 77 | Idenoised[k] = restored_patch 78 | 79 | if args.save_images: 80 | save_file = os.path.join(result_dir_img, '%04d_%02d.png'%(i+1,k+1)) 81 | denoised_img = img_as_ubyte(restored_patch) 82 | utils.save_img(save_file, denoised_img) 83 | 84 | # save denoised data 85 | sio.savemat(os.path.join(result_dir, filename), 86 | {"Idenoised": Idenoised, 87 | "israw": israw, 88 | "eval_version": eval_version}, 89 | ) 90 | -------------------------------------------------------------------------------- /Deblurring/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 23 14:35:48 2019 5 | 6 | @author: aditya 7 | """ 8 | 9 | r"""This module provides package-wide configuration management.""" 10 | from typing import Any, List 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | 15 | class Config(object): 16 | r""" 17 | A collection of all the required configuration parameters. This class is a nested dict-like 18 | structure, with nested keys accessible as attributes. It contains sensible default values for 19 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 20 | a list of attributes and values. 21 | 22 | Extended Summary 23 | ---------------- 24 | This class definition contains default values corresponding to ``joint_training`` phase, as it 25 | is the final training phase and uses almost all the configuration parameters. Modification of 26 | any parameter after instantiating this class is not possible, so you must override required 27 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 28 | 29 | Parameters 30 | ---------- 31 | config_yaml: str 32 | Path to a YAML file containing configuration parameters to override. 33 | config_override: List[Any], optional (default= []) 34 | A list of sequential attributes and values of parameters to override. This happens after 35 | overriding from YAML file. 36 | 37 | Examples 38 | -------- 39 | Let a YAML file named "config.yaml" specify these parameters to override:: 40 | 41 | ALPHA: 1000.0 42 | BETA: 0.5 43 | 44 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 45 | >>> _C.ALPHA # default: 100.0 46 | 1000.0 47 | >>> _C.BATCH_SIZE # default: 256 48 | 2048 49 | >>> _C.BETA # default: 0.1 50 | 0.7 51 | 52 | Attributes 53 | ---------- 54 | """ 55 | 56 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 57 | 58 | self._C = CN() 59 | self._C.GPU = [0] 60 | self._C.VERBOSE = False 61 | 62 | self._C.MODEL = CN() 63 | self._C.MODEL.MODE = 'global' 64 | self._C.MODEL.SESSION = 'ps128_bs1' 65 | 66 | self._C.OPTIM = CN() 67 | self._C.OPTIM.BATCH_SIZE = 1 68 | self._C.OPTIM.NUM_EPOCHS = 100 69 | self._C.OPTIM.NEPOCH_DECAY = [100] 70 | self._C.OPTIM.LR_INITIAL = 0.0002 71 | self._C.OPTIM.LR_MIN = 0.0002 72 | self._C.OPTIM.BETA1 = 0.5 73 | 74 | self._C.TRAINING = CN() 75 | self._C.TRAINING.VAL_AFTER_EVERY = 3 76 | self._C.TRAINING.RESUME = False 77 | self._C.TRAINING.SAVE_IMAGES = False 78 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train' 79 | self._C.TRAINING.VAL_DIR = 'images_dir/val' 80 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 81 | self._C.TRAINING.TRAIN_PS = 64 82 | self._C.TRAINING.VAL_PS = 64 83 | 84 | # Override parameter values from YAML file first, then from override list. 85 | self._C.merge_from_file(config_yaml) 86 | self._C.merge_from_list(config_override) 87 | 88 | # Make an instantiated object of this class immutable. 89 | self._C.freeze() 90 | 91 | def dump(self, file_path: str): 92 | r"""Save config at the specified file path. 93 | 94 | Parameters 95 | ---------- 96 | file_path: str 97 | (YAML) path to save config at. 98 | """ 99 | self._C.dump(stream=open(file_path, "w")) 100 | 101 | def __getattr__(self, attr: str): 102 | return self._C.__getattr__(attr) 103 | 104 | def __repr__(self): 105 | return self._C.__repr__() 106 | -------------------------------------------------------------------------------- /Denoising/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 23 14:35:48 2019 5 | 6 | @author: aditya 7 | """ 8 | 9 | r"""This module provides package-wide configuration management.""" 10 | from typing import Any, List 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | 15 | class Config(object): 16 | r""" 17 | A collection of all the required configuration parameters. This class is a nested dict-like 18 | structure, with nested keys accessible as attributes. It contains sensible default values for 19 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 20 | a list of attributes and values. 21 | 22 | Extended Summary 23 | ---------------- 24 | This class definition contains default values corresponding to ``joint_training`` phase, as it 25 | is the final training phase and uses almost all the configuration parameters. Modification of 26 | any parameter after instantiating this class is not possible, so you must override required 27 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 28 | 29 | Parameters 30 | ---------- 31 | config_yaml: str 32 | Path to a YAML file containing configuration parameters to override. 33 | config_override: List[Any], optional (default= []) 34 | A list of sequential attributes and values of parameters to override. This happens after 35 | overriding from YAML file. 36 | 37 | Examples 38 | -------- 39 | Let a YAML file named "config.yaml" specify these parameters to override:: 40 | 41 | ALPHA: 1000.0 42 | BETA: 0.5 43 | 44 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 45 | >>> _C.ALPHA # default: 100.0 46 | 1000.0 47 | >>> _C.BATCH_SIZE # default: 256 48 | 2048 49 | >>> _C.BETA # default: 0.1 50 | 0.7 51 | 52 | Attributes 53 | ---------- 54 | """ 55 | 56 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 57 | 58 | self._C = CN() 59 | self._C.GPU = [0] 60 | self._C.VERBOSE = False 61 | 62 | self._C.MODEL = CN() 63 | self._C.MODEL.MODE = 'global' 64 | self._C.MODEL.SESSION = 'ps128_bs1' 65 | 66 | self._C.OPTIM = CN() 67 | self._C.OPTIM.BATCH_SIZE = 1 68 | self._C.OPTIM.NUM_EPOCHS = 100 69 | self._C.OPTIM.NEPOCH_DECAY = [100] 70 | self._C.OPTIM.LR_INITIAL = 0.0002 71 | self._C.OPTIM.LR_MIN = 0.0002 72 | self._C.OPTIM.BETA1 = 0.5 73 | 74 | self._C.TRAINING = CN() 75 | self._C.TRAINING.VAL_AFTER_EVERY = 3 76 | self._C.TRAINING.RESUME = False 77 | self._C.TRAINING.SAVE_IMAGES = False 78 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train' 79 | self._C.TRAINING.VAL_DIR = 'images_dir/val' 80 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 81 | self._C.TRAINING.TRAIN_PS = 64 82 | self._C.TRAINING.VAL_PS = 64 83 | 84 | # Override parameter values from YAML file first, then from override list. 85 | self._C.merge_from_file(config_yaml) 86 | self._C.merge_from_list(config_override) 87 | 88 | # Make an instantiated object of this class immutable. 89 | self._C.freeze() 90 | 91 | def dump(self, file_path: str): 92 | r"""Save config at the specified file path. 93 | 94 | Parameters 95 | ---------- 96 | file_path: str 97 | (YAML) path to save config at. 98 | """ 99 | self._C.dump(stream=open(file_path, "w")) 100 | 101 | def __getattr__(self, attr: str): 102 | return self._C.__getattr__(attr) 103 | 104 | def __repr__(self): 105 | return self._C.__repr__() 106 | -------------------------------------------------------------------------------- /Deraining/config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | """ 4 | Created on Tue Jul 23 14:35:48 2019 5 | 6 | @author: aditya 7 | """ 8 | 9 | r"""This module provides package-wide configuration management.""" 10 | from typing import Any, List 11 | 12 | from yacs.config import CfgNode as CN 13 | 14 | 15 | class Config(object): 16 | r""" 17 | A collection of all the required configuration parameters. This class is a nested dict-like 18 | structure, with nested keys accessible as attributes. It contains sensible default values for 19 | all the parameters, which may be overriden by (first) through a YAML file and (second) through 20 | a list of attributes and values. 21 | 22 | Extended Summary 23 | ---------------- 24 | This class definition contains default values corresponding to ``joint_training`` phase, as it 25 | is the final training phase and uses almost all the configuration parameters. Modification of 26 | any parameter after instantiating this class is not possible, so you must override required 27 | parameter values in either through ``config_yaml`` file or ``config_override`` list. 28 | 29 | Parameters 30 | ---------- 31 | config_yaml: str 32 | Path to a YAML file containing configuration parameters to override. 33 | config_override: List[Any], optional (default= []) 34 | A list of sequential attributes and values of parameters to override. This happens after 35 | overriding from YAML file. 36 | 37 | Examples 38 | -------- 39 | Let a YAML file named "config.yaml" specify these parameters to override:: 40 | 41 | ALPHA: 1000.0 42 | BETA: 0.5 43 | 44 | >>> _C = Config("config.yaml", ["OPTIM.BATCH_SIZE", 2048, "BETA", 0.7]) 45 | >>> _C.ALPHA # default: 100.0 46 | 1000.0 47 | >>> _C.BATCH_SIZE # default: 256 48 | 2048 49 | >>> _C.BETA # default: 0.1 50 | 0.7 51 | 52 | Attributes 53 | ---------- 54 | """ 55 | 56 | def __init__(self, config_yaml: str, config_override: List[Any] = []): 57 | 58 | self._C = CN() 59 | self._C.GPU = [0] 60 | self._C.VERBOSE = False 61 | 62 | self._C.MODEL = CN() 63 | self._C.MODEL.MODE = 'global' 64 | self._C.MODEL.SESSION = 'ps128_bs1' 65 | 66 | self._C.OPTIM = CN() 67 | self._C.OPTIM.BATCH_SIZE = 1 68 | self._C.OPTIM.NUM_EPOCHS = 100 69 | self._C.OPTIM.NEPOCH_DECAY = [100] 70 | self._C.OPTIM.LR_INITIAL = 0.0002 71 | self._C.OPTIM.LR_MIN = 0.0002 72 | self._C.OPTIM.BETA1 = 0.5 73 | 74 | self._C.TRAINING = CN() 75 | self._C.TRAINING.VAL_AFTER_EVERY = 3 76 | self._C.TRAINING.RESUME = False 77 | self._C.TRAINING.SAVE_IMAGES = False 78 | self._C.TRAINING.TRAIN_DIR = 'images_dir/train' 79 | self._C.TRAINING.VAL_DIR = 'images_dir/val' 80 | self._C.TRAINING.SAVE_DIR = 'checkpoints' 81 | self._C.TRAINING.TRAIN_PS = 64 82 | self._C.TRAINING.VAL_PS = 64 83 | 84 | # Override parameter values from YAML file first, then from override list. 85 | self._C.merge_from_file(config_yaml) 86 | self._C.merge_from_list(config_override) 87 | 88 | # Make an instantiated object of this class immutable. 89 | self._C.freeze() 90 | 91 | def dump(self, file_path: str): 92 | r"""Save config at the specified file path. 93 | 94 | Parameters 95 | ---------- 96 | file_path: str 97 | (YAML) path to save config at. 98 | """ 99 | self._C.dump(stream=open(file_path, "w")) 100 | 101 | def __getattr__(self, attr: str): 102 | return self._C.__getattr__(attr) 103 | 104 | def __repr__(self): 105 | return self._C.__repr__() 106 | -------------------------------------------------------------------------------- /Deblurring/evaluate_RealBlur.py: -------------------------------------------------------------------------------- 1 | ## Multi-Stage Progressive Image Restoration 2 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 3 | ## https://arxiv.org/abs/2102.02808 4 | 5 | import os 6 | import numpy as np 7 | from glob import glob 8 | from natsort import natsorted 9 | from skimage import io 10 | import cv2 11 | from skimage.metrics import structural_similarity 12 | from tqdm import tqdm 13 | import concurrent.futures 14 | 15 | def image_align(deblurred, gt): 16 | # this function is based on kohler evaluation code 17 | z = deblurred 18 | c = np.ones_like(z) 19 | x = gt 20 | 21 | zs = (np.sum(x * z) / np.sum(z * z)) * z # simple intensity matching 22 | 23 | warp_mode = cv2.MOTION_HOMOGRAPHY 24 | warp_matrix = np.eye(3, 3, dtype=np.float32) 25 | 26 | # Specify the number of iterations. 27 | number_of_iterations = 100 28 | 29 | termination_eps = 0 30 | 31 | criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 32 | number_of_iterations, termination_eps) 33 | 34 | # Run the ECC algorithm. The results are stored in warp_matrix. 35 | (cc, warp_matrix) = cv2.findTransformECC(cv2.cvtColor(x, cv2.COLOR_RGB2GRAY), cv2.cvtColor(zs, cv2.COLOR_RGB2GRAY), warp_matrix, warp_mode, criteria, inputMask=None, gaussFiltSize=5) 36 | 37 | target_shape = x.shape 38 | shift = warp_matrix 39 | 40 | zr = cv2.warpPerspective( 41 | zs, 42 | warp_matrix, 43 | (target_shape[1], target_shape[0]), 44 | flags=cv2.INTER_CUBIC+ cv2.WARP_INVERSE_MAP, 45 | borderMode=cv2.BORDER_REFLECT) 46 | 47 | cr = cv2.warpPerspective( 48 | np.ones_like(zs, dtype='float32'), 49 | warp_matrix, 50 | (target_shape[1], target_shape[0]), 51 | flags=cv2.INTER_NEAREST+ cv2.WARP_INVERSE_MAP, 52 | borderMode=cv2.BORDER_CONSTANT, 53 | borderValue=0) 54 | 55 | zr = zr * cr 56 | xr = x * cr 57 | 58 | return zr, xr, cr, shift 59 | 60 | def compute_psnr(image_true, image_test, image_mask, data_range=None): 61 | # this function is based on skimage.metrics.peak_signal_noise_ratio 62 | err = np.sum((image_true - image_test) ** 2, dtype=np.float64) / np.sum(image_mask) 63 | return 10 * np.log10((data_range ** 2) / err) 64 | 65 | 66 | def compute_ssim(tar_img, prd_img, cr1): 67 | ssim_pre, ssim_map = structural_similarity(tar_img, prd_img, multichannel=True, gaussian_weights=True, use_sample_covariance=False, data_range = 1.0, full=True) 68 | ssim_map = ssim_map * cr1 69 | r = int(3.5 * 1.5 + 0.5) # radius as in ndimage 70 | win_size = 2 * r + 1 71 | pad = (win_size - 1) // 2 72 | ssim = ssim_map[pad:-pad,pad:-pad,:] 73 | crop_cr1 = cr1[pad:-pad,pad:-pad,:] 74 | ssim = ssim.sum(axis=0).sum(axis=0)/crop_cr1.sum(axis=0).sum(axis=0) 75 | ssim = np.mean(ssim) 76 | return ssim 77 | 78 | def proc(filename): 79 | tar,prd = filename 80 | tar_img = io.imread(tar) 81 | prd_img = io.imread(prd) 82 | 83 | tar_img = tar_img.astype(np.float32)/255.0 84 | prd_img = prd_img.astype(np.float32)/255.0 85 | 86 | prd_img, tar_img, cr1, shift = image_align(prd_img, tar_img) 87 | 88 | PSNR = compute_psnr(tar_img, prd_img, cr1, data_range=1) 89 | SSIM = compute_ssim(tar_img, prd_img, cr1) 90 | return (PSNR,SSIM) 91 | 92 | datasets = ['RealBlur_J', 'RealBlur_R'] 93 | 94 | for dataset in datasets: 95 | 96 | file_path = os.path.join('results' , dataset) 97 | gt_path = os.path.join('Datasets', dataset, 'test', 'target') 98 | 99 | path_list = natsorted(glob(os.path.join(file_path, '*.png')) + glob(os.path.join(file_path, '*.jpg'))) 100 | gt_list = natsorted(glob(os.path.join(gt_path, '*.png')) + glob(os.path.join(gt_path, '*.jpg'))) 101 | 102 | assert len(path_list) != 0, "Predicted files not found" 103 | assert len(gt_list) != 0, "Target files not found" 104 | 105 | psnr, ssim = [], [] 106 | img_files =[(i, j) for i,j in zip(gt_list,path_list)] 107 | with concurrent.futures.ProcessPoolExecutor(max_workers=10) as executor: 108 | for filename, PSNR_SSIM in zip(img_files, executor.map(proc, img_files)): 109 | psnr.append(PSNR_SSIM[0]) 110 | ssim.append(PSNR_SSIM[1]) 111 | 112 | avg_psnr = sum(psnr)/len(psnr) 113 | avg_ssim = sum(ssim)/len(ssim) 114 | 115 | print('For {:s} dataset PSNR: {:f} SSIM: {:f}\n'.format(dataset, avg_psnr, avg_ssim)) 116 | -------------------------------------------------------------------------------- /Deraining/dataset_RGB.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.data import Dataset 3 | import torch 4 | from PIL import Image 5 | import torchvision.transforms.functional as TF 6 | from pdb import set_trace as stx 7 | import random 8 | 9 | def is_image_file(filename): 10 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 11 | 12 | class DataLoaderTrain(Dataset): 13 | def __init__(self, rgb_dir, img_options=None): 14 | super(DataLoaderTrain, self).__init__() 15 | 16 | inp_files = sorted(os.listdir(os.path.join(rgb_dir, 'input'))) 17 | tar_files = sorted(os.listdir(os.path.join(rgb_dir, 'target'))) 18 | 19 | self.inp_filenames = [os.path.join(rgb_dir, 'input', x) for x in inp_files if is_image_file(x)] 20 | self.tar_filenames = [os.path.join(rgb_dir, 'target', x) for x in tar_files if is_image_file(x)] 21 | 22 | self.img_options = img_options 23 | self.sizex = len(self.tar_filenames) # get the size of target 24 | 25 | self.ps = self.img_options['patch_size'] 26 | 27 | def __len__(self): 28 | return self.sizex 29 | 30 | def __getitem__(self, index): 31 | index_ = index % self.sizex 32 | ps = self.ps 33 | 34 | inp_path = self.inp_filenames[index_] 35 | tar_path = self.tar_filenames[index_] 36 | 37 | inp_img = Image.open(inp_path) 38 | tar_img = Image.open(tar_path) 39 | 40 | w,h = tar_img.size 41 | padw = ps-w if w 1: 54 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 55 | 56 | 57 | new_lr = opt.OPTIM.LR_INITIAL 58 | 59 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8) 60 | 61 | 62 | ######### Scheduler ########### 63 | warmup_epochs = 3 64 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN) 65 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 66 | scheduler.step() 67 | 68 | ######### Resume ########### 69 | if opt.TRAINING.RESUME: 70 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 71 | utils.load_checkpoint(model_restoration,path_chk_rest) 72 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 73 | utils.load_optim(optimizer, path_chk_rest) 74 | 75 | for i in range(1, start_epoch): 76 | scheduler.step() 77 | new_lr = scheduler.get_lr()[0] 78 | print('------------------------------------------------------------------------------') 79 | print("==> Resuming Training with learning rate:", new_lr) 80 | print('------------------------------------------------------------------------------') 81 | 82 | if len(device_ids)>1: 83 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) 84 | 85 | ######### Loss ########### 86 | criterion_char = losses.CharbonnierLoss() 87 | criterion_edge = losses.EdgeLoss() 88 | 89 | ######### DataLoaders ########### 90 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 91 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) 92 | 93 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) 94 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 95 | 96 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) 97 | print('===> Loading datasets') 98 | 99 | best_psnr = 0 100 | best_epoch = 0 101 | 102 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 103 | epoch_start_time = time.time() 104 | epoch_loss = 0 105 | train_id = 1 106 | 107 | model_restoration.train() 108 | for i, data in enumerate(tqdm(train_loader), 0): 109 | 110 | # zero_grad 111 | for param in model_restoration.parameters(): 112 | param.grad = None 113 | 114 | target = data[0].cuda() 115 | input_ = data[1].cuda() 116 | 117 | restored = model_restoration(input_) 118 | 119 | # Compute loss at each stage 120 | loss_char = torch.sum([criterion_char(restored[j],target) for j in range(len(restored))]) 121 | loss_edge = torch.sum([criterion_edge(restored[j],target) for j in range(len(restored))]) 122 | loss = (loss_char) + (0.05*loss_edge) 123 | 124 | loss.backward() 125 | optimizer.step() 126 | epoch_loss +=loss.item() 127 | 128 | #### Evaluation #### 129 | if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0: 130 | model_restoration.eval() 131 | psnr_val_rgb = [] 132 | for ii, data_val in enumerate((val_loader), 0): 133 | target = data_val[0].cuda() 134 | input_ = data_val[1].cuda() 135 | 136 | with torch.no_grad(): 137 | restored = model_restoration(input_) 138 | restored = restored[0] 139 | 140 | for res,tar in zip(restored,target): 141 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 142 | 143 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 144 | 145 | if psnr_val_rgb > best_psnr: 146 | best_psnr = psnr_val_rgb 147 | best_epoch = epoch 148 | torch.save({'epoch': epoch, 149 | 'state_dict': model_restoration.state_dict(), 150 | 'optimizer' : optimizer.state_dict() 151 | }, os.path.join(model_dir,"model_best.pth")) 152 | 153 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 154 | 155 | torch.save({'epoch': epoch, 156 | 'state_dict': model_restoration.state_dict(), 157 | 'optimizer' : optimizer.state_dict() 158 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 159 | 160 | scheduler.step() 161 | 162 | print("------------------------------------------------------------------") 163 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) 164 | print("------------------------------------------------------------------") 165 | 166 | torch.save({'epoch': epoch, 167 | 'state_dict': model_restoration.state_dict(), 168 | 'optimizer' : optimizer.state_dict() 169 | }, os.path.join(model_dir,"model_latest.pth")) 170 | 171 | -------------------------------------------------------------------------------- /Deraining/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import Config 3 | opt = Config('training.yml') 4 | 5 | gpus = ','.join([str(i) for i in opt.GPU]) 6 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 7 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 8 | 9 | import torch 10 | torch.backends.cudnn.benchmark = True 11 | 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.utils.data import DataLoader 16 | 17 | import random 18 | import time 19 | import numpy as np 20 | 21 | import utils 22 | from data_RGB import get_training_data, get_validation_data 23 | from MPRNet import MPRNet 24 | import losses 25 | from warmup_scheduler import GradualWarmupScheduler 26 | from tqdm import tqdm 27 | from pdb import set_trace as stx 28 | 29 | ######### Set Seeds ########### 30 | random.seed(1234) 31 | np.random.seed(1234) 32 | torch.manual_seed(1234) 33 | torch.cuda.manual_seed_all(1234) 34 | 35 | start_epoch = 1 36 | mode = opt.MODEL.MODE 37 | session = opt.MODEL.SESSION 38 | 39 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) 40 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) 41 | 42 | utils.mkdir(result_dir) 43 | utils.mkdir(model_dir) 44 | 45 | train_dir = opt.TRAINING.TRAIN_DIR 46 | val_dir = opt.TRAINING.VAL_DIR 47 | 48 | ######### Model ########### 49 | model_restoration = MPRNet() 50 | model_restoration.cuda() 51 | 52 | device_ids = [i for i in range(torch.cuda.device_count())] 53 | if torch.cuda.device_count() > 1: 54 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 55 | 56 | 57 | new_lr = opt.OPTIM.LR_INITIAL 58 | 59 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8) 60 | 61 | 62 | ######### Scheduler ########### 63 | warmup_epochs = 3 64 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs, eta_min=opt.OPTIM.LR_MIN) 65 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 66 | scheduler.step() 67 | 68 | ######### Resume ########### 69 | if opt.TRAINING.RESUME: 70 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 71 | utils.load_checkpoint(model_restoration,path_chk_rest) 72 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 73 | utils.load_optim(optimizer, path_chk_rest) 74 | 75 | for i in range(1, start_epoch): 76 | scheduler.step() 77 | new_lr = scheduler.get_lr()[0] 78 | print('------------------------------------------------------------------------------') 79 | print("==> Resuming Training with learning rate:", new_lr) 80 | print('------------------------------------------------------------------------------') 81 | 82 | if len(device_ids)>1: 83 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) 84 | 85 | ######### Loss ########### 86 | criterion_char = losses.CharbonnierLoss() 87 | criterion_edge = losses.EdgeLoss() 88 | 89 | ######### DataLoaders ########### 90 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 91 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) 92 | 93 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) 94 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 95 | 96 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) 97 | print('===> Loading datasets') 98 | 99 | best_psnr = 0 100 | best_epoch = 0 101 | 102 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 103 | epoch_start_time = time.time() 104 | epoch_loss = 0 105 | train_id = 1 106 | 107 | model_restoration.train() 108 | for i, data in enumerate(tqdm(train_loader), 0): 109 | 110 | # zero_grad 111 | for param in model_restoration.parameters(): 112 | param.grad = None 113 | 114 | target = data[0].cuda() 115 | input_ = data[1].cuda() 116 | 117 | restored = model_restoration(input_) 118 | 119 | # Compute loss at each stage 120 | loss_char = torch.sum([criterion_char(restored[j],target) for j in range(len(restored))]) 121 | loss_edge = torch.sum([criterion_edge(restored[j],target) for j in range(len(restored))]) 122 | loss = (loss_char) + (0.05*loss_edge) 123 | 124 | loss.backward() 125 | optimizer.step() 126 | epoch_loss +=loss.item() 127 | 128 | #### Evaluation #### 129 | if epoch%opt.TRAINING.VAL_AFTER_EVERY == 0: 130 | model_restoration.eval() 131 | psnr_val_rgb = [] 132 | for ii, data_val in enumerate((val_loader), 0): 133 | target = data_val[0].cuda() 134 | input_ = data_val[1].cuda() 135 | 136 | with torch.no_grad(): 137 | restored = model_restoration(input_) 138 | restored = restored[0] 139 | 140 | for res,tar in zip(restored,target): 141 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 142 | 143 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 144 | 145 | if psnr_val_rgb > best_psnr: 146 | best_psnr = psnr_val_rgb 147 | best_epoch = epoch 148 | torch.save({'epoch': epoch, 149 | 'state_dict': model_restoration.state_dict(), 150 | 'optimizer' : optimizer.state_dict() 151 | }, os.path.join(model_dir,"model_best.pth")) 152 | 153 | print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr)) 154 | 155 | torch.save({'epoch': epoch, 156 | 'state_dict': model_restoration.state_dict(), 157 | 'optimizer' : optimizer.state_dict() 158 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 159 | 160 | scheduler.step() 161 | 162 | print("------------------------------------------------------------------") 163 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.8f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) 164 | print("------------------------------------------------------------------") 165 | 166 | torch.save({'epoch': epoch, 167 | 'state_dict': model_restoration.state_dict(), 168 | 'optimizer' : optimizer.state_dict() 169 | }, os.path.join(model_dir,"model_latest.pth")) 170 | 171 | -------------------------------------------------------------------------------- /Denoising/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from config import Config 3 | opt = Config('training.yml') 4 | 5 | gpus = ','.join([str(i) for i in opt.GPU]) 6 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 7 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 8 | 9 | import torch 10 | torch.backends.cudnn.benchmark = True 11 | 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | import torch.optim as optim 15 | from torch.utils.data import DataLoader 16 | 17 | import random 18 | import time 19 | import numpy as np 20 | 21 | import utils 22 | from data_RGB import get_training_data, get_validation_data 23 | from MPRNet import MPRNet 24 | import losses 25 | from warmup_scheduler import GradualWarmupScheduler 26 | from tqdm import tqdm 27 | from pdb import set_trace as stx 28 | 29 | ######### Set Seeds ########### 30 | random.seed(1234) 31 | np.random.seed(1234) 32 | torch.manual_seed(1234) 33 | torch.cuda.manual_seed_all(1234) 34 | 35 | start_epoch = 1 36 | mode = opt.MODEL.MODE 37 | session = opt.MODEL.SESSION 38 | 39 | result_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'results', session) 40 | model_dir = os.path.join(opt.TRAINING.SAVE_DIR, mode, 'models', session) 41 | 42 | utils.mkdir(result_dir) 43 | utils.mkdir(model_dir) 44 | 45 | train_dir = opt.TRAINING.TRAIN_DIR 46 | val_dir = opt.TRAINING.VAL_DIR 47 | 48 | ######### Model ########### 49 | model_restoration = MPRNet() 50 | model_restoration.cuda() 51 | 52 | device_ids = [i for i in range(torch.cuda.device_count())] 53 | if torch.cuda.device_count() > 1: 54 | print("\n\nLet's use", torch.cuda.device_count(), "GPUs!\n\n") 55 | 56 | 57 | new_lr = opt.OPTIM.LR_INITIAL 58 | 59 | optimizer = optim.Adam(model_restoration.parameters(), lr=new_lr, betas=(0.9, 0.999),eps=1e-8, weight_decay=1e-8) 60 | 61 | 62 | ######### Scheduler ########### 63 | warmup_epochs = 3 64 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.OPTIM.NUM_EPOCHS-warmup_epochs+40, eta_min=opt.OPTIM.LR_MIN) 65 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 66 | scheduler.step() 67 | 68 | ######### Resume ########### 69 | if opt.TRAINING.RESUME: 70 | path_chk_rest = utils.get_last_path(model_dir, '_latest.pth') 71 | utils.load_checkpoint(model_restoration,path_chk_rest) 72 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 73 | utils.load_optim(optimizer, path_chk_rest) 74 | 75 | for i in range(1, start_epoch): 76 | scheduler.step() 77 | new_lr = scheduler.get_lr()[0] 78 | print('------------------------------------------------------------------------------') 79 | print("==> Resuming Training with learning rate:", new_lr) 80 | print('------------------------------------------------------------------------------') 81 | 82 | if len(device_ids)>1: 83 | model_restoration = nn.DataParallel(model_restoration, device_ids = device_ids) 84 | 85 | ######### Loss ########### 86 | criterion = losses.CharbonnierLoss() 87 | 88 | ######### DataLoaders ########### 89 | train_dataset = get_training_data(train_dir, {'patch_size':opt.TRAINING.TRAIN_PS}) 90 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.OPTIM.BATCH_SIZE, shuffle=True, num_workers=16, drop_last=False, pin_memory=True) 91 | 92 | val_dataset = get_validation_data(val_dir, {'patch_size':opt.TRAINING.VAL_PS}) 93 | val_loader = DataLoader(dataset=val_dataset, batch_size=16, shuffle=False, num_workers=8, drop_last=False, pin_memory=True) 94 | 95 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.OPTIM.NUM_EPOCHS + 1)) 96 | print('===> Loading datasets') 97 | 98 | best_psnr = 0 99 | best_epoch = 0 100 | best_iter = 0 101 | 102 | eval_now = len(train_loader)//3 - 1 103 | print(f"\nEval after every {eval_now} Iterations !!!\n") 104 | mixup = utils.MixUp_AUG() 105 | 106 | for epoch in range(start_epoch, opt.OPTIM.NUM_EPOCHS + 1): 107 | epoch_start_time = time.time() 108 | epoch_loss = 0 109 | train_id = 1 110 | 111 | model_restoration.train() 112 | for i, data in enumerate(tqdm(train_loader), 0): 113 | 114 | # zero_grad 115 | for param in model_restoration.parameters(): 116 | param.grad = None 117 | 118 | target = data[0].cuda() 119 | input_ = data[1].cuda() 120 | 121 | if epoch>5: 122 | target, input_ = mixup.aug(target, input_) 123 | 124 | restored = model_restoration(input_) 125 | 126 | # Compute loss at each stage 127 | loss = torch.sum([criterion(torch.clamp(restored[j],0,1),target) for j in range(len(restored))]) 128 | 129 | loss.backward() 130 | optimizer.step() 131 | epoch_loss +=loss.item() 132 | 133 | #### Evaluation #### 134 | if i%eval_now==0 and i>0 and (epoch in [1,25,45] or epoch>60): 135 | model_restoration.eval() 136 | psnr_val_rgb = [] 137 | for ii, data_val in enumerate((val_loader), 0): 138 | target = data_val[0].cuda() 139 | input_ = data_val[1].cuda() 140 | 141 | with torch.no_grad(): 142 | restored = model_restoration(input_) 143 | restored = restored[0] 144 | 145 | for res,tar in zip(restored,target): 146 | psnr_val_rgb.append(utils.torchPSNR(res, tar)) 147 | 148 | psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item() 149 | 150 | if psnr_val_rgb > best_psnr: 151 | best_psnr = psnr_val_rgb 152 | best_epoch = epoch 153 | best_iter = i 154 | torch.save({'epoch': epoch, 155 | 'state_dict': model_restoration.state_dict(), 156 | 'optimizer' : optimizer.state_dict() 157 | }, os.path.join(model_dir,"model_best.pth")) 158 | 159 | print("[epoch %d it %d PSNR: %.4f --- best_epoch %d best_iter %d Best_PSNR %.4f]" % (epoch, i, psnr_val_rgb, best_epoch, best_iter, best_psnr)) 160 | 161 | torch.save({'epoch': epoch, 162 | 'state_dict': model_restoration.state_dict(), 163 | 'optimizer' : optimizer.state_dict() 164 | }, os.path.join(model_dir,f"model_epoch_{epoch}.pth")) 165 | 166 | model_restoration.train() 167 | 168 | scheduler.step() 169 | 170 | print("------------------------------------------------------------------") 171 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time, epoch_loss, scheduler.get_lr()[0])) 172 | print("------------------------------------------------------------------") 173 | 174 | torch.save({'epoch': epoch, 175 | 'state_dict': model_restoration.state_dict(), 176 | 'optimizer' : optimizer.state_dict() 177 | }, os.path.join(model_dir,"model_latest.pth")) 178 | 179 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Multi-Stage Progressive Image Restoration (CVPR 2021) 4 | 5 | [Syed Waqas Zamir](https://scholar.google.ae/citations?hl=en&user=POoai-QAAAAJ), [Aditya Arora](https://adityac8.github.io/), [Salman Khan](https://salman-h-khan.github.io/), [Munawar Hayat](https://scholar.google.com/citations?user=Mx8MbWYAAAAJ&hl=en), [Fahad Shahbaz Khan](https://scholar.google.es/citations?user=zvaeYnUAAAAJ&hl=en), [Ming-Hsuan Yang](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), and [Ling Shao](https://scholar.google.com/citations?user=z84rLjoAAAAJ&hl=en) 6 | 7 | [](https://arxiv.org/abs/2102.02808) 8 | [](https://drive.google.com/file/d/1mbfljawUuFUQN9V5g0Rmw1UdauJdckCu/view?usp=sharing) 9 | [](https://www.youtube.com/watch?v=0SMTPiLw5Vw) 10 | [](https://drive.google.com/file/d/1-L43wj-VTppkrR9AL6cPBJI2RJi3Hc_z/view?usp=sharing) 11 | 12 | 13 | 14 | ### News 15 | 16 | We are happy to see that our work has inspired the **Winning Solutions in NTIRE 2021 challenges**: 17 | - [Dual-pixel Defocus Deblurring Challenge](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Abuolaim_NTIRE_2021_Challenge_for_Defocus_Deblurring_Using_Dual-Pixel_Images_Methods_CVPRW_2021_paper.pdf) -- MRNet: Multi Refinement Network for Dual-pixel Images Defocus Deblurring 18 | - [Image Deblurring Challenge](https://openaccess.thecvf.com/content/CVPR2021W/NTIRE/papers/Nah_NTIRE_2021_Challenge_on_Image_Deblurring_CVPRW_2021_paper.pdf) -- HINet: Half Instance Normalization Network for Image Restoration 19 | 20 | 21 | 22 | > **Abstract:** *Image restoration tasks demand a complex balance between spatial details and high-level contextualized information while recovering images. In this paper, we propose a novel synergistic design that can optimally balance these competing goals. Our main proposal is a multi-stage architecture, that progressively learns restoration functions for the degraded inputs, thereby breaking down the overall recovery process into more manageable steps. Specifically, our model first learns the contextualized features using encoder-decoder architectures and later combines them with a high-resolution branch that retains local information. At each stage, we introduce a novel per-pixel adaptive design that leverages in-situ supervised attention to reweight the local features. A key ingredient in such a multi-stage architecture is the information exchange between different stages. To this end, we propose a two-faceted approach where the information is not only exchanged sequentially from early to late stages, but lateral connections between feature processing blocks also exist to avoid any loss of information. The resulting tightly interlinked multi-stage architecture, named as MPRNet, delivers strong performance gains on ten datasets across a range of tasks including image deraining, deblurring, and denoising. For example, on the Rain100L, GoPro and DND datasets, we obtain PSNR gains of 4 dB, 0.81 dB and 0.21 dB, respectively, compared to the state-of-the-art.* 23 | 24 | ## Network Architecture 25 | 26 | 27 | 28 | 29 | 30 | 31 | Overall Framework of MPRNet 32 | Supervised Attention Module (SAM) 33 | 34 | 35 | 36 | ## Installation 37 | The model is built in PyTorch 1.1.0 and tested on Ubuntu 16.04 environment (Python3.7, CUDA9.0, cuDNN7.5). 38 | 39 | For installing, follow these intructions 40 | ``` 41 | conda create -n pytorch1 python=3.7 42 | conda activate pytorch1 43 | conda install pytorch=1.1 torchvision=0.3 cudatoolkit=9.0 -c pytorch 44 | pip install matplotlib scikit-image opencv-python yacs joblib natsort h5py tqdm 45 | ``` 46 | 47 | Install warmup scheduler 48 | 49 | ``` 50 | cd pytorch-gradual-warmup-lr; python setup.py install; cd .. 51 | ``` 52 | 53 | ## Quick Run 54 | 55 | To test the pre-trained models of [Deblurring](https://drive.google.com/file/d/1QwQUVbk6YVOJViCsOKYNykCsdJSVGRtb/view?usp=sharing), [Deraining](https://drive.google.com/file/d/1O3WEJbcat7eTY6doXWeorAbQ1l_WmMnM/view?usp=sharing), [Denoising](https://drive.google.com/file/d/1LODPt9kYmxwU98g96UrRA0_Eh5HYcsRw/view?usp=sharing) on your own images, run 56 | ``` 57 | python demo.py --task Task_Name --input_dir path_to_images --result_dir save_images_here 58 | ``` 59 | Here is an example to perform Deblurring: 60 | ``` 61 | python demo.py --task Deblurring --input_dir ./samples/input/ --result_dir ./samples/output/ 62 | ``` 63 | 64 | ## Training and Evaluation 65 | 66 | Training and Testing codes for deblurring, deraining and denoising are provided in their respective directories. 67 | 68 | ## Results 69 | Experiments are performed for different image processing tasks including, image deblurring, image deraining and image denoising. Images produced by MPRNet can be downloaded from Google Drive links: [Deblurring](https://drive.google.com/drive/folders/12jgrGdIh_lfiSsXyo-QicQuZYcLXp9rP?usp=sharing), [Deraining](https://drive.google.com/drive/folders/1IpF_jCGBhqsXN4f1vBNQ6DGpr7Pk6LdO?usp=sharing), and [Denoising](https://drive.google.com/drive/folders/1usbZKuYg8c7UrUml2bdZSbuxh_JrHW67?usp=sharing). 70 | 71 | 72 | Image Deblurring (click to expand) 73 | 74 | 75 | 76 | 77 | 78 | 79 | Deblurring on Synthetic Datasets. 80 | Deblurring on Real Dataset. 81 | 82 | 83 | 84 | 85 | Image Deraining (click to expand) 86 | 87 | 88 | 89 | Image Denoising (click to expand) 90 | 91 | 92 | ## Citation 93 | If you use MPRNet, please consider citing: 94 | 95 | @inproceedings{Zamir2021MPRNet, 96 | title={Multi-Stage Progressive Image Restoration}, 97 | author={Syed Waqas Zamir and Aditya Arora and Salman Khan and Munawar Hayat 98 | and Fahad Shahbaz Khan and Ming-Hsuan Yang and Ling Shao}, 99 | booktitle={CVPR}, 100 | year={2021} 101 | } 102 | 103 | ## Contact 104 | Should you have any question, please contact waqas.zamir@inceptioniai.org 105 | 106 | ## Our Related Works 107 | - Learning Enriched Features for Fast Image Restoration and Enhancement, TPAMI 2022. [Paper](https://www.waqaszamir.com/publication/zamir-2022-mirnetv2/) | [Code](https://github.com/swz30/MIRNetv2) 108 | - Restormer: Efficient Transformer for High-Resolution Image Restoration, CVPR 2022. [Paper](https://arxiv.org/abs/2111.09881) | [Code](https://github.com/swz30/Restormer) 109 | - Learning Enriched Features for Real Image Restoration and Enhancement, ECCV 2020. [Paper](https://arxiv.org/abs/2003.06792) | [Code](https://github.com/swz30/MIRNet) 110 | - CycleISP: Real Image Restoration via Improved Data Synthesis, CVPR 2020. [Paper](https://arxiv.org/abs/2003.07761) | [Code](https://github.com/swz30/CycleISP) 111 | -------------------------------------------------------------------------------- /Deraining/evaluate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | clc;close all;clear all; 8 | 9 | % datasets = {'Rain100L'}; 10 | datasets = {'Test100', 'Rain100H', 'Rain100L', 'Test2800', 'Test1200'}; 11 | num_set = length(datasets); 12 | 13 | psnr_alldatasets = 0; 14 | ssim_alldatasets = 0; 15 | 16 | tic 17 | delete(gcp('nocreate')) 18 | parpool('local',20); 19 | 20 | for idx_set = 1:num_set 21 | file_path = strcat('./results/', datasets{idx_set}, '/'); 22 | gt_path = strcat('./Datasets/test/', datasets{idx_set}, '/target/'); 23 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 24 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 25 | img_num = length(path_list); 26 | 27 | total_psnr = 0; 28 | total_ssim = 0; 29 | if img_num > 0 30 | parfor j = 1:img_num 31 | image_name = path_list(j).name; 32 | gt_name = gt_list(j).name; 33 | input = imread(strcat(file_path,image_name)); 34 | gt = imread(strcat(gt_path, gt_name)); 35 | ssim_val = compute_ssim(input, gt); 36 | psnr_val = compute_psnr(input, gt); 37 | total_ssim = total_ssim + ssim_val; 38 | total_psnr = total_psnr + psnr_val; 39 | end 40 | end 41 | qm_psnr = total_psnr / img_num; 42 | qm_ssim = total_ssim / img_num; 43 | 44 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 45 | 46 | psnr_alldatasets = psnr_alldatasets + qm_psnr; 47 | ssim_alldatasets = ssim_alldatasets + qm_ssim; 48 | 49 | end 50 | 51 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set); 52 | 53 | delete(gcp('nocreate')) 54 | toc 55 | 56 | function ssim_mean=compute_ssim(img1,img2) 57 | if size(img1, 3) == 3 58 | img1 = rgb2ycbcr(img1); 59 | img1 = img1(:, :, 1); 60 | end 61 | 62 | if size(img2, 3) == 3 63 | img2 = rgb2ycbcr(img2); 64 | img2 = img2(:, :, 1); 65 | end 66 | ssim_mean = SSIM_index(img1, img2); 67 | end 68 | 69 | function psnr=compute_psnr(img1,img2) 70 | if size(img1, 3) == 3 71 | img1 = rgb2ycbcr(img1); 72 | img1 = img1(:, :, 1); 73 | end 74 | 75 | if size(img2, 3) == 3 76 | img2 = rgb2ycbcr(img2); 77 | img2 = img2(:, :, 1); 78 | end 79 | 80 | imdff = double(img1) - double(img2); 81 | imdff = imdff(:); 82 | rmse = sqrt(mean(imdff.^2)); 83 | psnr = 20*log10(255/rmse); 84 | 85 | end 86 | 87 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L) 88 | 89 | %======================================================================== 90 | %SSIM Index, Version 1.0 91 | %Copyright(c) 2003 Zhou Wang 92 | %All Rights Reserved. 93 | % 94 | %The author is with Howard Hughes Medical Institute, and Laboratory 95 | %for Computational Vision at Center for Neural Science and Courant 96 | %Institute of Mathematical Sciences, New York University. 97 | % 98 | %---------------------------------------------------------------------- 99 | %Permission to use, copy, or modify this software and its documentation 100 | %for educational and research purposes only and without fee is hereby 101 | %granted, provided that this copyright notice and the original authors' 102 | %names appear on all copies and supporting documentation. This program 103 | %shall not be used, rewritten, or adapted as the basis of a commercial 104 | %software or hardware product without first obtaining permission of the 105 | %authors. The authors make no representations about the suitability of 106 | %this software for any purpose. It is provided "as is" without express 107 | %or implied warranty. 108 | %---------------------------------------------------------------------- 109 | % 110 | %This is an implementation of the algorithm for calculating the 111 | %Structural SIMilarity (SSIM) index between two images. Please refer 112 | %to the following paper: 113 | % 114 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 115 | %quality assessment: From error measurement to structural similarity" 116 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 117 | % 118 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 119 | % 120 | %---------------------------------------------------------------------- 121 | % 122 | %Input : (1) img1: the first image being compared 123 | % (2) img2: the second image being compared 124 | % (3) K: constants in the SSIM index formula (see the above 125 | % reference). defualt value: K = [0.01 0.03] 126 | % (4) window: local window for statistics (see the above 127 | % reference). default widnow is Gaussian given by 128 | % window = fspecial('gaussian', 11, 1.5); 129 | % (5) L: dynamic range of the images. default: L = 255 130 | % 131 | %Output: (1) mssim: the mean SSIM index value between 2 images. 132 | % If one of the images being compared is regarded as 133 | % perfect quality, then mssim can be considered as the 134 | % quality measure of the other image. 135 | % If img1 = img2, then mssim = 1. 136 | % (2) ssim_map: the SSIM index map of the test image. The map 137 | % has a smaller size than the input images. The actual size: 138 | % size(img1) - size(window) + 1. 139 | % 140 | %Default Usage: 141 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 142 | % 143 | % [mssim ssim_map] = ssim_index(img1, img2); 144 | % 145 | %Advanced Usage: 146 | % User defined parameters. For example 147 | % 148 | % K = [0.05 0.05]; 149 | % window = ones(8); 150 | % L = 100; 151 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 152 | % 153 | %See the results: 154 | % 155 | % mssim %Gives the mssim value 156 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 157 | % 158 | %======================================================================== 159 | 160 | 161 | if (nargin < 2 || nargin > 5) 162 | ssim_index = -Inf; 163 | ssim_map = -Inf; 164 | return; 165 | end 166 | 167 | if (size(img1) ~= size(img2)) 168 | ssim_index = -Inf; 169 | ssim_map = -Inf; 170 | return; 171 | end 172 | 173 | [M N] = size(img1); 174 | 175 | if (nargin == 2) 176 | if ((M < 11) || (N < 11)) 177 | ssim_index = -Inf; 178 | ssim_map = -Inf; 179 | return 180 | end 181 | window = fspecial('gaussian', 11, 1.5); % 182 | K(1) = 0.01; % default settings 183 | K(2) = 0.03; % 184 | L = 255; % 185 | end 186 | 187 | if (nargin == 3) 188 | if ((M < 11) || (N < 11)) 189 | ssim_index = -Inf; 190 | ssim_map = -Inf; 191 | return 192 | end 193 | window = fspecial('gaussian', 11, 1.5); 194 | L = 255; 195 | if (length(K) == 2) 196 | if (K(1) < 0 || K(2) < 0) 197 | ssim_index = -Inf; 198 | ssim_map = -Inf; 199 | return; 200 | end 201 | else 202 | ssim_index = -Inf; 203 | ssim_map = -Inf; 204 | return; 205 | end 206 | end 207 | 208 | if (nargin == 4) 209 | [H W] = size(window); 210 | if ((H*W) < 4 || (H > M) || (W > N)) 211 | ssim_index = -Inf; 212 | ssim_map = -Inf; 213 | return 214 | end 215 | L = 255; 216 | if (length(K) == 2) 217 | if (K(1) < 0 || K(2) < 0) 218 | ssim_index = -Inf; 219 | ssim_map = -Inf; 220 | return; 221 | end 222 | else 223 | ssim_index = -Inf; 224 | ssim_map = -Inf; 225 | return; 226 | end 227 | end 228 | 229 | if (nargin == 5) 230 | [H W] = size(window); 231 | if ((H*W) < 4 || (H > M) || (W > N)) 232 | ssim_index = -Inf; 233 | ssim_map = -Inf; 234 | return 235 | end 236 | if (length(K) == 2) 237 | if (K(1) < 0 || K(2) < 0) 238 | ssim_index = -Inf; 239 | ssim_map = -Inf; 240 | return; 241 | end 242 | else 243 | ssim_index = -Inf; 244 | ssim_map = -Inf; 245 | return; 246 | end 247 | end 248 | 249 | C1 = (K(1)*L)^2; 250 | C2 = (K(2)*L)^2; 251 | window = window/sum(sum(window)); 252 | img1 = double(img1); 253 | img2 = double(img2); 254 | 255 | mu1 = filter2(window, img1, 'valid'); 256 | mu2 = filter2(window, img2, 'valid'); 257 | mu1_sq = mu1.*mu1; 258 | mu2_sq = mu2.*mu2; 259 | mu1_mu2 = mu1.*mu2; 260 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 261 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 262 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 263 | 264 | if (C1 > 0 & C2 > 0) 265 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 266 | else 267 | numerator1 = 2*mu1_mu2 + C1; 268 | numerator2 = 2*sigma12 + C2; 269 | denominator1 = mu1_sq + mu2_sq + C1; 270 | denominator2 = sigma1_sq + sigma2_sq + C2; 271 | ssim_map = ones(size(mu1)); 272 | index = (denominator1.*denominator2 > 0); 273 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 274 | index = (denominator1 ~= 0) & (denominator2 == 0); 275 | ssim_map(index) = numerator1(index)./denominator1(index); 276 | end 277 | 278 | mssim = mean2(ssim_map); 279 | 280 | end 281 | -------------------------------------------------------------------------------- /Deblurring/MPRNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from pdb import set_trace as stx 11 | 12 | ########################################################################## 13 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 14 | return nn.Conv2d( 15 | in_channels, out_channels, kernel_size, 16 | padding=(kernel_size//2), bias=bias, stride = stride) 17 | 18 | 19 | ########################################################################## 20 | ## Channel Attention Layer 21 | class CALayer(nn.Module): 22 | def __init__(self, channel, reduction=16, bias=False): 23 | super(CALayer, self).__init__() 24 | # global average pooling: feature --> point 25 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 26 | # feature channel downscale and upscale --> channel weight 27 | self.conv_du = nn.Sequential( 28 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x): 35 | y = self.avg_pool(x) 36 | y = self.conv_du(y) 37 | return x * y 38 | 39 | 40 | ########################################################################## 41 | ## Channel Attention Block (CAB) 42 | class CAB(nn.Module): 43 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 44 | super(CAB, self).__init__() 45 | modules_body = [] 46 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 47 | modules_body.append(act) 48 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 49 | 50 | self.CA = CALayer(n_feat, reduction, bias=bias) 51 | self.body = nn.Sequential(*modules_body) 52 | 53 | def forward(self, x): 54 | res = self.body(x) 55 | res = self.CA(res) 56 | res += x 57 | return res 58 | 59 | ########################################################################## 60 | ## Supervised Attention Module 61 | class SAM(nn.Module): 62 | def __init__(self, n_feat, kernel_size, bias): 63 | super(SAM, self).__init__() 64 | self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) 65 | self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) 66 | self.conv3 = conv(3, n_feat, kernel_size, bias=bias) 67 | 68 | def forward(self, x, x_img): 69 | x1 = self.conv1(x) 70 | img = self.conv2(x) + x_img 71 | x2 = torch.sigmoid(self.conv3(img)) 72 | x1 = x1*x2 73 | x1 = x1+x 74 | return x1, img 75 | 76 | ########################################################################## 77 | ## U-Net 78 | 79 | class Encoder(nn.Module): 80 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): 81 | super(Encoder, self).__init__() 82 | 83 | self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 84 | self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 85 | self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 86 | 87 | self.encoder_level1 = nn.Sequential(*self.encoder_level1) 88 | self.encoder_level2 = nn.Sequential(*self.encoder_level2) 89 | self.encoder_level3 = nn.Sequential(*self.encoder_level3) 90 | 91 | self.down12 = DownSample(n_feat, scale_unetfeats) 92 | self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) 93 | 94 | # Cross Stage Feature Fusion (CSFF) 95 | if csff: 96 | self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 97 | self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 98 | self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 99 | 100 | self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 101 | self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 102 | self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 103 | 104 | def forward(self, x, encoder_outs=None, decoder_outs=None): 105 | enc1 = self.encoder_level1(x) 106 | if (encoder_outs is not None) and (decoder_outs is not None): 107 | enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) 108 | 109 | x = self.down12(enc1) 110 | 111 | enc2 = self.encoder_level2(x) 112 | if (encoder_outs is not None) and (decoder_outs is not None): 113 | enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) 114 | 115 | x = self.down23(enc2) 116 | 117 | enc3 = self.encoder_level3(x) 118 | if (encoder_outs is not None) and (decoder_outs is not None): 119 | enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) 120 | 121 | return [enc1, enc2, enc3] 122 | 123 | class Decoder(nn.Module): 124 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): 125 | super(Decoder, self).__init__() 126 | 127 | self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 128 | self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 129 | self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 130 | 131 | self.decoder_level1 = nn.Sequential(*self.decoder_level1) 132 | self.decoder_level2 = nn.Sequential(*self.decoder_level2) 133 | self.decoder_level3 = nn.Sequential(*self.decoder_level3) 134 | 135 | self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) 136 | self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) 137 | 138 | self.up21 = SkipUpSample(n_feat, scale_unetfeats) 139 | self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) 140 | 141 | def forward(self, outs): 142 | enc1, enc2, enc3 = outs 143 | dec3 = self.decoder_level3(enc3) 144 | 145 | x = self.up32(dec3, self.skip_attn2(enc2)) 146 | dec2 = self.decoder_level2(x) 147 | 148 | x = self.up21(dec2, self.skip_attn1(enc1)) 149 | dec1 = self.decoder_level1(x) 150 | 151 | return [dec1,dec2,dec3] 152 | 153 | ########################################################################## 154 | ##---------- Resizing Modules ---------- 155 | class DownSample(nn.Module): 156 | def __init__(self, in_channels,s_factor): 157 | super(DownSample, self).__init__() 158 | self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 159 | nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False)) 160 | 161 | def forward(self, x): 162 | x = self.down(x) 163 | return x 164 | 165 | class UpSample(nn.Module): 166 | def __init__(self, in_channels,s_factor): 167 | super(UpSample, self).__init__() 168 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 169 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 170 | 171 | def forward(self, x): 172 | x = self.up(x) 173 | return x 174 | 175 | class SkipUpSample(nn.Module): 176 | def __init__(self, in_channels,s_factor): 177 | super(SkipUpSample, self).__init__() 178 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 179 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 180 | 181 | def forward(self, x, y): 182 | x = self.up(x) 183 | x = x + y 184 | return x 185 | 186 | ########################################################################## 187 | ## Original Resolution Block (ORB) 188 | class ORB(nn.Module): 189 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 190 | super(ORB, self).__init__() 191 | modules_body = [] 192 | modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 193 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 194 | self.body = nn.Sequential(*modules_body) 195 | 196 | def forward(self, x): 197 | res = self.body(x) 198 | res += x 199 | return res 200 | 201 | ########################################################################## 202 | class ORSNet(nn.Module): 203 | def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): 204 | super(ORSNet, self).__init__() 205 | 206 | self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 207 | self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 208 | self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 209 | 210 | self.up_enc1 = UpSample(n_feat, scale_unetfeats) 211 | self.up_dec1 = UpSample(n_feat, scale_unetfeats) 212 | 213 | self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 214 | self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 215 | 216 | self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 217 | self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 218 | self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 219 | 220 | self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 221 | self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 222 | self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 223 | 224 | def forward(self, x, encoder_outs, decoder_outs): 225 | x = self.orb1(x) 226 | x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) 227 | 228 | x = self.orb2(x) 229 | x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) 230 | 231 | x = self.orb3(x) 232 | x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) 233 | 234 | return x 235 | 236 | 237 | ########################################################################## 238 | class MPRNet(nn.Module): 239 | def __init__(self, in_c=3, out_c=3, n_feat=96, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False): 240 | super(MPRNet, self).__init__() 241 | 242 | act=nn.PReLU() 243 | self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 244 | self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 245 | self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 246 | 247 | # Cross Stage Feature Fusion (CSFF) 248 | self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) 249 | self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 250 | 251 | self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) 252 | self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 253 | 254 | self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) 255 | 256 | self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) 257 | self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) 258 | 259 | self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) 260 | self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) 261 | self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) 262 | 263 | def forward(self, x3_img): 264 | # Original-resolution Image for Stage 3 265 | H = x3_img.size(2) 266 | W = x3_img.size(3) 267 | 268 | # Multi-Patch Hierarchy: Split Image into four non-overlapping patches 269 | 270 | # Two Patches for Stage 2 271 | x2top_img = x3_img[:,:,0:int(H/2),:] 272 | x2bot_img = x3_img[:,:,int(H/2):H,:] 273 | 274 | # Four Patches for Stage 1 275 | x1ltop_img = x2top_img[:,:,:,0:int(W/2)] 276 | x1rtop_img = x2top_img[:,:,:,int(W/2):W] 277 | x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] 278 | x1rbot_img = x2bot_img[:,:,:,int(W/2):W] 279 | 280 | ##------------------------------------------- 281 | ##-------------- Stage 1--------------------- 282 | ##------------------------------------------- 283 | ## Compute Shallow Features 284 | x1ltop = self.shallow_feat1(x1ltop_img) 285 | x1rtop = self.shallow_feat1(x1rtop_img) 286 | x1lbot = self.shallow_feat1(x1lbot_img) 287 | x1rbot = self.shallow_feat1(x1rbot_img) 288 | 289 | ## Process features of all 4 patches with Encoder of Stage 1 290 | feat1_ltop = self.stage1_encoder(x1ltop) 291 | feat1_rtop = self.stage1_encoder(x1rtop) 292 | feat1_lbot = self.stage1_encoder(x1lbot) 293 | feat1_rbot = self.stage1_encoder(x1rbot) 294 | 295 | ## Concat deep features 296 | feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] 297 | feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] 298 | 299 | ## Pass features through Decoder of Stage 1 300 | res1_top = self.stage1_decoder(feat1_top) 301 | res1_bot = self.stage1_decoder(feat1_bot) 302 | 303 | ## Apply Supervised Attention Module (SAM) 304 | x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) 305 | x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) 306 | 307 | ## Output image at Stage 1 308 | stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) 309 | ##------------------------------------------- 310 | ##-------------- Stage 2--------------------- 311 | ##------------------------------------------- 312 | ## Compute Shallow Features 313 | x2top = self.shallow_feat2(x2top_img) 314 | x2bot = self.shallow_feat2(x2bot_img) 315 | 316 | ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 317 | x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) 318 | x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) 319 | 320 | ## Process features of both patches with Encoder of Stage 2 321 | feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) 322 | feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) 323 | 324 | ## Concat deep features 325 | feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] 326 | 327 | ## Pass features through Decoder of Stage 2 328 | res2 = self.stage2_decoder(feat2) 329 | 330 | ## Apply SAM 331 | x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) 332 | 333 | 334 | ##------------------------------------------- 335 | ##-------------- Stage 3--------------------- 336 | ##------------------------------------------- 337 | ## Compute Shallow Features 338 | x3 = self.shallow_feat3(x3_img) 339 | 340 | ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 341 | x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) 342 | 343 | x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) 344 | 345 | stage3_img = self.tail(x3_cat) 346 | 347 | return [stage3_img+x3_img, stage2_img, stage1_img] 348 | -------------------------------------------------------------------------------- /Denoising/MPRNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from pdb import set_trace as stx 11 | 12 | ########################################################################## 13 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 14 | return nn.Conv2d( 15 | in_channels, out_channels, kernel_size, 16 | padding=(kernel_size//2), bias=bias, stride = stride) 17 | 18 | 19 | ########################################################################## 20 | ## Channel Attention Layer 21 | class CALayer(nn.Module): 22 | def __init__(self, channel, reduction=16, bias=False): 23 | super(CALayer, self).__init__() 24 | # global average pooling: feature --> point 25 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 26 | # feature channel downscale and upscale --> channel weight 27 | self.conv_du = nn.Sequential( 28 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x): 35 | y = self.avg_pool(x) 36 | y = self.conv_du(y) 37 | return x * y 38 | 39 | 40 | ########################################################################## 41 | ## Channel Attention Block (CAB) 42 | class CAB(nn.Module): 43 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 44 | super(CAB, self).__init__() 45 | modules_body = [] 46 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 47 | modules_body.append(act) 48 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 49 | 50 | self.CA = CALayer(n_feat, reduction, bias=bias) 51 | self.body = nn.Sequential(*modules_body) 52 | 53 | def forward(self, x): 54 | res = self.body(x) 55 | res = self.CA(res) 56 | res += x 57 | return res 58 | 59 | ########################################################################## 60 | ## Supervised Attention Module 61 | class SAM(nn.Module): 62 | def __init__(self, n_feat, kernel_size, bias): 63 | super(SAM, self).__init__() 64 | self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) 65 | self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) 66 | self.conv3 = conv(3, n_feat, kernel_size, bias=bias) 67 | 68 | def forward(self, x, x_img): 69 | x1 = self.conv1(x) 70 | img = self.conv2(x) + x_img 71 | x2 = torch.sigmoid(self.conv3(img)) 72 | x1 = x1*x2 73 | x1 = x1+x 74 | return x1, img 75 | 76 | ########################################################################## 77 | ## U-Net 78 | 79 | class Encoder(nn.Module): 80 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): 81 | super(Encoder, self).__init__() 82 | 83 | self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 84 | self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 85 | self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 86 | 87 | self.encoder_level1 = nn.Sequential(*self.encoder_level1) 88 | self.encoder_level2 = nn.Sequential(*self.encoder_level2) 89 | self.encoder_level3 = nn.Sequential(*self.encoder_level3) 90 | 91 | self.down12 = DownSample(n_feat, scale_unetfeats) 92 | self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) 93 | 94 | # Cross Stage Feature Fusion (CSFF) 95 | if csff: 96 | self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 97 | self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 98 | self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 99 | 100 | self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 101 | self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 102 | self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 103 | 104 | def forward(self, x, encoder_outs=None, decoder_outs=None): 105 | enc1 = self.encoder_level1(x) 106 | if (encoder_outs is not None) and (decoder_outs is not None): 107 | enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) 108 | 109 | x = self.down12(enc1) 110 | 111 | enc2 = self.encoder_level2(x) 112 | if (encoder_outs is not None) and (decoder_outs is not None): 113 | enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) 114 | 115 | x = self.down23(enc2) 116 | 117 | enc3 = self.encoder_level3(x) 118 | if (encoder_outs is not None) and (decoder_outs is not None): 119 | enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) 120 | 121 | return [enc1, enc2, enc3] 122 | 123 | class Decoder(nn.Module): 124 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): 125 | super(Decoder, self).__init__() 126 | 127 | self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 128 | self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 129 | self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 130 | 131 | self.decoder_level1 = nn.Sequential(*self.decoder_level1) 132 | self.decoder_level2 = nn.Sequential(*self.decoder_level2) 133 | self.decoder_level3 = nn.Sequential(*self.decoder_level3) 134 | 135 | self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) 136 | self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) 137 | 138 | self.up21 = SkipUpSample(n_feat, scale_unetfeats) 139 | self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) 140 | 141 | def forward(self, outs): 142 | enc1, enc2, enc3 = outs 143 | dec3 = self.decoder_level3(enc3) 144 | 145 | x = self.up32(dec3, self.skip_attn2(enc2)) 146 | dec2 = self.decoder_level2(x) 147 | 148 | x = self.up21(dec2, self.skip_attn1(enc1)) 149 | dec1 = self.decoder_level1(x) 150 | 151 | return [dec1,dec2,dec3] 152 | 153 | ########################################################################## 154 | ##---------- Resizing Modules ---------- 155 | class DownSample(nn.Module): 156 | def __init__(self, in_channels,s_factor): 157 | super(DownSample, self).__init__() 158 | self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 159 | nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False)) 160 | 161 | def forward(self, x): 162 | x = self.down(x) 163 | return x 164 | 165 | class UpSample(nn.Module): 166 | def __init__(self, in_channels,s_factor): 167 | super(UpSample, self).__init__() 168 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 169 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 170 | 171 | def forward(self, x): 172 | x = self.up(x) 173 | return x 174 | 175 | class SkipUpSample(nn.Module): 176 | def __init__(self, in_channels,s_factor): 177 | super(SkipUpSample, self).__init__() 178 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 179 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 180 | 181 | def forward(self, x, y): 182 | x = self.up(x) 183 | x = x + y 184 | return x 185 | 186 | ########################################################################## 187 | ## Original Resolution Block (ORB) 188 | class ORB(nn.Module): 189 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 190 | super(ORB, self).__init__() 191 | modules_body = [] 192 | modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 193 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 194 | self.body = nn.Sequential(*modules_body) 195 | 196 | def forward(self, x): 197 | res = self.body(x) 198 | res += x 199 | return res 200 | 201 | ########################################################################## 202 | class ORSNet(nn.Module): 203 | def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): 204 | super(ORSNet, self).__init__() 205 | 206 | self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 207 | self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 208 | self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 209 | 210 | self.up_enc1 = UpSample(n_feat, scale_unetfeats) 211 | self.up_dec1 = UpSample(n_feat, scale_unetfeats) 212 | 213 | self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 214 | self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 215 | 216 | self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 217 | self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 218 | self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 219 | 220 | self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 221 | self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 222 | self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 223 | 224 | def forward(self, x, encoder_outs, decoder_outs): 225 | x = self.orb1(x) 226 | x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) 227 | 228 | x = self.orb2(x) 229 | x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) 230 | 231 | x = self.orb3(x) 232 | x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) 233 | 234 | return x 235 | 236 | 237 | ########################################################################## 238 | class MPRNet(nn.Module): 239 | def __init__(self, in_c=3, out_c=3, n_feat=80, scale_unetfeats=48, scale_orsnetfeats=32, num_cab=8, kernel_size=3, reduction=4, bias=False): 240 | super(MPRNet, self).__init__() 241 | 242 | act=nn.PReLU() 243 | self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 244 | self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 245 | self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 246 | 247 | # Cross Stage Feature Fusion (CSFF) 248 | self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) 249 | self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 250 | 251 | self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) 252 | self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 253 | 254 | self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) 255 | 256 | self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) 257 | self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) 258 | 259 | self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) 260 | self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) 261 | self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) 262 | 263 | def forward(self, x3_img): 264 | # Original-resolution Image for Stage 3 265 | H = x3_img.size(2) 266 | W = x3_img.size(3) 267 | 268 | # Multi-Patch Hierarchy: Split Image into four non-overlapping patches 269 | 270 | # Two Patches for Stage 2 271 | x2top_img = x3_img[:,:,0:int(H/2),:] 272 | x2bot_img = x3_img[:,:,int(H/2):H,:] 273 | 274 | # Four Patches for Stage 1 275 | x1ltop_img = x2top_img[:,:,:,0:int(W/2)] 276 | x1rtop_img = x2top_img[:,:,:,int(W/2):W] 277 | x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] 278 | x1rbot_img = x2bot_img[:,:,:,int(W/2):W] 279 | 280 | ##------------------------------------------- 281 | ##-------------- Stage 1--------------------- 282 | ##------------------------------------------- 283 | ## Compute Shallow Features 284 | x1ltop = self.shallow_feat1(x1ltop_img) 285 | x1rtop = self.shallow_feat1(x1rtop_img) 286 | x1lbot = self.shallow_feat1(x1lbot_img) 287 | x1rbot = self.shallow_feat1(x1rbot_img) 288 | 289 | ## Process features of all 4 patches with Encoder of Stage 1 290 | feat1_ltop = self.stage1_encoder(x1ltop) 291 | feat1_rtop = self.stage1_encoder(x1rtop) 292 | feat1_lbot = self.stage1_encoder(x1lbot) 293 | feat1_rbot = self.stage1_encoder(x1rbot) 294 | 295 | ## Concat deep features 296 | feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] 297 | feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] 298 | 299 | ## Pass features through Decoder of Stage 1 300 | res1_top = self.stage1_decoder(feat1_top) 301 | res1_bot = self.stage1_decoder(feat1_bot) 302 | 303 | ## Apply Supervised Attention Module (SAM) 304 | x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) 305 | x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) 306 | 307 | ## Output image at Stage 1 308 | stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) 309 | ##------------------------------------------- 310 | ##-------------- Stage 2--------------------- 311 | ##------------------------------------------- 312 | ## Compute Shallow Features 313 | x2top = self.shallow_feat2(x2top_img) 314 | x2bot = self.shallow_feat2(x2bot_img) 315 | 316 | ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 317 | x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) 318 | x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) 319 | 320 | ## Process features of both patches with Encoder of Stage 2 321 | feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) 322 | feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) 323 | 324 | ## Concat deep features 325 | feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] 326 | 327 | ## Pass features through Decoder of Stage 2 328 | res2 = self.stage2_decoder(feat2) 329 | 330 | ## Apply SAM 331 | x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) 332 | 333 | 334 | ##------------------------------------------- 335 | ##-------------- Stage 3--------------------- 336 | ##------------------------------------------- 337 | ## Compute Shallow Features 338 | x3 = self.shallow_feat3(x3_img) 339 | 340 | ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 341 | x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) 342 | 343 | x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) 344 | 345 | stage3_img = self.tail(x3_cat) 346 | 347 | return [stage3_img+x3_img, stage2_img, stage1_img] 348 | -------------------------------------------------------------------------------- /Deraining/MPRNet.py: -------------------------------------------------------------------------------- 1 | """ 2 | ## Multi-Stage Progressive Image Restoration 3 | ## Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan, Ming-Hsuan Yang, and Ling Shao 4 | ## https://arxiv.org/abs/2102.02808 5 | """ 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | from pdb import set_trace as stx 11 | 12 | ########################################################################## 13 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 14 | return nn.Conv2d( 15 | in_channels, out_channels, kernel_size, 16 | padding=(kernel_size//2), bias=bias, stride = stride) 17 | 18 | 19 | ########################################################################## 20 | ## Channel Attention Layer 21 | class CALayer(nn.Module): 22 | def __init__(self, channel, reduction=16, bias=False): 23 | super(CALayer, self).__init__() 24 | # global average pooling: feature --> point 25 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 26 | # feature channel downscale and upscale --> channel weight 27 | self.conv_du = nn.Sequential( 28 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=bias), 29 | nn.ReLU(inplace=True), 30 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=bias), 31 | nn.Sigmoid() 32 | ) 33 | 34 | def forward(self, x): 35 | y = self.avg_pool(x) 36 | y = self.conv_du(y) 37 | return x * y 38 | 39 | 40 | ########################################################################## 41 | ## Channel Attention Block (CAB) 42 | class CAB(nn.Module): 43 | def __init__(self, n_feat, kernel_size, reduction, bias, act): 44 | super(CAB, self).__init__() 45 | modules_body = [] 46 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 47 | modules_body.append(act) 48 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 49 | 50 | self.CA = CALayer(n_feat, reduction, bias=bias) 51 | self.body = nn.Sequential(*modules_body) 52 | 53 | def forward(self, x): 54 | res = self.body(x) 55 | res = self.CA(res) 56 | res += x 57 | return res 58 | 59 | ########################################################################## 60 | ## Supervised Attention Module 61 | class SAM(nn.Module): 62 | def __init__(self, n_feat, kernel_size, bias): 63 | super(SAM, self).__init__() 64 | self.conv1 = conv(n_feat, n_feat, kernel_size, bias=bias) 65 | self.conv2 = conv(n_feat, 3, kernel_size, bias=bias) 66 | self.conv3 = conv(3, n_feat, kernel_size, bias=bias) 67 | 68 | def forward(self, x, x_img): 69 | x1 = self.conv1(x) 70 | img = self.conv2(x) + x_img 71 | x2 = torch.sigmoid(self.conv3(img)) 72 | x1 = x1*x2 73 | x1 = x1+x 74 | return x1, img 75 | 76 | ########################################################################## 77 | ## U-Net 78 | 79 | class Encoder(nn.Module): 80 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff): 81 | super(Encoder, self).__init__() 82 | 83 | self.encoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 84 | self.encoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 85 | self.encoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 86 | 87 | self.encoder_level1 = nn.Sequential(*self.encoder_level1) 88 | self.encoder_level2 = nn.Sequential(*self.encoder_level2) 89 | self.encoder_level3 = nn.Sequential(*self.encoder_level3) 90 | 91 | self.down12 = DownSample(n_feat, scale_unetfeats) 92 | self.down23 = DownSample(n_feat+scale_unetfeats, scale_unetfeats) 93 | 94 | # Cross Stage Feature Fusion (CSFF) 95 | if csff: 96 | self.csff_enc1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 97 | self.csff_enc2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 98 | self.csff_enc3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 99 | 100 | self.csff_dec1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias) 101 | self.csff_dec2 = nn.Conv2d(n_feat+scale_unetfeats, n_feat+scale_unetfeats, kernel_size=1, bias=bias) 102 | self.csff_dec3 = nn.Conv2d(n_feat+(scale_unetfeats*2), n_feat+(scale_unetfeats*2), kernel_size=1, bias=bias) 103 | 104 | def forward(self, x, encoder_outs=None, decoder_outs=None): 105 | enc1 = self.encoder_level1(x) 106 | if (encoder_outs is not None) and (decoder_outs is not None): 107 | enc1 = enc1 + self.csff_enc1(encoder_outs[0]) + self.csff_dec1(decoder_outs[0]) 108 | 109 | x = self.down12(enc1) 110 | 111 | enc2 = self.encoder_level2(x) 112 | if (encoder_outs is not None) and (decoder_outs is not None): 113 | enc2 = enc2 + self.csff_enc2(encoder_outs[1]) + self.csff_dec2(decoder_outs[1]) 114 | 115 | x = self.down23(enc2) 116 | 117 | enc3 = self.encoder_level3(x) 118 | if (encoder_outs is not None) and (decoder_outs is not None): 119 | enc3 = enc3 + self.csff_enc3(encoder_outs[2]) + self.csff_dec3(decoder_outs[2]) 120 | 121 | return [enc1, enc2, enc3] 122 | 123 | class Decoder(nn.Module): 124 | def __init__(self, n_feat, kernel_size, reduction, act, bias, scale_unetfeats): 125 | super(Decoder, self).__init__() 126 | 127 | self.decoder_level1 = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 128 | self.decoder_level2 = [CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 129 | self.decoder_level3 = [CAB(n_feat+(scale_unetfeats*2), kernel_size, reduction, bias=bias, act=act) for _ in range(2)] 130 | 131 | self.decoder_level1 = nn.Sequential(*self.decoder_level1) 132 | self.decoder_level2 = nn.Sequential(*self.decoder_level2) 133 | self.decoder_level3 = nn.Sequential(*self.decoder_level3) 134 | 135 | self.skip_attn1 = CAB(n_feat, kernel_size, reduction, bias=bias, act=act) 136 | self.skip_attn2 = CAB(n_feat+scale_unetfeats, kernel_size, reduction, bias=bias, act=act) 137 | 138 | self.up21 = SkipUpSample(n_feat, scale_unetfeats) 139 | self.up32 = SkipUpSample(n_feat+scale_unetfeats, scale_unetfeats) 140 | 141 | def forward(self, outs): 142 | enc1, enc2, enc3 = outs 143 | dec3 = self.decoder_level3(enc3) 144 | 145 | x = self.up32(dec3, self.skip_attn2(enc2)) 146 | dec2 = self.decoder_level2(x) 147 | 148 | x = self.up21(dec2, self.skip_attn1(enc1)) 149 | dec1 = self.decoder_level1(x) 150 | 151 | return [dec1,dec2,dec3] 152 | 153 | ########################################################################## 154 | ##---------- Resizing Modules ---------- 155 | class DownSample(nn.Module): 156 | def __init__(self, in_channels,s_factor): 157 | super(DownSample, self).__init__() 158 | self.down = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 159 | nn.Conv2d(in_channels, in_channels+s_factor, 1, stride=1, padding=0, bias=False)) 160 | 161 | def forward(self, x): 162 | x = self.down(x) 163 | return x 164 | 165 | class UpSample(nn.Module): 166 | def __init__(self, in_channels,s_factor): 167 | super(UpSample, self).__init__() 168 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 169 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 170 | 171 | def forward(self, x): 172 | x = self.up(x) 173 | return x 174 | 175 | class SkipUpSample(nn.Module): 176 | def __init__(self, in_channels,s_factor): 177 | super(SkipUpSample, self).__init__() 178 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 179 | nn.Conv2d(in_channels+s_factor, in_channels, 1, stride=1, padding=0, bias=False)) 180 | 181 | def forward(self, x, y): 182 | x = self.up(x) 183 | x = x + y 184 | return x 185 | 186 | ########################################################################## 187 | ## Original Resolution Block (ORB) 188 | class ORB(nn.Module): 189 | def __init__(self, n_feat, kernel_size, reduction, act, bias, num_cab): 190 | super(ORB, self).__init__() 191 | modules_body = [] 192 | modules_body = [CAB(n_feat, kernel_size, reduction, bias=bias, act=act) for _ in range(num_cab)] 193 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 194 | self.body = nn.Sequential(*modules_body) 195 | 196 | def forward(self, x): 197 | res = self.body(x) 198 | res += x 199 | return res 200 | 201 | ########################################################################## 202 | class ORSNet(nn.Module): 203 | def __init__(self, n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab): 204 | super(ORSNet, self).__init__() 205 | 206 | self.orb1 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 207 | self.orb2 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 208 | self.orb3 = ORB(n_feat+scale_orsnetfeats, kernel_size, reduction, act, bias, num_cab) 209 | 210 | self.up_enc1 = UpSample(n_feat, scale_unetfeats) 211 | self.up_dec1 = UpSample(n_feat, scale_unetfeats) 212 | 213 | self.up_enc2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 214 | self.up_dec2 = nn.Sequential(UpSample(n_feat+scale_unetfeats, scale_unetfeats), UpSample(n_feat, scale_unetfeats)) 215 | 216 | self.conv_enc1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 217 | self.conv_enc2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 218 | self.conv_enc3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 219 | 220 | self.conv_dec1 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 221 | self.conv_dec2 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 222 | self.conv_dec3 = nn.Conv2d(n_feat, n_feat+scale_orsnetfeats, kernel_size=1, bias=bias) 223 | 224 | def forward(self, x, encoder_outs, decoder_outs): 225 | x = self.orb1(x) 226 | x = x + self.conv_enc1(encoder_outs[0]) + self.conv_dec1(decoder_outs[0]) 227 | 228 | x = self.orb2(x) 229 | x = x + self.conv_enc2(self.up_enc1(encoder_outs[1])) + self.conv_dec2(self.up_dec1(decoder_outs[1])) 230 | 231 | x = self.orb3(x) 232 | x = x + self.conv_enc3(self.up_enc2(encoder_outs[2])) + self.conv_dec3(self.up_dec2(decoder_outs[2])) 233 | 234 | return x 235 | 236 | 237 | ########################################################################## 238 | class MPRNet(nn.Module): 239 | def __init__(self, in_c=3, out_c=3, n_feat=40, scale_unetfeats=20, scale_orsnetfeats=16, num_cab=8, kernel_size=3, reduction=4, bias=False): 240 | super(MPRNet, self).__init__() 241 | 242 | act=nn.PReLU() 243 | self.shallow_feat1 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 244 | self.shallow_feat2 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 245 | self.shallow_feat3 = nn.Sequential(conv(in_c, n_feat, kernel_size, bias=bias), CAB(n_feat,kernel_size, reduction, bias=bias, act=act)) 246 | 247 | # Cross Stage Feature Fusion (CSFF) 248 | self.stage1_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=False) 249 | self.stage1_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 250 | 251 | self.stage2_encoder = Encoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats, csff=True) 252 | self.stage2_decoder = Decoder(n_feat, kernel_size, reduction, act, bias, scale_unetfeats) 253 | 254 | self.stage3_orsnet = ORSNet(n_feat, scale_orsnetfeats, kernel_size, reduction, act, bias, scale_unetfeats, num_cab) 255 | 256 | self.sam12 = SAM(n_feat, kernel_size=1, bias=bias) 257 | self.sam23 = SAM(n_feat, kernel_size=1, bias=bias) 258 | 259 | self.concat12 = conv(n_feat*2, n_feat, kernel_size, bias=bias) 260 | self.concat23 = conv(n_feat*2, n_feat+scale_orsnetfeats, kernel_size, bias=bias) 261 | self.tail = conv(n_feat+scale_orsnetfeats, out_c, kernel_size, bias=bias) 262 | 263 | def forward(self, x3_img): 264 | # Original-resolution Image for Stage 3 265 | H = x3_img.size(2) 266 | W = x3_img.size(3) 267 | 268 | # Multi-Patch Hierarchy: Split Image into four non-overlapping patches 269 | 270 | # Two Patches for Stage 2 271 | x2top_img = x3_img[:,:,0:int(H/2),:] 272 | x2bot_img = x3_img[:,:,int(H/2):H,:] 273 | 274 | # Four Patches for Stage 1 275 | x1ltop_img = x2top_img[:,:,:,0:int(W/2)] 276 | x1rtop_img = x2top_img[:,:,:,int(W/2):W] 277 | x1lbot_img = x2bot_img[:,:,:,0:int(W/2)] 278 | x1rbot_img = x2bot_img[:,:,:,int(W/2):W] 279 | 280 | ##------------------------------------------- 281 | ##-------------- Stage 1--------------------- 282 | ##------------------------------------------- 283 | ## Compute Shallow Features 284 | x1ltop = self.shallow_feat1(x1ltop_img) 285 | x1rtop = self.shallow_feat1(x1rtop_img) 286 | x1lbot = self.shallow_feat1(x1lbot_img) 287 | x1rbot = self.shallow_feat1(x1rbot_img) 288 | 289 | ## Process features of all 4 patches with Encoder of Stage 1 290 | feat1_ltop = self.stage1_encoder(x1ltop) 291 | feat1_rtop = self.stage1_encoder(x1rtop) 292 | feat1_lbot = self.stage1_encoder(x1lbot) 293 | feat1_rbot = self.stage1_encoder(x1rbot) 294 | 295 | ## Concat deep features 296 | feat1_top = [torch.cat((k,v), 3) for k,v in zip(feat1_ltop,feat1_rtop)] 297 | feat1_bot = [torch.cat((k,v), 3) for k,v in zip(feat1_lbot,feat1_rbot)] 298 | 299 | ## Pass features through Decoder of Stage 1 300 | res1_top = self.stage1_decoder(feat1_top) 301 | res1_bot = self.stage1_decoder(feat1_bot) 302 | 303 | ## Apply Supervised Attention Module (SAM) 304 | x2top_samfeats, stage1_img_top = self.sam12(res1_top[0], x2top_img) 305 | x2bot_samfeats, stage1_img_bot = self.sam12(res1_bot[0], x2bot_img) 306 | 307 | ## Output image at Stage 1 308 | stage1_img = torch.cat([stage1_img_top, stage1_img_bot],2) 309 | ##------------------------------------------- 310 | ##-------------- Stage 2--------------------- 311 | ##------------------------------------------- 312 | ## Compute Shallow Features 313 | x2top = self.shallow_feat2(x2top_img) 314 | x2bot = self.shallow_feat2(x2bot_img) 315 | 316 | ## Concatenate SAM features of Stage 1 with shallow features of Stage 2 317 | x2top_cat = self.concat12(torch.cat([x2top, x2top_samfeats], 1)) 318 | x2bot_cat = self.concat12(torch.cat([x2bot, x2bot_samfeats], 1)) 319 | 320 | ## Process features of both patches with Encoder of Stage 2 321 | feat2_top = self.stage2_encoder(x2top_cat, feat1_top, res1_top) 322 | feat2_bot = self.stage2_encoder(x2bot_cat, feat1_bot, res1_bot) 323 | 324 | ## Concat deep features 325 | feat2 = [torch.cat((k,v), 2) for k,v in zip(feat2_top,feat2_bot)] 326 | 327 | ## Pass features through Decoder of Stage 2 328 | res2 = self.stage2_decoder(feat2) 329 | 330 | ## Apply SAM 331 | x3_samfeats, stage2_img = self.sam23(res2[0], x3_img) 332 | 333 | 334 | ##------------------------------------------- 335 | ##-------------- Stage 3--------------------- 336 | ##------------------------------------------- 337 | ## Compute Shallow Features 338 | x3 = self.shallow_feat3(x3_img) 339 | 340 | ## Concatenate SAM features of Stage 2 with shallow features of Stage 3 341 | x3_cat = self.concat23(torch.cat([x3, x3_samfeats], 1)) 342 | 343 | x3_cat = self.stage3_orsnet(x3_cat, feat2, res2) 344 | 345 | stage3_img = self.tail(x3_cat) 346 | 347 | return [stage3_img+x3_img, stage2_img, stage1_img] 348 | --------------------------------------------------------------------------------
Overall Framework of MPRNet
Supervised Attention Module (SAM)
Deblurring on Synthetic Datasets.
Deblurring on Real Dataset.