├── README.md ├── additional_utils.py ├── config_parser.py ├── exp_denoising.sh ├── loss.py ├── main.py ├── models ├── S2Snet.py ├── __init__.py ├── common.py ├── downsampler.py ├── iterBN.py ├── resnet.py ├── skip.py ├── snet.py ├── texture_nets.py └── unet.py ├── requirements.txt ├── tasks.py ├── testset ├── BSD68 │ ├── test001.png │ ├── test002.png │ ├── test003.png │ ├── test004.png │ ├── test005.png │ ├── test006.png │ ├── test007.png │ ├── test008.png │ ├── test009.png │ ├── test010.png │ ├── test011.png │ ├── test012.png │ ├── test013.png │ ├── test014.png │ ├── test015.png │ ├── test016.png │ ├── test017.png │ ├── test018.png │ ├── test019.png │ ├── test020.png │ ├── test021.png │ ├── test022.png │ ├── test023.png │ ├── test024.png │ ├── test025.png │ ├── test026.png │ ├── test027.png │ ├── test028.png │ ├── test029.png │ ├── test030.png │ ├── test031.png │ ├── test032.png │ ├── test033.png │ ├── test034.png │ ├── test035.png │ ├── test036.png │ ├── test037.png │ ├── test038.png │ ├── test039.png │ ├── test040.png │ ├── test041.png │ ├── test042.png │ ├── test043.png │ ├── test044.png │ ├── test045.png │ ├── test046.png │ ├── test047.png │ ├── test048.png │ ├── test049.png │ ├── test050.png │ ├── test051.png │ ├── test052.png │ ├── test053.png │ ├── test054.png │ ├── test055.png │ ├── test056.png │ ├── test057.png │ ├── test058.png │ ├── test059.png │ ├── test060.png │ ├── test061.png │ ├── test062.png │ ├── test063.png │ ├── test064.png │ ├── test065.png │ ├── test066.png │ ├── test067.png │ └── test068.png ├── CSet9 │ ├── image_Baboon512rgb.png │ ├── image_F16_512rgb.png │ ├── image_House256rgb.png │ ├── image_Lena512rgb.png │ ├── image_Peppers512rgb.png │ ├── kodim01.png │ ├── kodim02.png │ ├── kodim03.png │ └── kodim12.png ├── MNIST │ ├── 1.png │ ├── 10.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ ├── 5.png │ ├── 6.png │ ├── 7.png │ ├── 8.png │ └── 9.png └── Set12 │ ├── 01.png │ ├── 02.png │ ├── 03.png │ ├── 04.png │ ├── 05.png │ ├── 06.png │ ├── 07.png │ ├── 08.png │ ├── 09.png │ ├── 10.png │ ├── 11.png │ └── 12.png └── utils ├── PerceptualSimilarity ├── __init__.py ├── base_model.py ├── dist_model.py ├── networks_basic.py ├── pretrained_networks.py ├── util.py └── weights │ ├── v0.0 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth │ └── v0.1 │ ├── alex.pth │ ├── squeeze.pth │ └── vgg.pth ├── REDutils.py ├── __init__.py ├── blur_utils.py ├── common_utils.py ├── denoising_utils.py └── parse_utils.py /README.md: -------------------------------------------------------------------------------- 1 | # DIP-denosing 2 | 3 | This is a code repo for [Rethinking Deep Image Prior for Denoising](https://arxiv.org/abs/2108.12841) (ICCV 2021). 4 | 5 | Addressing the relationship between Deep image prior and effective degrees of freedom, DIP-SURE with STE(stochestic temporal ensemble) shows reasonable result on single image denoising. 6 | 7 | If you use any of this code, please cite the following publication: 8 | 9 | ``` Citation 10 | @article{jo2021dipdenoising, 11 | author = {Yeonsik Jo, Se young chun, and Choi, Jonghyun}, 12 | title = {Rethinking Deep Image Prior for Denoising}, 13 | booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV)}, 14 | month = {October}, 15 | year = {2021}, 16 | pages = {5087-5096} 17 | } 18 | ``` 19 | 20 | ## Working environment 21 | 22 | - TITAN Xp 23 | - ubuntu 18.04.4 24 | - pytorch 1.6 25 | 26 | 27 | **Note:** 28 | Experimental results were not checked in other environments. 29 | 30 | ## Set-up 31 | 32 | - Make your own environment 33 | 34 | ```bash 35 | conda create --name DIP --file requirements.txt 36 | conda avtivate DIP 37 | pip install tqdm 38 | ``` 39 | 40 | ## Inference 41 | 42 | - Produce CSet9 result 43 | ```bash 44 | bash exp_denoising.sh CSet9 45 | ``` 46 | 47 | - For your own data with sigma=25 setup 48 | ```bash 49 | mkdir testset/ 50 | python main.py --dip_type eSURE_new --net_type s2s --exp_tag --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 --eval_data 51 | ``` 52 | 53 | ## Browsing experimental result 54 | 55 | - We provide reporting code with [invoke](https://www.pyinvoke.org/). 56 | ```bash 57 | invoke showtable csv// 58 | ``` 59 | 60 | - Example. 61 | ```bash 62 | invoke showtable csv/poisson/MNIST/ 63 | PURE_dc_scale001_new optimal stopping : 384.30, 31.97/0.02 | ZCSC : 447.60, 31.26/0.02 | STE 31.99/0.02 64 | PURE_dc_scale01_new optimal stopping : 94.70, 24.96/0.12 | ZCSC : 144.60, 24.04/0.14 | STE 24.89/0.12 65 | PURE_dc_scale02_new optimal stopping : 70.30, 22.92/0.20 | ZCSC : 110.00, 21.82/0.22 | STE 22.83/0.20 66 | optimal stopping :, / | ZCSC : , /| STE / 67 | ``` 68 | The reported numbers are PSNR/LPIPS. 69 | 70 | ## Results in paper 71 | For the result used on paper, please refer [this link](https://drive.google.com/drive/folders/1wAdBUguLTwALFmgmTNNbgwY5zku-c5Pz?usp=sharing). 72 | 73 | ## SSIM score 74 | For SSIM score of color images, I used matlab code same as the author of [S2S](https://github.com/scut-mingqinchen/self2self). 75 | This is the demo code I received from the S2S author. 76 | Thank you Mingqin! 77 | ```Matlab 78 | % examples 79 | ref = im2double(imread('gt.png')); 80 | noisy = im2double(imread('noisy.png')); 81 | psnr_result = psnr(ref, noisy); 82 | ssim_result = ssim(ref, noisy); 83 | ``` 84 | 85 | ## License 86 | 87 | MIT license. 88 | 89 | ## Contacts 90 | 91 | For questions, please send an email to **dustlrdk@gmail.com** -------------------------------------------------------------------------------- /additional_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim.optimizer import Optimizer 7 | 8 | 9 | def fix_running_statistic(net): 10 | cnt = 0 11 | for i in net.modules(): 12 | if isinstance(i, nn.BatchNorm2d): 13 | i.track_running_stats = False 14 | cnt += 1 15 | elif isinstance(i, nn.InstanceNorm2d): 16 | i.track_running_stats = False 17 | cnt += 1 18 | 19 | print("[*] %d Batchnorm is changed to fix statistic" % cnt) 20 | 21 | # DATA AUG. 22 | def data_aug_with_mode(img, mode = 0): 23 | """ 24 | :param img: Should be 4-dim tensor[B, C, H, W] 25 | :param mode: 26 | :return: 27 | """ 28 | if mode == 0: 29 | return img 30 | elif mode == 1: 31 | return torch.flip(img, dims=[2]) 32 | elif mode == 2: 33 | return torch.flip(img, dims=[3]) 34 | elif mode == 3: 35 | return torch.flip(img, dims=[2, 3]) 36 | 37 | # init optimization. 38 | def init_optim(model, optim): 39 | group = optim.param_groups[0] 40 | for n in [x for x in model.parameters()]: 41 | state = optim.state[n] 42 | state["step"] = 0 43 | state['exp_avg'] = torch.zeros_like(n, memory_format=torch.preserve_format) 44 | # Exponential moving average of squared gradient values 45 | state['exp_avg_sq'] = torch.zeros_like(n, memory_format=torch.preserve_format) 46 | 47 | class RAdam(Optimizer): 48 | 49 | def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, degenerated_to_sgd=True): 50 | if not 0.0 <= lr: 51 | raise ValueError("Invalid learning rate: {}".format(lr)) 52 | if not 0.0 <= eps: 53 | raise ValueError("Invalid epsilon value: {}".format(eps)) 54 | if not 0.0 <= betas[0] < 1.0: 55 | raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) 56 | if not 0.0 <= betas[1] < 1.0: 57 | raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) 58 | 59 | self.degenerated_to_sgd = degenerated_to_sgd 60 | if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): 61 | for param in params: 62 | if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): 63 | param['buffer'] = [[None, None, None] for _ in range(10)] 64 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, 65 | buffer=[[None, None, None] for _ in range(10)]) 66 | super(RAdam, self).__init__(params, defaults) 67 | 68 | def __setstate__(self, state): 69 | super(RAdam, self).__setstate__(state) 70 | 71 | def step(self, closure=None): 72 | 73 | loss = None 74 | if closure is not None: 75 | loss = closure() 76 | 77 | for group in self.param_groups: 78 | 79 | for p in group['params']: 80 | if p.grad is None: 81 | continue 82 | grad = p.grad.data.float() 83 | if grad.is_sparse: 84 | raise RuntimeError('RAdam does not support sparse gradients') 85 | 86 | p_data_fp32 = p.data.float() 87 | 88 | state = self.state[p] 89 | 90 | if len(state) == 0: 91 | state['step'] = 0 92 | state['exp_avg'] = torch.zeros_like(p_data_fp32) 93 | state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) 94 | else: 95 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 96 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) 97 | 98 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 99 | beta1, beta2 = group['betas'] 100 | 101 | exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) 102 | exp_avg.mul_(beta1).add_(1 - beta1, grad) 103 | 104 | state['step'] += 1 105 | buffered = group['buffer'][int(state['step'] % 10)] 106 | if state['step'] == buffered[0]: 107 | N_sma, step_size = buffered[1], buffered[2] 108 | else: 109 | buffered[0] = state['step'] 110 | beta2_t = beta2 ** state['step'] 111 | N_sma_max = 2 / (1 - beta2) - 1 112 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) 113 | buffered[1] = N_sma 114 | 115 | # more conservative since it's an approximated value 116 | if N_sma >= 5: 117 | step_size = math.sqrt( 118 | (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / ( 119 | N_sma_max - 2)) / (1 - beta1 ** state['step']) 120 | elif self.degenerated_to_sgd: 121 | step_size = 1.0 / (1 - beta1 ** state['step']) 122 | else: 123 | step_size = -1 124 | buffered[2] = step_size 125 | 126 | # more conservative since it's an approximated value 127 | if N_sma >= 5: 128 | if group['weight_decay'] != 0: 129 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 130 | denom = exp_avg_sq.sqrt().add_(group['eps']) 131 | p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) 132 | p.data.copy_(p_data_fp32) 133 | elif step_size > 0: 134 | if group['weight_decay'] != 0: 135 | p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) 136 | p_data_fp32.add_(-step_size * group['lr'], exp_avg) 137 | p.data.copy_(p_data_fp32) 138 | 139 | return loss -------------------------------------------------------------------------------- /config_parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | def main_parser(): 5 | parser = argparse.ArgumentParser() 6 | # shared param 7 | task_parsers = parser.add_subparsers(dest='task_type') 8 | parser.add_argument("--dip_type", default = "dip") 9 | parser.add_argument("--gray", action="store_true") 10 | 11 | # denoising param 12 | denoising_parser = task_parsers.add_parser('denoising') 13 | denoising_parser.add_argument("--eval_data", default="CSet9") 14 | denoising_parser.add_argument("--sigma", default=50, type=int) 15 | denoising_parser.add_argument("--lr", default=0.1, type=float) 16 | denoising_parser.add_argument('--reg_noise_std', default=1./20., type=float) 17 | 18 | # poisson param 19 | deblur_parser = task_parsers.add_parser('poisson') 20 | deblur_parser.add_argument("--eval_data", default="MNIST") 21 | deblur_parser.add_argument("--scale", default=0.1, type=float) 22 | deblur_parser.add_argument("--lr", default=0.1, type=float) 23 | deblur_parser.add_argument('--reg_noise_std', default=0.01, type=float) 24 | 25 | # network param 26 | parser.add_argument('--input_depth', default=3, type=int) 27 | parser.add_argument('--hidden_layer', default=64, type=int) 28 | parser.add_argument('--act_func', default="soft", type=str) # temporal experiment 29 | parser.add_argument("--optim", default="RAdam") 30 | parser.add_argument('--sigma_z', default=0.5, type=float) 31 | parser.add_argument("--net_type", default="s2s", type=str) 32 | 33 | # Additional methods. 34 | parser.add_argument("--force_steplr", action= "store_true") 35 | parser.add_argument("--extending", action="store_true") 36 | 37 | # BatchNorm methods. 38 | parser.add_argument("--bn_type", default="bn", type=str) 39 | parser.add_argument("--bn_fix_epoch", default=-1, type=int) 40 | 41 | # Extra_method related to DIP. 42 | parser.add_argument('--running_avg_ratio', default=0.99, type=float) 43 | 44 | # power of perturbation in divergence. 45 | parser.add_argument("--epsilon", default=0.5, type=float)# 1.6e-4 46 | parser.add_argument('--desc', default="", type=str) 47 | parser.add_argument("--exp_tag", default="", type=str) 48 | parser.add_argument('--show_every', default=500, type=int) 49 | parser.add_argument('--optim_init', default=0, type=int) 50 | parser.add_argument('--save_np', action='store_true') 51 | parser.add_argument('--epoch', default=0, type= int) 52 | parser.add_argument('--beta1', default=0.9, type=float) # Momentum. 53 | parser.add_argument('--beta2', default=0.999, type=float) # Adaptive learning rate. 54 | parser.add_argument('--noisy_map', action="store_true") 55 | parser.add_argument('--GT_noise', action="store_true") 56 | 57 | args = parser.parse_args() 58 | args.desc = "_" + args.desc 59 | 60 | return args 61 | -------------------------------------------------------------------------------- /exp_denoising.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | echo task $1 GPU_ID $2 4 | export CUDA_VISIBLE_DEVICES=$2 5 | 6 | # [ ] need space. 7 | if [ $1 == "dip" ]; then 8 | tag=dip_set9 9 | python main.py --dip_type dip --net_type s2s --exp_tag $tag --desc sigma15 denoising --sigma 15 10 | python main.py --dip_type dip --net_type s2s --exp_tag $tag --desc sigma25 denoising --sigma 25 11 | python main.py --dip_type dip --net_type s2s --exp_tag $tag --desc sigma50 denoising --sigma 50 12 | 13 | elif [ $1 == "ablation" ]; then 14 | tag=ablation 15 | python main.py --dip_type dip --net_type s2s --exp_tag $tag --desc sigma25 denoising --sigma 25 16 | python main.py --dip_type dip_sure --net_type s2s --exp_tag $tag --desc sigma25 denoising --sigma 25 17 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag CSet9 --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 18 | 19 | elif [ $1 == "CSet9" ]; then 20 | tag=CSet9 21 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag CSet9 --optim RAdam --force_steplr --desc sigma15 denoising --sigma 15 22 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag CSet9 --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 23 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag CSet9 --optim RAdam --force_steplr --desc sigma50 denoising --sigma 50 24 | 25 | elif [ $1 == "set12" ]; then 26 | tag=Set12 27 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma15 denoising --sigma 15 --eval_data Set12 28 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma25 denoising --sigma 25 --eval_data Set12 29 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma50 denoising --sigma 50 --eval_data Set12 30 | 31 | elif [ $1 == "McM" ]; then 32 | tag=McM 33 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma15 denoising --sigma 15 --eval_data McM 34 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 --eval_data McM 35 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma50 denoising --sigma 50 --eval_data McM 36 | 37 | elif [ $1 == "kodak" ]; then 38 | tag=kodak 39 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma15 denoising --sigma 15 --eval_data Kodak 40 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 --eval_data Kodak 41 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma50 denoising --sigma 50 --eval_data Kodak 42 | 43 | elif [ $1 == "CBSD" ]; then 44 | tag=CBSD 45 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma15 denoising --sigma 15 --eval_data CBSD68 46 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma25 denoising --sigma 25 --eval_data CBSD68 47 | python main.py --dip_type eSURE_uniform --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --desc sigma50 denoising --sigma 50 --eval_data CBSD68 48 | 49 | elif [ $1 == "BSD" ]; then 50 | tag=BSD 51 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma15 denoising --sigma 15 --eval_data BSD68 52 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma25 denoising --sigma 25 --eval_data BSD68 53 | python main.py --dip_type eSURE_uniform --gray --net_type s2s --exp_tag $tag --optim RAdam --force_steplr --sigma_z 0.3 --desc sigma50 denoising --sigma 50 --eval_data BSD68 54 | 55 | elif [ $1 == "DIP_MNIST" ]; then 56 | tag=DIP_MNIST 57 | python main.py --dip_type dip --gray --running_avg_ratio 0.9 --exp_tag $tag --optim RAdam --force_steplr --desc scale001 poisson --scale 0.01 --eval_data MNIST 58 | python main.py --dip_type dip --gray --running_avg_ratio 0.9 --exp_tag $tag --optim RAdam --force_steplr --desc scale01 poisson --scale 0.1 --eval_data MNIST 59 | python main.py --dip_type dip --gray --running_avg_ratio 0.9 --exp_tag $tag --optim RAdam --force_steplr --desc scale02 poisson --scale 0.2 --eval_data MNIST 60 | 61 | elif [ $1 == "MNIST" ]; then 62 | tag=MNIST 63 | python main.py --dip_type PURE_dc --gray --running_avg_ratio 0.9 --net_type s2s_normal --exp_tag $tag --optim RAdam --desc scale001 poisson --scale 0.01 --eval_data MNIST 64 | python main.py --dip_type PURE_dc --gray --running_avg_ratio 0.9 --net_type s2s_normal --exp_tag $tag --optim RAdam --desc scale01 poisson --scale 0.1 --eval_data MNIST 65 | python main.py --dip_type PURE_dc --gray --running_avg_ratio 0.9 --net_type s2s_normal --exp_tag $tag --optim RAdam --desc scale02 poisson --scale 0.2 --eval_data MNIST 66 | 67 | else 68 | echo wrong task name 69 | fi 70 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import additional_utils 4 | from utils.blur_utils import * # blur functions 5 | 6 | class DIPloss(nn.Module): 7 | def __init__(self, net, net_input, args): 8 | super(DIPloss, self).__init__() 9 | self.net = net 10 | self.dip_type = args.dip_type 11 | self.reg_noise_std = args.reg_noise_std 12 | self.task_type = args.task_type 13 | self.dtype = args.dtype 14 | self.mse = torch.nn.MSELoss(reduction="sum") 15 | self.net_input_saved = net_input.detach().clone() 16 | self.cnt = 0 17 | self.epsilon_decay = False 18 | self.reduction = "mean" 19 | 20 | def set_sigma(self, sigma): 21 | self.sigma = sigma 22 | self.sigma_y = sigma 23 | self.sigma_z = sigma * self.arg_sigma_z 24 | self.eps_y = torch.ones([1], device="cuda").reshape([-1, 1, 1, 1]) * self.sigma_y / 255.0 25 | self.eps_tf = self.eps_y * self.arg_epsilon 26 | self.eps_tf_init = self.eps_tf.clone() 27 | self.vary = (self.eps_y) ** 2 28 | 29 | def DIP(self, net_input, noisy_torch): 30 | if self.reg_noise_std > 0: 31 | net_input = self.net_input_saved + (torch.rand_like(self.net_input_saved).normal_() * self.reg_noise_std) 32 | out = self.inference(net_input) 33 | total_loss = torch.mean((out - noisy_torch) ** 2) 34 | return total_loss, out 35 | 36 | def SURE(self, output, target, divergence, sigma): 37 | batch, c, h, w = output.shape 38 | divergence = divergence * sigma 39 | mse = (output - target) ** 2 40 | esure = mse + 2 * divergence - sigma 41 | esure = torch.sum(esure) 42 | esure = esure if self.reduction == "sum" else esure / (h * w * c) 43 | return esure 44 | 45 | def DIP_SURE(self, net_input, noisy_torch): 46 | if self.sigma_z > 0 or self.uniform_sigma: 47 | if self.uniform_sigma: 48 | self.eSigma = np.random.uniform(0, self.sigma_z) / 255.0 49 | else: 50 | self.eSigma = self.sigma_z / 255.0 51 | net_input = self.net_input_saved + torch.randn_like(net_input).type(self.dtype) * self.eSigma 52 | net_input = net_input.requires_grad_() 53 | 54 | out = self.inference(net_input)#.contiguous(memory_format=torch.channels_last) 55 | divergence = self.divergence(net_input, out) 56 | total_loss = self.SURE(out, noisy_torch, divergence, self.vary) 57 | return total_loss, out 58 | 59 | def divergence_ty(self, net_input, out): 60 | if self.epsilon_decay: 61 | self.eps_tf = self.eps_tf_init * (0.9 ** (self.cnt // 200)) 62 | b_prime = torch.randn_like(net_input).type(self.dtype) 63 | out_ptb = self.inference(net_input + b_prime * self.eps_tf) 64 | divergence = (b_prime * (out_ptb - out)) / self.eps_tf 65 | return divergence 66 | 67 | def divergence_new(self, net_input, out): 68 | b_prime = torch.randn_like(net_input).type(self.dtype) 69 | nh_y = torch.sum(b_prime * out, dim=[1, 2, 3]) 70 | vector = torch.ones(1).to(out) 71 | divergence = b_prime * \ 72 | torch.autograd.grad(nh_y, net_input, grad_outputs=vector, retain_graph=True, create_graph=True)[0] 73 | return divergence 74 | 75 | def inference(self, x): 76 | return self.net(x) 77 | 78 | def forward(self, input, target): 79 | self.cnt += 1 80 | return self.loss(input, target) 81 | 82 | class Denoising_loss(DIPloss): 83 | def __init__(self, net, net_input, args): 84 | super(Denoising_loss, self).__init__(net, net_input, args) 85 | # parameter related to SURE. 86 | self.arg_sigma_z = args.sigma_z 87 | self.arg_epsilon = args.epsilon 88 | self.set_sigma(args.sigma) 89 | self.cnt = 0 90 | self.epsilon_decay = False 91 | self.reduction = "mean" 92 | 93 | print("[*] loss type : %s" % args.dip_type) 94 | print("[*] sigma : %.2f" % self.sigma) 95 | print("[*] sigma_z : %.2f" % self.sigma_z) 96 | 97 | self.divergence = self.divergence_ty 98 | self.uniform_sigma = False 99 | self.clip_divergence = False 100 | if self.dip_type == "dip": 101 | self.loss = self.DIP 102 | elif self.dip_type == "dip_sure": 103 | self.sigma_z = 0 104 | self.loss = self.DIP_SURE 105 | elif self.dip_type == "dip_sure_new": 106 | self.sigma_z = 0 107 | self.loss = self.DIP_SURE 108 | self.divergence = self.divergence_new 109 | elif self.dip_type == "eSURE": 110 | self.loss = self.DIP_SURE 111 | elif self.dip_type == "eSURE_alpha": 112 | self.epsilon_decay = True 113 | self.loss = self.DIP_SURE 114 | elif self.dip_type == "eSURE_new": 115 | self.divergence = self.divergence_new 116 | self.loss = self.DIP_SURE 117 | elif self.dip_type == "eSURE_uniform": 118 | self.uniform_sigma = True 119 | self.divergence = self.divergence_new 120 | self.loss = self.DIP_SURE 121 | else: 122 | print("[!] Not defined loss function.") 123 | raise NotImplementedError 124 | 125 | class Poisson_loss(DIPloss): 126 | def __init__(self, net, net_input, args): 127 | super(Poisson_loss, self).__init__(net, net_input, args) 128 | self.net = net 129 | self.dip_type = args.dip_type 130 | self.reg_noise_std = args.reg_noise_std 131 | self.task_type = args.task_type 132 | self.dtype = args.dtype 133 | self.mse = torch.nn.MSELoss(reduction="sum") 134 | self.net_input_saved = net_input.detach().clone() 135 | 136 | self.arg_sigma_z = args.sigma_z 137 | self.arg_epsilon = args.epsilon 138 | self.divergence = self.divergence_ty 139 | self.uniform_sigma = False 140 | self.clip_divergence = False 141 | 142 | # parameter related to SURE. 143 | self.scale = args.scale 144 | self.eps = 0.01 145 | self.epsilon_decay = False 146 | self.reduction = "mean" 147 | 148 | print("[*] loss type : %s" % args.dip_type) 149 | print("[*] Poisson scale : %.2f" % self.scale) 150 | # print("[*] sigma_z : %.2f" % self.sigma_z) 151 | self.uniform_sigma = False 152 | self.eps_decay = False 153 | if self.dip_type == "dip": 154 | self.loss = self.DIP 155 | elif self.dip_type == "PURE": 156 | self.loss = self.DIP_PURE 157 | elif self.dip_type == "PURE_dc": 158 | self.eps = 0.1 159 | self.eps_decay = True 160 | self.loss = self.DIP_PURE 161 | 162 | def PURE(self, output, target, scale): 163 | Y_ = output 164 | Y = target 165 | b_prime = 2*(torch.randint_like(target, 0, 2) - 0.5) # [-1, 1] random vector 166 | if self.eps_decay and (self.cnt % 20 == 9): 167 | self.eps *= 0.9 168 | Z = Y + self.eps * b_prime 169 | Z_ = self.inference(Z) 170 | batch, c, h, w = output.shape 171 | mse = torch.mean((Y - Y_) ** 2) 172 | T1 = - scale * torch.mean(target)# / batch 173 | gradient = 2*(scale / (self.eps * batch)) * torch.mean((b_prime *Y) * (Z_ - Y_)) 174 | return mse + T1 + gradient 175 | 176 | def DIP_PURE(self, net_input, noisy_torch): 177 | out = self.inference(net_input) # .contiguous(memory_format=torch.channels_last) 178 | total_loss = self.PURE(out, noisy_torch, self.scale) 179 | return total_loss, out 180 | 181 | 182 | def get_loss(net, net_input, args): 183 | if args.task_type == "denoising": 184 | print("[!] Denoising mode setup.") 185 | return Denoising_loss(net, net_input, args) 186 | elif args.task_type == "poisson": 187 | return Poisson_loss(net, net_input, args) 188 | else: 189 | raise NotImplementedError 190 | 191 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import glob 4 | import json 5 | 6 | import cv2 7 | import torch 8 | import numpy as np 9 | import pandas as pd 10 | 11 | import loss 12 | import models 13 | import config_parser 14 | 15 | 16 | from utils.common_utils import * 17 | from utils.denoising_utils import * 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | # beta version code 21 | import additional_utils 22 | 23 | def get_net(img_np, noise_np, args): 24 | net = models.get_net(args) 25 | 26 | if args.dip_type in ["dip_sure", "eSURE", "NCV_y", "eSURE_fixed", 'eSURE_new', 'eSURE_alpha', "eSURE_uniform", "eSURE_clip","eSURE_real", "no_div", "PURE", "PURE_dc", "dip_sure_new"]: 27 | net_input = cv2_to_torch(noise_np, dtype) 28 | print("[*] input_type : noisy image") 29 | else: 30 | INPUT = 'noise' 31 | input_depth = 1 if args.gray else 3 32 | # For SR, the get_noise should be same as img_np 33 | net_input = get_noise(input_depth, INPUT, (img_np.shape[1], img_np.shape[2])).type(dtype).detach() 34 | print("[*] input_type : noise") 35 | 36 | return net, net_input 37 | 38 | def get_optim(name, net, lr, beta): 39 | if name == "adam": 40 | print("[*] optim_type : Adam") 41 | return torch.optim.Adam(net.parameters(), lr, beta) 42 | elif name == "adamw": 43 | print("[*] optim_type : AdamW (wd : 1e-2)") 44 | return torch.optim.AdamW(net.parameters(), lr, beta) # default weight decay is 1e-2. 45 | elif name == "RAdam": 46 | return additional_utils.RAdam(net.parameters(), lr, beta) 47 | else: 48 | raise NotImplementedError 49 | 50 | def image_restorazation(file, args): 51 | # MAIN 52 | stat = {} 53 | task_type = args.task_type 54 | 55 | # Step 1. prepare clean & degradation(noisy) pair 56 | img_np, noisy_np = load_image_pair(file, task_type, args) 57 | if args.GT_noise: 58 | args.sigma = (img_np.astype(np.float) - noisy_np.astype(np.float)).std() 59 | # np_to_torch function from utils.common_utils. 60 | # _np : C,H,W [0, 255] -> _torch : C,H,W [0,1] scale 61 | img_torch = cv2_to_torch(img_np, args.dtype) 62 | noise_torch = cv2_to_torch(noisy_np, args.dtype) 63 | 64 | # For PSNR measure. 65 | noisy_clip_np = np.clip(noisy_np, 0, 255) 66 | # Step 2. make model and model input 67 | net, net_input = get_net(img_np, noisy_np, args) 68 | net.train() 69 | 70 | # Step 3. set loss function. 71 | cal_loss = loss.get_loss(net, net_input, args) 72 | optimizer = get_optim(args.optim, net, args.lr, (args.beta1, args.beta2)) 73 | if args.force_steplr: 74 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=.9, step_size=300) 75 | else: 76 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2000, 3000], gamma=0.5) 77 | 78 | # Step 4. optimization and inference. 79 | # Hyper_param for Learning 80 | psnr_noisy_last = 0 81 | psnr_gt_running = 0 82 | 83 | save_dir = args.save_dir 84 | 85 | # Ensemble methods. 86 | running_avg = None 87 | running_avg_ratio = args.running_avg_ratio 88 | 89 | image_name = file.split("/")[-1][:-4] 90 | np_save_dir = os.path.join(args.save_dir, image_name) 91 | os.makedirs(np_save_dir, exist_ok=True) 92 | 93 | stat["max_psnr"] = 0 94 | stat["max_ssim"] = 0 95 | stat["NUM_Backtracking"] = 0 96 | 97 | args.writer = SummaryWriter(log_dir="runs/%s/%s" % (args.exp_tag, args.desc + image_name)) 98 | for ep in range(args.epoch): 99 | optimizer.zero_grad() 100 | total_loss, out = cal_loss(net_input, noise_torch) 101 | with torch.no_grad(): 102 | mse_loss = torch.nn.functional.mse_loss(out, img_torch).item() 103 | diff_loss = total_loss.item() - mse_loss 104 | args.writer.add_scalar("loss/used_loss", total_loss.item(), global_step=ep) 105 | args.writer.add_scalar("loss/MSE_loss", mse_loss, global_step=ep) 106 | args.writer.add_scalar("loss/diff", diff_loss, global_step=ep) 107 | 108 | # _torch : C,H,W [0,1] scale => _np : C,H,W [0, 255] 109 | #out = torch_to_cv2(net(net_input)) 110 | out = torch_to_cv2(out) 111 | psnr_noisy = calculate_psnr(noisy_clip_np, out) 112 | psnr_gt = calculate_psnr(img_np, out) 113 | lpips_noisy = calculate_lpips(noisy_clip_np, out, args.lpips) 114 | lpips_gt = calculate_lpips(img_np, out, args.lpips) 115 | args.writer.add_scalar("psnr/noisy_to_out", psnr_noisy, global_step=ep) 116 | args.writer.add_scalar("psnr/clean_to_out", psnr_gt, global_step=ep) 117 | args.writer.add_scalar("lpips/noisy_to_out", lpips_noisy, global_step=ep) 118 | args.writer.add_scalar("lpips/clean_to_out", lpips_gt, global_step=ep) 119 | 120 | if total_loss < 0: 121 | print('\nLoss is less than 0') 122 | for new_param, net_param in zip(last_net, net.parameters()): 123 | net_param.data.copy_(new_param.cuda()) 124 | break 125 | if (psnr_noisy - psnr_noisy_last < -5) and (ep > 5) : 126 | print('\nFalling back to previous checkpoint.') 127 | for new_param, net_param in zip(last_net, net.parameters()): 128 | net_param.data.copy_(new_param.cuda()) 129 | stat["NUM_Backtracking"] += 1 130 | if stat["NUM_Backtracking"] > 10: 131 | break 132 | # continue 133 | else: 134 | # Running ensemble 135 | if True: #(ep % 50 == 0) and 136 | if running_avg is None: 137 | running_avg = out 138 | else: 139 | running_avg = running_avg * running_avg_ratio + out * (1 - running_avg_ratio) 140 | psnr_gt_running = calculate_psnr(img_np, running_avg) 141 | lpips_gt_running = calculate_lpips(img_np, running_avg, args.lpips, color="BGR") 142 | args.writer.add_scalar("psnr/clean_to_avg", psnr_gt_running, global_step=ep) 143 | args.writer.add_scalar("lpips/clean_to_avg", lpips_gt_running, global_step=ep) 144 | 145 | if (stat["max_psnr"] <= psnr_gt): 146 | stat["max_step"] = ep 147 | stat["max_psnr"] = psnr_gt 148 | stat["max_psnr_avg"] = psnr_gt_running 149 | stat["max_lpips_avg"] = lpips_gt_running 150 | stat["max_lpips"] = lpips_gt 151 | max_out, maxavg_out = out.copy(),running_avg.copy() 152 | 153 | #save file 154 | if args.save_np: 155 | state_dict = net.state_dict() 156 | torch.save(state_dict, os.path.join(np_save_dir, "max_psnr_state_dict.pth")) 157 | 158 | if (ep == 200 or ep == 10) and (psnr_gt_running < psnr_gt): 159 | running_avg = None 160 | 161 | # args.writer.add_image("result/gt_noise_out_avg", np.concatenate([img_np, noisy_np, out, running_avg], axis=2), ep) 162 | print('Iteration %05d total loss / MSE / diff %f / %f / %f PSNR_noisy: %f psnr_gt: %f PSNR_gt_sm: %f' % ( 163 | ep, total_loss.item(), mse_loss, diff_loss, psnr_noisy, psnr_gt, psnr_gt_running), end='\r') 164 | 165 | last_net = [x.detach().cpu() for x in net.parameters()] 166 | psnr_noisy_last=psnr_noisy 167 | total_loss.backward() 168 | optimizer.step() 169 | scheduler.step() 170 | torch.cuda.empty_cache() 171 | 172 | if args.optim_init > 0: 173 | if ep % args.optim_init == 0: 174 | additional_utils.init_optim(net, optimizer) 175 | 176 | stat["final_ep"] = ep 177 | stat["final_psnr"] = psnr_gt 178 | stat["final_psnr_avg"] = psnr_gt_running 179 | stat["final_lpips_avg"]= lpips_gt_running 180 | stat["final_lpips"] = lpips_gt 181 | 182 | 183 | # Make final images 184 | if True: 185 | save_CHW_np(save_dir + "/%s.png" % (image_name), out) 186 | save_CHW_np(save_dir + "/%s_avg.png" % (image_name), running_avg) 187 | save_CHW_np(save_dir + "/%s_max.png" % (image_name), max_out) 188 | save_CHW_np(save_dir + "/%s_max_avg.png" % (image_name), maxavg_out) 189 | 190 | if args.gray: 191 | stat["final_ssim"] = calculate_ssim(img_np, out) 192 | stat["final_ssim_avg"] = calculate_ssim(img_np, running_avg) 193 | stat["max_ssim"] = calculate_ssim(img_np, max_out) 194 | stat["max_ssim_avg"] = calculate_ssim(img_np, maxavg_out) 195 | log_file = open(save_dir + "/%s_log.txt" % (image_name), "w") 196 | print(stat, file=log_file) 197 | print("%s psnr clean_out : %.2f, %.2f noise_out : %.2f, max %.2f, %.2f" % ( 198 | image_name, psnr_gt_running, lpips_gt_running, psnr_noisy, stat["max_psnr"], stat["max_lpips"]), " " * 100) 199 | print(stat) 200 | args.writer.close() 201 | torch.cuda.empty_cache() 202 | return stat 203 | 204 | 205 | def read_dataset_file_list(eval_data): 206 | dataset_dir = "./testset/%s/" % eval_data 207 | file_list1 = glob.glob(dataset_dir + "*.tif") 208 | file_list2 = glob.glob(dataset_dir + "*.png") 209 | file_list3 = glob.glob(dataset_dir + "*.JPG") 210 | file_list = file_list1 + file_list2 + file_list3 211 | return file_list 212 | 213 | 214 | if __name__ == "__main__": 215 | # For REPRODUCIBILITY 216 | print("[*] reproduce mode On") 217 | torch.manual_seed(0) 218 | np.random.seed(0) 219 | if torch.cuda.is_available(): 220 | torch.backends.cudnn.enabled = True 221 | torch.backends.cudnn.deterministic = True 222 | torch.backends.cudnn.benchmark = False 223 | dtype = torch.cuda.FloatTensor 224 | lpips = get_lpips("cuda") 225 | else: 226 | dtype = torch.FloatTensor 227 | lpips = get_lpips("cpu") 228 | args = config_parser.main_parser() 229 | args.save_dir = "./result/%s/%s/%s" % (args.task_type, args.exp_tag, args.dip_type + args.desc) 230 | os.makedirs(args.save_dir, exist_ok = True) 231 | 232 | # default epoch setup. 233 | if args.task_type == "denoising": 234 | args.epoch = 3000 if args.epoch == 0 else args.epoch 235 | args.save_point = [1, 10, 100, 500, 1000, 2000, 3000, 4000] 236 | elif args.task_type == "poisson": 237 | args.epoch = 3000 if args.epoch == 0 else args.epoch 238 | args.save_point = [1, 10, 100, 500, 1000, 2000, 3000, 4000] 239 | 240 | 241 | with open(os.path.join(args.save_dir, 'args.json'), 'w') as f: 242 | json.dump(args.__dict__, f, indent=2) 243 | args.dtype = dtype 244 | args.lpips = lpips 245 | 246 | # file_list. 247 | file_list = read_dataset_file_list(args.eval_data) 248 | file_list = sorted(file_list) 249 | stat_list = [] 250 | for file in file_list: 251 | print("[*] process image file : %s" % file) 252 | stat = image_restorazation(file, args) 253 | stat_list.append(stat) 254 | 255 | data = pd.DataFrame(stat_list, index= [i.split("/")[-1] for i in file_list]) 256 | os.makedirs("./csv/%s/%s/" % (args.task_type, args.exp_tag), exist_ok=True) 257 | data.to_csv("./csv/%s/%s/%s.csv" % ( args.task_type, args.exp_tag ,args.dip_type+args.desc)) 258 | print("experiment done") 259 | print(data) 260 | -------------------------------------------------------------------------------- /models/S2Snet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.common import * 4 | 5 | 6 | class conv_block_sp(nn.Module): 7 | def __init__(self, ch_in, ch_out, down=False, act_fun='LeakyReLU', pad='reflection', group=1, bn_mode = "bn", bias=True): 8 | super(conv_block_sp, self).__init__() 9 | self.conv1 = conv(ch_in, ch_out, kernel_size=3, stride=1 if down is False else 2, bias=bias, pad=pad, 10 | group=group) 11 | self.conv2 = conv(ch_out, ch_out, kernel_size=3, stride=1, bias=bias, pad=pad, group=group) 12 | self.conv = nn.Sequential( 13 | self.conv1, bn(ch_out, bn_mode if group == 1 else "groupNorm"), act(act_fun), 14 | self.conv2, bn(ch_out, bn_mode if group == 1 else "groupNorm"), act(act_fun)) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | return x 19 | 20 | 21 | class conv_block_last(nn.Module): 22 | def __init__(self, ch_in, ch_out, down=False, act_fun='LeakyReLU', pad='reflection', group=1, bn_mode = "bn", bias=True): 23 | super(conv_block_last, self).__init__() 24 | self.conv1 = conv(ch_in, 64, kernel_size=3, stride=1 if down is False else 2, bias=bias, pad=pad, group=group) 25 | self.conv2 = conv(64, 32, kernel_size=3, stride=1, bias=bias, pad=pad, group=group) 26 | self.conv3 = conv(32, ch_out, kernel_size=3, stride=1, bias=bias, pad=pad, group=group) 27 | self.conv = nn.Sequential( 28 | self.conv1, bn(64, bn_mode if group == 1 else "groupNorm"), act(act_fun), 29 | self.conv2, bn(32, bn_mode if group == 1 else "groupNorm"), act(act_fun), 30 | self.conv3) 31 | 32 | def forward(self, x): 33 | x = self.conv(x) 34 | return x 35 | 36 | 37 | class SIREN_layer(nn.Module): 38 | def __init__(self, ch_in, ch_out, frist = False, act_fun='sine', omega_0=30): 39 | super(SIREN_layer, self).__init__() 40 | self.conv1 = nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=1, bias=True) 41 | self.act_fun = act(act_fun) 42 | self.omega_0 = omega_0 43 | self.in_features = ch_in 44 | self.frist = frist 45 | self.init() 46 | 47 | 48 | def init(self): 49 | with torch.no_grad(): 50 | if self.frist: 51 | self.conv1.weight.uniform_(-1 / self.in_features, 52 | 1 / self.in_features) 53 | else: 54 | self.conv1.weight.uniform_(-np.sqrt(6 / self.in_features) / self.omega_0, 55 | np.sqrt(6 / self.in_features) / self.omega_0) 56 | 57 | def forward(self, x): 58 | x = self.conv1(x) 59 | return self.act_fun(self.omega_0 * x) 60 | 61 | class SIREN_CONV(nn.Module): 62 | def __init__(self, ch_in, ch_out): 63 | super(SIREN_CONV, self).__init__() 64 | self.conv1 = SIREN_layer(ch_in, 64, frist=True) 65 | self.conv2 = SIREN_layer(64, 32) 66 | self.conv3 = SIREN_layer(32, ch_out) 67 | self.conv = nn.Sequential( 68 | self.conv1, 69 | self.conv2, 70 | self.conv3) 71 | 72 | def forward(self, x): 73 | x = self.conv(x) 74 | return x 75 | 76 | 77 | 78 | class conv_block_skip(nn.Module): 79 | def __init__(self, ch_in, ch_out, act_fun='LeakyReLU', pad='reflection', group=1, bn_mode = "bn", bias=True): 80 | super(conv_block_skip, self).__init__() 81 | self.conv1 = conv(ch_in, ch_out, kernel_size=1, stride=1, bias=bias, pad=pad, group=group) 82 | self.conv = nn.Sequential( 83 | self.conv1, bn(ch_out, bn_mode if group == 1 else "groupNorm"), act(act_fun)) 84 | 85 | def forward(self, x): 86 | x = self.conv(x) 87 | return x 88 | 89 | 90 | class conv_block_concat(nn.Module): 91 | def __init__(self, ch_in, ch_out, act_fun='LeakyReLU', pad='reflection', group=1, bn_mode = "bn", bias=True): 92 | super(conv_block_concat, self).__init__() 93 | self.conv1 = conv(ch_in, ch_out, kernel_size=3, stride=1, bias=bias, pad=pad, group=group) 94 | self.conv2 = conv(ch_out, ch_out, kernel_size=3, stride=1, bias=bias, pad=pad, group=group) 95 | self.up = nn.Sequential( 96 | bn(ch_in), 97 | self.conv1, bn(ch_out, bn_mode if group == 1 else "groupNorm"), act(act_fun), 98 | self.conv2, bn(ch_out, bn_mode if group == 1 else "groupNorm"), act(act_fun)) 99 | 100 | def forward(self, x): 101 | x = self.up(x) 102 | return x 103 | 104 | 105 | class Concat_layer(nn.Module): 106 | def __init__(self, dim): 107 | super(Concat_layer, self).__init__() 108 | self.dim = dim 109 | 110 | def forward(self, inputs): 111 | inputs_shapes2 = [x.shape[2] for x in inputs] 112 | inputs_shapes3 = [x.shape[3] for x in inputs] 113 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( 114 | np.array(inputs_shapes3) == min(inputs_shapes3)): 115 | inputs_ = inputs 116 | else: 117 | target_shape2 = min(inputs_shapes2) 118 | target_shape3 = min(inputs_shapes3) 119 | inputs_ = [] 120 | for inp in inputs: 121 | diff2 = (inp.size(2) - target_shape2) // 2 122 | diff3 = (inp.size(3) - target_shape3) // 2 123 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 124 | return torch.cat(inputs_, dim=self.dim) 125 | 126 | def __len__(self): 127 | return len(self._modules) 128 | 129 | 130 | class S2Snet(nn.Module): 131 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection'): 132 | super(S2Snet, self).__init__() 133 | enc_ch = [48, 48, 48, 48, 48] # fixed 134 | dec_ch = [96, 96, 96, 96, 96] # fixed 135 | self.upsample = nn.Upsample(scale_factor=2) 136 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 137 | self.concat = Concat_layer(1) 138 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad) # h/2, w/2 139 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad) # h/4, w/4 140 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad) # h/8, w/8 141 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, 142 | pad=pad) # h/16, w/16 143 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, 144 | pad=pad) # h/32, w/32 145 | 146 | self.conv = nn.Sequential( 147 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad=pad), 148 | bn(dec_ch[4]), 149 | act(act_type)) 150 | 151 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun=act_type, pad=pad) 152 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun=act_type, pad=pad) 153 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun=act_type, pad=pad) 154 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun=act_type, pad=pad) 155 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun=act_type, pad=pad) # concat 156 | self.Sig = nn.Sigmoid() 157 | 158 | def forward(self, x): 159 | # encoding path 160 | x1 = self.Conv1(x) # h/2 w/2 161 | x2 = self.Conv2(x1) # h/4 w/4 162 | x3 = self.Conv3(x2) # h/8 w/8 163 | x4 = self.Conv4(x3) # h/16 w/16 164 | x5 = self.Conv5(x4) # h/32 w/32 165 | 166 | x6 = self.conv(x5) 167 | 168 | d5 = self.upsample(x6) # h/16 w/16 169 | d5 = self.concat([d5, x4]) 170 | d4 = self.Up_conv4(d5) 171 | 172 | d4 = self.upsample(d4) # h/8 w/8 173 | d4 = self.concat([d4, x3]) 174 | d3 = self.Up_conv3(d4) 175 | 176 | d3 = self.upsample(d3) # h/4 w/4 177 | d3 = self.concat([d3, x2]) 178 | d2 = self.Up_conv2(d3) 179 | 180 | d2 = self.upsample(d2) # h/2 w/2 181 | d2 = self.concat([d2, x1]) 182 | d1 = self.Up_conv1(d2) 183 | 184 | d0 = self.upsample(d1) # h w 185 | d0 = self.concat([d0, x]) 186 | d0 = self.Conv_last1(d0) 187 | return self.Sig(d0) 188 | 189 | 190 | class s2s_(nn.Module): 191 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "bn"): 192 | super(s2s_, self).__init__() 193 | enc_ch = [48, 48, 48, 48, 48] # fixed 194 | dec_ch = [96, 96, 96, 96, 96] # fixed 195 | self.upsample = nn.Upsample(scale_factor=2) 196 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 197 | self.concat = Concat_layer(1) 198 | self.Conv_first1 = conv_block_last(img_ch, enc_ch[0], act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) # concat 199 | self.Conv1 = conv_block_sp(ch_in=img_ch + enc_ch[0], ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) # h/2, w/2 200 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) # h/4, w/4 201 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) # h/8, w/8 202 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, 203 | pad=pad, bn_mode=bn_mode) # h/16, w/16 204 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, 205 | pad=pad, bn_mode=bn_mode) # h/32, w/32 206 | 207 | self.conv = nn.Sequential( 208 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=False, pad=pad), 209 | bn(dec_ch[4], bn_mode), 210 | act(act_type)) 211 | 212 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) 213 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) 214 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) 215 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) 216 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun=act_type, pad=pad, bn_mode=bn_mode, bias=False) # concat 217 | 218 | self.Sig = nn.Sigmoid() 219 | if img_ch == 4: 220 | self.mean = torch.tensor([0.406, 0.456, 0.485, 0]).view(1, 4, 1, 1) 221 | self.std = torch.tensor([0.225, 0.224, 0.229, 1]).view(1, 4, 1, 1) 222 | if img_ch == 3: 223 | self.mean = torch.tensor([0.406, 0.456, 0.485]).view(1, 3, 1, 1) 224 | self.std = torch.tensor([0.225, 0.224, 0.229]).view(1, 3, 1, 1) 225 | else: 226 | self.mean = torch.tensor([0.449]) 227 | self.std = torch.tensor([0.226]) 228 | 229 | def forward(self, x): 230 | x = (x - self.mean.to(x.device)) / self.std.to(x.device) 231 | # encoding path 232 | x0 = self.Conv_first1(x) 233 | x0 = self.concat([x0, x]) 234 | x1 = self.Conv1(x0) # h/2 w/2 235 | x2 = self.Conv2(x1) # h/4 w/4 236 | x3 = self.Conv3(x2) # h/8 w/8 237 | x4 = self.Conv4(x3) # h/16 w/16 238 | x5 = self.Conv5(x4) # h/32 w/32 239 | 240 | x6 = self.conv(x5) 241 | 242 | d5 = self.upsample(x6) # h/16 w/16 243 | d5 = self.concat([d5, x4]) 244 | d4 = self.Up_conv4(d5) 245 | 246 | d4 = self.upsample(d4) # h/8 w/8 247 | d4 = self.concat([d4, x3]) 248 | d3 = self.Up_conv3(d4) 249 | 250 | d3 = self.upsample(d3) # h/4 w/4 251 | d3 = self.concat([d3, x2]) 252 | d2 = self.Up_conv2(d3) 253 | 254 | d2 = self.upsample(d2) # h/2 w/2 255 | d2 = self.concat([d2, x1]) 256 | d1 = self.Up_conv1(d2) 257 | 258 | d0 = self.upsample(d1) # h w 259 | d0 = self.concat([d0, x]) 260 | d0 = self.Conv_last1(d0) 261 | return self.Sig(d0) 262 | 263 | class s2s_fixed(nn.Module): 264 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "bn"): 265 | super(s2s_fixed, self).__init__() 266 | enc_ch = [48, 48, 48, 48, 48] # fixed 267 | dec_ch = [96, 96, 96, 96, 96] # fixed 268 | self.upsample = nn.Upsample(scale_factor=2) 269 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 270 | self.concat = Concat_layer(1) 271 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/2, w/2 272 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/4, w/4 273 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/8, w/8 274 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, 275 | pad=pad, bn_mode=bn_mode) # h/16, w/16 276 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, 277 | pad=pad, bn_mode=bn_mode) # h/32, w/32 278 | 279 | self.conv = nn.Sequential( 280 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad=pad), 281 | bn(dec_ch[4], bn_mode), 282 | act(act_type)) 283 | 284 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun=act_type, pad=pad, bn_mode=bn_mode) 285 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun=act_type, pad=pad, bn_mode=bn_mode) 286 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun=act_type, pad=pad, bn_mode=bn_mode) 287 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun=act_type, pad=pad, bn_mode=bn_mode) 288 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun=act_type, pad=pad, bn_mode=bn_mode) # concat 289 | self.Sig = nn.Sigmoid() 290 | 291 | def forward(self, x): 292 | # encoding path 293 | x1 = self.Conv1(x) # h/2 w/2 294 | x2 = self.Conv2(x1) # h/4 w/4 295 | x3 = self.Conv3(x2) # h/8 w/8 296 | x4 = self.Conv4(x3) # h/16 w/16 297 | x5 = self.Conv5(x4) # h/32 w/32 298 | 299 | x6 = self.conv(x5) 300 | 301 | d5 = self.upsample(x6) # h/16 w/16 302 | d5 = self.concat([d5, x4]) 303 | d4 = self.Up_conv4(d5) 304 | 305 | d4 = self.upsample(d4) # h/8 w/8 306 | d4 = self.concat([d4, x3]) 307 | d3 = self.Up_conv3(d4) 308 | 309 | d3 = self.upsample(d3) # h/4 w/4 310 | d3 = self.concat([d3, x2]) 311 | d2 = self.Up_conv2(d3) 312 | 313 | d2 = self.upsample(d2) # h/2 w/2 314 | d2 = self.concat([d2, x1]) 315 | d1 = self.Up_conv1(d2) 316 | 317 | d0 = self.upsample(d1) # h w 318 | d0 = self.concat([d0, x]) 319 | d0 = self.Conv_last1(d0) 320 | return self.Sig(d0) 321 | 322 | class s2s_normal(nn.Module): 323 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "bn"): 324 | super(s2s_normal, self).__init__() 325 | enc_ch = [48, 48, 48, 48, 48] # fixed 326 | dec_ch = [96, 96, 96, 96, 96] # fixed 327 | self.upsample = nn.Upsample(scale_factor=2) 328 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 329 | self.concat = Concat_layer(1) 330 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/2, w/2 331 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/4, w/4 332 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad, bn_mode=bn_mode) # h/8, w/8 333 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, 334 | pad=pad, bn_mode=bn_mode) # h/16, w/16 335 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, 336 | pad=pad, bn_mode=bn_mode) # h/32, w/32 337 | 338 | self.conv = nn.Sequential( 339 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad=pad), 340 | bn(dec_ch[4], bn_mode), 341 | act(act_type)) 342 | 343 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun=act_type, pad=pad, bn_mode=bn_mode) 344 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun=act_type, pad=pad, bn_mode=bn_mode) 345 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun=act_type, pad=pad, bn_mode=bn_mode) 346 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun=act_type, pad=pad, bn_mode=bn_mode) 347 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun=act_type, pad=pad, bn_mode=bn_mode) # concat 348 | self.Sig = nn.Sigmoid() 349 | if img_ch == 4: 350 | self.mean = torch.tensor([0.406, 0.456, 0.485, 0]).view(1, 4, 1, 1) 351 | self.std = torch.tensor([0.225, 0.224, 0.229, 1]).view(1, 4, 1, 1) 352 | if img_ch == 3: 353 | self.mean = torch.tensor([0.406, 0.456, 0.485]).view(1, 3, 1, 1) 354 | self.std = torch.tensor([0.225, 0.224, 0.229]).view(1, 3, 1, 1) 355 | else: 356 | self.mean = torch.tensor([0.449]) 357 | self.std = torch.tensor([0.226]) 358 | 359 | def forward(self, x): 360 | x = (x - self.mean.to(x.device)) / self.std.to(x.device) 361 | # encoding path 362 | x1 = self.Conv1(x) # h/2 w/2 363 | x2 = self.Conv2(x1) # h/4 w/4 364 | x3 = self.Conv3(x2) # h/8 w/8 365 | x4 = self.Conv4(x3) # h/16 w/16 366 | x5 = self.Conv5(x4) # h/32 w/32 367 | 368 | x6 = self.conv(x5) 369 | 370 | d5 = self.upsample(x6) # h/16 w/16 371 | d5 = self.concat([d5, x4]) 372 | d4 = self.Up_conv4(d5) 373 | 374 | d4 = self.upsample(d4) # h/8 w/8 375 | d4 = self.concat([d4, x3]) 376 | d3 = self.Up_conv3(d4) 377 | 378 | d3 = self.upsample(d3) # h/4 w/4 379 | d3 = self.concat([d3, x2]) 380 | d2 = self.Up_conv2(d3) 381 | 382 | d2 = self.upsample(d2) # h/2 w/2 383 | d2 = self.concat([d2, x1]) 384 | d1 = self.Up_conv1(d2) 385 | 386 | d0 = self.upsample(d1) # h w 387 | d0 = self.concat([d0, x]) 388 | d0 = self.Conv_last1(d0) 389 | return self.Sig(d0) 390 | 391 | class s2s_g(nn.Module): 392 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "groupNorm"): 393 | super(s2s_g, self).__init__() 394 | enc_ch = [24, 24, 24, 24, 24] # fixed 395 | dec_ch = [24, 24, 24, 24, 24] # fixed 396 | self.img_ch = img_ch 397 | self.upsample = nn.Upsample(scale_factor=2) 398 | self.concat = Concat_layer(1) 399 | self.Conv1 = conv_block_sp(ch_in=img_ch *4, ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad, group=4, 400 | bn_mode=bn_mode) # h/2, w/2 401 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad, group=4, 402 | bn_mode=bn_mode) # h/4, w/4 403 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad, group=4, 404 | bn_mode=bn_mode) # h/8, w/8 405 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, pad=pad, group=4, 406 | bn_mode=bn_mode) # h/16, w/16 407 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, pad=pad, group=4, 408 | bn_mode=bn_mode) # h/32, w/32 409 | 410 | self.conv = nn.Sequential( 411 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad=pad), 412 | bn(dec_ch[4], bn_mode), 413 | act(act_type)) 414 | 415 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun=act_type, pad=pad, 416 | group=4, bn_mode=bn_mode) 417 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun=act_type, pad=pad, 418 | group=4, bn_mode=bn_mode) 419 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun=act_type, pad=pad, 420 | group=4, bn_mode=bn_mode) 421 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun=act_type, pad=pad, 422 | group=4, bn_mode=bn_mode) 423 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch * 4, output_ch * 4, act_fun=act_type, pad=pad, group= 4, 424 | bn_mode=bn_mode) # concat 425 | self.Sig = nn.Sigmoid() 426 | 427 | def group_concat(self, d, x, stride1=6, stride2=6): 428 | return self.concat([d[:,:stride1,:,:], x[:,:stride2,:,:], d[:,stride1:2*stride1,:,:], x[:,stride2:2*stride2,:,:], 429 | d[:,2*stride1:3*stride1,:,:], x[:,2*stride2:3*stride2,:,:], d[:,3*stride1:4*stride1,:,:], x[:,3*stride2:4*stride2,:,:]]) 430 | 431 | def group_mean(self, out, stride=3): 432 | return (out[:,:stride] + out[:,stride:2*stride] + out[:,2*stride:3*stride] + out[:,3*stride:4*stride])/4. 433 | 434 | def forward(self, x): 435 | x = (x - x.mean()) / (x.std()) 436 | x_tmp = self.concat([x, x, x, x]) 437 | # encoding path 438 | x1 = self.Conv1(x_tmp) # h/2 w/2 439 | x2 = self.Conv2(x1) # h/4 w/4 440 | x3 = self.Conv3(x2) # h/8 w/8 441 | x4 = self.Conv4(x3) # h/16 w/16 442 | x5 = self.Conv5(x4) # h/32 w/32 443 | 444 | x6 = self.conv(x5) 445 | 446 | d5 = self.upsample(x6) # h/16 w/16 447 | d5 = self.group_concat(d5, x4) 448 | d4 = self.Up_conv4(d5) 449 | 450 | d4 = self.upsample(d4) # h/8 w/8 451 | d4 = self.group_concat(d4, x3) 452 | d3 = self.Up_conv3(d4) 453 | 454 | d3 = self.upsample(d3) # h/4 w/4 455 | d3 = self.group_concat(d3, x2) 456 | d2 = self.Up_conv2(d3) 457 | 458 | d2 = self.upsample(d2) # h/2 w/2 459 | d2 = self.group_concat(d2, x1) 460 | d1 = self.Up_conv1(d2) 461 | 462 | d0 = self.upsample(d1) # h w 463 | 464 | d0 = self.group_concat(d0, x_tmp, stride2 = self.img_ch) 465 | d0 = self.Conv_last1(d0) 466 | out = self.Sig(d0) 467 | return self.group_mean(out) 468 | 469 | 470 | 471 | class Attention_block(nn.Module): 472 | def __init__(self, F_g, F_l, F_int): 473 | super(Attention_block, self).__init__() 474 | self.W_g = nn.Sequential( 475 | nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), 476 | nn.BatchNorm2d(F_int) 477 | ) 478 | 479 | self.W_x = nn.Sequential( 480 | nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), 481 | nn.BatchNorm2d(F_int) 482 | ) 483 | 484 | self.psi = nn.Sequential( 485 | nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), 486 | nn.BatchNorm2d(1), 487 | nn.Sigmoid() 488 | ) 489 | 490 | self.relu = nn.ReLU(inplace=True) 491 | 492 | def forward(self, g, x): 493 | inputs_shapes2 = [x.shape[2] for x in [g,x]] 494 | inputs_shapes3 = [x.shape[3] for x in [g,x]] 495 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all( 496 | np.array(inputs_shapes3) == min(inputs_shapes3)): 497 | pass 498 | else: 499 | target_shape2 = min(inputs_shapes2) 500 | target_shape3 = min(inputs_shapes3) 501 | diff2 = (g.size(2) - target_shape2) // 2 502 | diff3 = (g.size(3) - target_shape3) // 2 503 | g = g[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3] 504 | diff2 = (g.size(2) - target_shape2) // 2 505 | diff3 = (g.size(3) - target_shape3) // 2 506 | x = x[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3] 507 | 508 | g1 = self.W_g(g) 509 | x1 = self.W_x(x) 510 | psi = self.relu(g1 + x1) 511 | psi = self.psi(psi) 512 | return x * psi 513 | 514 | class S2SATnet1(nn.Module): 515 | def __init__(self, img_ch=3, output_ch=3, act_type = "LeakyReLU", bn_mode = "bn", pad='reflection'): 516 | super(S2SATnet1, self).__init__() 517 | enc_ch = [48, 48, 48, 48, 48] # fixed 518 | dec_ch = [96, 96, 96, 96, 96] # fixed 519 | print("[*] input/output channel : %d / %d" % (img_ch, output_ch)) 520 | print("[*] act_type : %s" % act_type) 521 | print("[*] bn_type : %s" % bn_mode) 522 | self.upsample = nn.Upsample(scale_factor=2) 523 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 524 | self.concat = Concat_layer(1) 525 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/2, w/2 526 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/4, w/4 527 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/8, w/8 528 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/16, w/16 529 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/32, w/32 530 | 531 | self.conv = nn.Sequential( 532 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad = pad), 533 | bn(dec_ch[4], bn_mode), 534 | act(act_type)) 535 | 536 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun= act_type, bn_mode=bn_mode, pad= pad) 537 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun= act_type, bn_mode=bn_mode, pad= pad) 538 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun= act_type, bn_mode=bn_mode, pad= pad) 539 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun= act_type, bn_mode=bn_mode, pad= pad) 540 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun= act_type, bn_mode=bn_mode, pad = pad) # concat 541 | self.Sig = nn.Sigmoid() 542 | 543 | self.Att1 = Attention_block(dec_ch[0], enc_ch[0], 48) 544 | self.Att2 = Attention_block(dec_ch[1], enc_ch[1], 48) 545 | self.Att3 = Attention_block(dec_ch[2], enc_ch[2], 48) 546 | self.Att4 = Attention_block(dec_ch[3], enc_ch[3], 48) 547 | self.Att5 = Attention_block(dec_ch[4], enc_ch[4], 48) 548 | 549 | 550 | def forward(self, x): 551 | # encoding path 552 | x1 = self.Conv1(x) # h/2 w/2 553 | x2 = self.Conv2(x1) # h/4 w/4 554 | x3 = self.Conv3(x2) # h/8 w/8 555 | x4 = self.Conv4(x3) # h/16 w/16 556 | x5 = self.Conv5(x4) # h/32 w/32 557 | 558 | x6 = self.conv(x5) 559 | 560 | d5 = self.upsample(x6) # h/16 w/16 561 | x4 = self.Att5(g=d5, x=x4) 562 | d5 = self.concat([d5, x4]) 563 | d4 = self.Up_conv4(d5) 564 | 565 | d4 = self.upsample(d4) # h/8 w/8 566 | x3 = self.Att4(g=d4, x=x3) 567 | d4 = self.concat([d4, x3]) 568 | d3 = self.Up_conv3(d4) 569 | 570 | d3 = self.upsample(d3) # h/4 w/4 571 | x2 = self.Att3(g=d3, x=x2) 572 | d3 = self.concat([d3, x2]) 573 | d2 = self.Up_conv2(d3) 574 | 575 | d2 = self.upsample(d2) # h/2 w/2 576 | x1 = self.Att2(g=d2, x=x1) 577 | d2 = self.concat([d2, x1]) 578 | d1 = self.Up_conv1(d2) 579 | 580 | d0 = self.upsample(d1) # h w 581 | d0 = self.concat([d0, x]) 582 | d0 = self.Conv_last1(d0) 583 | return self.Sig(d0) 584 | 585 | class S2SATnet2(nn.Module): 586 | def __init__(self, img_ch=3, output_ch=3, act_type = "LeakyReLU", bn_mode = "bn", pad='reflection'): 587 | super(S2SATnet2, self).__init__() 588 | enc_ch = [48, 48, 48, 48, 48] # fixed 589 | dec_ch = [96, 96, 96, 96, 96] # fixed 590 | print("[*] input/output channel : %d / %d" % (img_ch, output_ch)) 591 | print("[*] act_type : %s" % act_type) 592 | print("[*] bn_type : %s" % bn_mode) 593 | self.upsample = nn.Upsample(scale_factor=2) 594 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 595 | self.concat = Concat_layer(1) 596 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/2, w/2 597 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/4, w/4 598 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/8, w/8 599 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/16, w/16 600 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun= act_type, bn_mode=bn_mode, pad = pad) # h/32, w/32 601 | 602 | self.conv = nn.Sequential( 603 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad = pad), 604 | bn(dec_ch[4], bn_mode), 605 | act(act_type)) 606 | 607 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3], ch_out=dec_ch[3], act_fun= act_type, bn_mode=bn_mode, pad= pad) 608 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2], ch_out=dec_ch[2], act_fun= act_type, bn_mode=bn_mode, pad= pad) 609 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1], ch_out=dec_ch[1], act_fun= act_type, bn_mode=bn_mode, pad= pad) 610 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0], ch_out=dec_ch[0], act_fun= act_type, bn_mode=bn_mode, pad= pad) 611 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun= act_type, bn_mode=bn_mode, pad = pad) # concat 612 | self.Sig = nn.Sigmoid() 613 | 614 | self.Att1 = Attention_block(enc_ch[0], dec_ch[0], 48) 615 | self.Att2 = Attention_block(enc_ch[1], dec_ch[1], 48) 616 | self.Att3 = Attention_block(enc_ch[2], dec_ch[2], 48) 617 | self.Att4 = Attention_block(enc_ch[3], dec_ch[3], 48) 618 | self.Att5 = Attention_block(enc_ch[4], dec_ch[4], 48) 619 | 620 | 621 | def forward(self, x): 622 | # encoding path 623 | x1 = self.Conv1(x) # h/2 w/2 624 | x2 = self.Conv2(x1) # h/4 w/4 625 | x3 = self.Conv3(x2) # h/8 w/8 626 | x4 = self.Conv4(x3) # h/16 w/16 627 | x5 = self.Conv5(x4) # h/32 w/32 628 | 629 | x6 = self.conv(x5) 630 | 631 | d5 = self.upsample(x6) # h/16 w/16 632 | d5 = self.Att5(g=x4, x=d5) 633 | d5 = self.concat([d5, x4]) 634 | d4 = self.Up_conv4(d5) 635 | 636 | d4 = self.upsample(d4) # h/8 w/8 637 | d4 = self.Att4(g=x3, x=d4) 638 | d4 = self.concat([d4, x3]) 639 | d3 = self.Up_conv3(d4) 640 | 641 | d3 = self.upsample(d3) # h/4 w/4 642 | d3 = self.Att3(g=x2, x=d3) 643 | d3 = self.concat([d3, x2]) 644 | d2 = self.Up_conv2(d3) 645 | 646 | d2 = self.upsample(d2) # h/2 w/2 647 | d2 = self.Att2(g=x1, x=d2) 648 | d2 = self.concat([d2, x1]) 649 | d1 = self.Up_conv1(d2) 650 | 651 | d0 = self.upsample(d1) # h w 652 | d0 = self.concat([d0, x]) 653 | d0 = self.Conv_last1(d0) 654 | return self.Sig(d0) 655 | 656 | 657 | 658 | 659 | 660 | class Testnet(nn.Module): 661 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection'): 662 | super(Testnet, self).__init__() 663 | enc_ch = [48, 48, 48, 48, 48] # fixed 664 | dec_ch = [48, 48, 48, 48, 48] # fixed 665 | 666 | self.upsample = nn.Upsample(scale_factor=2) 667 | self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) 668 | self.Avgpool = nn.AvgPool2d(kernel_size=2, stride=2) 669 | self.concat = Concat_layer(1) 670 | self.Conv1 = conv_block_sp(ch_in=img_ch, ch_out=enc_ch[0], down=True, act_fun=act_type, pad=pad) # h/2, w/2 671 | self.Conv2 = conv_block_sp(ch_in=enc_ch[0], ch_out=enc_ch[1], down=True, act_fun=act_type, pad=pad) # h/4, w/4 672 | self.Conv3 = conv_block_sp(ch_in=enc_ch[1], ch_out=enc_ch[2], down=True, act_fun=act_type, pad=pad) # h/8, w/8 673 | self.Conv4 = conv_block_sp(ch_in=enc_ch[2], ch_out=enc_ch[3], down=True, act_fun=act_type, 674 | pad=pad) # h/16, w/16 675 | self.Conv5 = conv_block_sp(ch_in=enc_ch[3], ch_out=enc_ch[4], down=True, act_fun=act_type, 676 | pad=pad) # h/32, w/32 677 | 678 | self.conv = nn.Sequential( 679 | conv(enc_ch[4], dec_ch[4], kernel_size=3, stride=1, bias=True, pad=pad), 680 | bn(dec_ch[4]), 681 | act(act_type)) 682 | 683 | self.Up_conv4 = conv_block_concat(ch_in=dec_ch[4] + enc_ch[3] + img_ch, ch_out=dec_ch[3], act_fun=act_type, pad=pad) 684 | self.Up_conv3 = conv_block_concat(ch_in=dec_ch[3] + enc_ch[2] + img_ch, ch_out=dec_ch[2], act_fun=act_type, pad=pad) 685 | self.Up_conv2 = conv_block_concat(ch_in=dec_ch[2] + enc_ch[1] + img_ch, ch_out=dec_ch[1], act_fun=act_type, pad=pad) 686 | self.Up_conv1 = conv_block_concat(ch_in=dec_ch[1] + enc_ch[0] + img_ch, ch_out=dec_ch[0], act_fun=act_type, pad=pad) 687 | self.Conv_last1 = conv_block_last(dec_ch[0] + img_ch, output_ch, act_fun=act_type, pad=pad) # concat 688 | self.Sig = nn.Sigmoid() 689 | 690 | def forward(self, x): 691 | # encoding path 692 | x_1 = self.Avgpool(x) 693 | x_2 = self.Avgpool(x_1) 694 | x_3 = self.Avgpool(x_2) 695 | x_4 = self.Avgpool(x_3) 696 | 697 | x1 = self.Conv1(x) # h/2 w/2 698 | x2 = self.Conv2(x1) + self.Avgpool(x1) # h/4 w/4 699 | x3 = self.Conv3(x2) + self.Avgpool(x2) # h/8 w/8 700 | x4 = self.Conv4(x3) + self.Avgpool(x3) # h/16 w/16 701 | x5 = self.Conv5(x4) + self.Avgpool(x4) # h/32 w/32 702 | 703 | x6 = self.conv(x5) 704 | 705 | d5 = self.upsample(x6) # h/16 w/16 706 | d5 = self.concat([d5, x4, x_4]) 707 | d4 = self.Up_conv4(d5) 708 | 709 | d4 = self.upsample(d4) # h/8 w/8 710 | d4 = self.concat([d4, x3, x_3]) 711 | d3 = self.Up_conv3(d4) 712 | 713 | d3 = self.upsample(d3) # h/4 w/4 714 | d3 = self.concat([d3, x2, x_2]) 715 | d2 = self.Up_conv2(d3) 716 | 717 | d2 = self.upsample(d2) # h/2 w/2 718 | d2 = self.concat([d2, x1, x_1]) 719 | d1 = self.Up_conv1(d2) 720 | 721 | d0 = self.upsample(d1) # h w 722 | d0 = self.concat([d0, x]) 723 | d0 = self.Conv_last1(d0) 724 | return self.Sig(d0) 725 | 726 | if __name__ == '__main__': 727 | x = torch.rand([1, 3, 481, 321]) 728 | net = s2s_() 729 | 730 | print(net(x).shape) 731 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .skip import skip 2 | from .texture_nets import get_texture_nets 3 | from .resnet import ResNet 4 | from .unet import UNet 5 | from .S2Snet import * 6 | from .snet import * 7 | import torch.nn as nn 8 | import torchvision 9 | import torch 10 | 11 | class torchmodel(nn.Module): 12 | def __init__(self, model): 13 | super(torchmodel, self).__init__() 14 | self.s = nn.Sigmoid() 15 | self.model = model 16 | 17 | def forward(self, x): 18 | x_ = torch.cat([x,x]) 19 | out = self.s(self.model(x_)['out']) 20 | out = torch.mean(out, dim = 0, keepdim=True) 21 | return out 22 | 23 | def get_net(args): 24 | pad = 'reflection' 25 | input_depth = 1 if args.gray else 3 26 | if args.noisy_map: 27 | input_depth += 1 28 | n_channels = 1 if args.gray else 3 29 | NET_TYPE = args.net_type 30 | act_fun = args.act_func 31 | upsample_mode = 'bilinear' 32 | downsample_mode = 'stride' 33 | sigmoid = True 34 | 35 | if NET_TYPE == 'skip': 36 | skip_n33d = skip_n33u = args.hidden_layer 37 | skip_n11 = 4 38 | num_scales = 5 39 | print("[*] Net_type : skip with %d layer" % skip_n33d) 40 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 41 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 42 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 43 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 44 | need_sigmoid=sigmoid, need_bias=True, pad=pad, act_fun=act_fun) 45 | elif NET_TYPE == 's2s': 46 | print("[*] Net_type : s2s") 47 | net = S2Snet(input_depth, n_channels, act_type=act_fun) 48 | elif NET_TYPE == 's2s_fixed': 49 | print("[*] Net_type : s2s_fixed") 50 | net = s2s_fixed(input_depth, n_channels, act_type= act_fun, bn_mode = args.bn_type) 51 | elif NET_TYPE == 'texture_nets': 52 | net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) 53 | elif NET_TYPE =='UNet': 54 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 55 | feature_scale=4, more_layers=0, concat_x=False, 56 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 57 | elif NET_TYPE == 'dncnn': 58 | net = dncnn_s() 59 | elif NET_TYPE == 's2s96': 60 | net = S2Snet96(input_depth, n_channels, act_type= act_fun, bn_mode = args.bn_type) 61 | elif NET_TYPE == 's2sW': 62 | net = S2SnetW(input_depth, n_channels, act_type= act_fun) 63 | elif NET_TYPE == 's2sT': 64 | net = S2SnetT(input_depth, n_channels, act_type= act_fun) 65 | elif NET_TYPE == 's2s_': 66 | net = s2s_(input_depth, n_channels, act_type=act_fun) 67 | elif NET_TYPE == 's2s_g': 68 | net = s2s_g(input_depth, n_channels, act_type=act_fun) 69 | elif NET_TYPE == 's2s_4x': 70 | net = snet_4x(input_depth, n_channels, act_type= act_fun) 71 | elif NET_TYPE == 's2s_normal': 72 | net = s2s_normal(input_depth, n_channels, act_type=act_fun) 73 | elif NET_TYPE == 'S2SATnet1': 74 | net = S2SATnet1(input_depth, n_channels, act_type=act_fun) 75 | elif NET_TYPE == 'S2SATnet2': 76 | net = S2SATnet2(input_depth, n_channels, act_type= act_fun) 77 | 78 | 79 | elif NET_TYPE == 'Testnet': 80 | net = Testnet(input_depth, n_channels, act_type= act_fun) 81 | elif NET_TYPE == "MemNet": 82 | net = MemNet(input_depth, n_channels, 4, 6) 83 | else: 84 | assert False 85 | 86 | return net.type(args.dtype) 87 | -------------------------------------------------------------------------------- /models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .downsampler import Downsampler 5 | from .iterBN import IterNorm 6 | 7 | def add_module(self, module): 8 | self.add_module(str(len(self) + 1), module) 9 | 10 | torch.nn.Module.add = add_module 11 | 12 | class Concat(nn.Module): 13 | def __init__(self, dim, *args): 14 | super(Concat, self).__init__() 15 | self.dim = dim 16 | 17 | for idx, module in enumerate(args): 18 | self.add_module(str(idx), module) 19 | 20 | def forward(self, input): 21 | inputs = [] 22 | for module in self._modules.values(): 23 | inputs.append(module(input)) 24 | 25 | inputs_shapes2 = [x.shape[2] for x in inputs] 26 | inputs_shapes3 = [x.shape[3] for x in inputs] 27 | 28 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 29 | inputs_ = inputs 30 | else: 31 | target_shape2 = min(inputs_shapes2) 32 | target_shape3 = min(inputs_shapes3) 33 | 34 | inputs_ = [] 35 | for inp in inputs: 36 | diff2 = (inp.size(2) - target_shape2) // 2 37 | diff3 = (inp.size(3) - target_shape3) // 2 38 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 39 | 40 | return torch.cat(inputs_, dim=self.dim) 41 | 42 | def __len__(self): 43 | return len(self._modules) 44 | 45 | 46 | class GenNoise(nn.Module): 47 | def __init__(self, dim2): 48 | super(GenNoise, self).__init__() 49 | self.dim2 = dim2 50 | 51 | def forward(self, input): 52 | a = list(input.size()) 53 | a[1] = self.dim2 54 | # print (input.data.type()) 55 | 56 | b = torch.zeros(a).type_as(input.data) 57 | b.normal_() 58 | 59 | x = torch.autograd.Variable(b) 60 | 61 | return x 62 | 63 | 64 | class Swish(nn.Module): 65 | """ 66 | https://arxiv.org/abs/1710.05941 67 | The hype was so huge that I could not help but try it 68 | """ 69 | def __init__(self): 70 | super(Swish, self).__init__() 71 | self.s = nn.Sigmoid() 72 | 73 | def forward(self, x): 74 | return x * self.s(x) 75 | 76 | class Tanh(nn.Module): 77 | """ 78 | https://arxiv.org/abs/1710.05941 79 | The hype was so huge that I could not help but try it 80 | """ 81 | def __init__(self): 82 | super(Tanh, self).__init__() 83 | 84 | def forward(self, x): 85 | return torch.tanh(x) 86 | 87 | class Sin(nn.Module): 88 | """ 89 | https://arxiv.org/abs/1710.05941 90 | The hype was so huge that I could not help but try it 91 | """ 92 | def __init__(self): 93 | super(Sin, self).__init__() 94 | 95 | def forward(self, x): 96 | return torch.sin(x) 97 | 98 | def act(act_fun = 'LeakyReLU'): 99 | ''' 100 | Either string defining an activation function or module (e.g. nn.ReLU) 101 | ''' 102 | if isinstance(act_fun, str): 103 | if act_fun == 'LeakyReLU': 104 | return nn.LeakyReLU(0.2, inplace=True) 105 | elif act_fun == 'Swish': 106 | return Swish() 107 | elif act_fun[:3] == 'ELU': 108 | if len(act_fun)> 3: 109 | param = float(act_fun[3:]) 110 | return nn.ELU(param, inplace=True) 111 | return nn.ELU(inplace=True) 112 | elif act_fun == 'ReLU': 113 | return nn.ReLU() 114 | elif act_fun == 'tanh': 115 | return Tanh() 116 | elif act_fun == 'sine': 117 | return Sin() 118 | elif act_fun == 'soft': 119 | return nn.Softplus() 120 | elif act_fun == 'none': 121 | return nn.Sequential() 122 | else: 123 | assert False 124 | else: 125 | return act_fun() 126 | 127 | 128 | def bn(num_features, mode = "bn"): 129 | if mode == "bn": 130 | return nn.BatchNorm2d(num_features) 131 | elif mode == "bn_kai": 132 | return nn.BatchNorm2d(num_features, momentum=0.9, eps=1e-4) 133 | elif mode == "In": 134 | return nn.InstanceNorm2d(num_features, affine= True) 135 | elif mode == "None": 136 | return nn.Sequential() 137 | 138 | elif mode == "bn_kai7": 139 | return nn.BatchNorm2d(num_features, momentum=0.7, eps=1e-4) 140 | elif mode == "bn_kai5": 141 | return nn.BatchNorm2d(num_features, momentum=0.5, eps=1e-4) 142 | elif mode == "bn_eps": 143 | return nn.BatchNorm2d(num_features, eps=1e-4) 144 | elif mode == "iterbn": 145 | return IterNorm(num_features) 146 | 147 | elif mode == "groupNorm": 148 | return nn.GroupNorm(4, num_features) 149 | 150 | else : 151 | return None 152 | 153 | 154 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', group = 1, downsample_mode='stride'): 155 | downsampler = None 156 | if stride != 1 and downsample_mode != 'stride': 157 | 158 | if downsample_mode == 'avg': 159 | downsampler = nn.AvgPool2d(stride, stride) 160 | elif downsample_mode == 'max': 161 | downsampler = nn.MaxPool2d(stride, stride) 162 | elif downsample_mode in ['lanczos2', 'lanczos3']: 163 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 164 | else: 165 | assert False 166 | 167 | stride = 1 168 | 169 | padder = None 170 | to_pad = int((kernel_size - 1) / 2) 171 | if pad == 'reflection': 172 | padder = nn.ReflectionPad2d(to_pad) 173 | to_pad = 0 174 | 175 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, groups= group,bias=bias) 176 | 177 | 178 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 179 | return nn.Sequential(*layers) 180 | -------------------------------------------------------------------------------- /models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from utils.REDutils import fspecial_gauss 5 | 6 | 7 | class Downsampler(nn.Module): 8 | """ 9 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 10 | """ 11 | 12 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, 13 | preserve_size=False, pad_type='reflection', transpose_conv=False): 14 | super(Downsampler, self).__init__() 15 | 16 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 17 | 18 | if kernel_type == 'lanczos2': 19 | support = 2 20 | kernel_width = 4 * factor + 1 21 | kernel_type_ = 'lanczos' 22 | 23 | elif kernel_type == 'lanczos3': 24 | support = 3 25 | kernel_width = 6 * factor + 1 26 | kernel_type_ = 'lanczos' 27 | 28 | elif kernel_type == 'gauss12': 29 | kernel_width = 7 30 | sigma = 1 / 2 31 | kernel_type_ = 'gauss' 32 | 33 | elif kernel_type == 'gauss1sq2': 34 | kernel_width = 9 35 | sigma = 1. / np.sqrt(2) 36 | kernel_type_ = 'gauss' 37 | 38 | elif kernel_type == 'uniform_blur': 39 | kernel_width = 9 40 | kernel_type_ = 'uniform' 41 | pad_type = 'circular' 42 | 43 | elif kernel_type == 'gauss_blur': 44 | kernel_width = 25 45 | sigma = 1.6 46 | kernel_type_ = 'gauss' 47 | pad_type = 'circular' 48 | 49 | elif kernel_type in {'lanczos', 'gauss', 'box'}: 50 | kernel_type_ = kernel_type 51 | 52 | else: 53 | assert False, 'wrong name kernel' 54 | 55 | # note that `kernel width` will be different to actual size for phase = 1/2 56 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 57 | if transpose_conv: 58 | if self.kernel.shape[0] % 2 == 1: 59 | pad = int((self.kernel.shape[0] - 1) // 2.) 60 | else: 61 | pad = int((self.kernel.shape[0] - factor) // 2.) 62 | downsampler = nn.ConvTranspose2d(n_planes, n_planes, kernel_size=self.kernel.shape, 63 | stride=factor, padding=pad) 64 | else: 65 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 66 | downsampler.weight.data[:] = 0 67 | downsampler.bias.data[:] = 0 68 | 69 | kernel_torch = torch.from_numpy(self.kernel) 70 | for i in range(n_planes): 71 | downsampler.weight.data[i, i] = kernel_torch 72 | 73 | self.downsampler_ = downsampler 74 | 75 | if preserve_size: 76 | if pad_type == 'circular': 77 | self.padding = lambda torch_in: pad_circular(torch_in, kernel_width // 2) 78 | elif pad_type == 'reflection': 79 | if self.kernel.shape[0] % 2 == 1: 80 | pad = int((self.kernel.shape[0] - 1) // 2.) 81 | else: 82 | pad = int((self.kernel.shape[0] - factor) // 2.) 83 | self.padding = nn.ReplicationPad2d(pad) 84 | else: 85 | assert False, "pad_type have only circular or reflection options" 86 | self.preserve_size = preserve_size 87 | 88 | def forward(self, input): 89 | if self.preserve_size: 90 | x = self.padding(input) 91 | else: 92 | x = input 93 | self.x = x 94 | return self.downsampler_(x) 95 | 96 | 97 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 98 | assert kernel_type in ['lanczos', 'gauss', 'box', 'uniform', 'blur'] 99 | 100 | # factor = float(factor) 101 | if phase == 0.5 and kernel_type != 'box': 102 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 103 | else: 104 | kernel = np.zeros([kernel_width, kernel_width]) 105 | 106 | if kernel_type == 'box': 107 | assert phase == 0.5, 'Box filter is always half-phased' 108 | kernel[:] = 1. / (kernel_width * kernel_width) 109 | 110 | elif kernel_type == 'gauss': 111 | assert sigma, 'sigma is not specified' 112 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 113 | return fspecial_gauss(kernel_width, sigma) 114 | 115 | elif kernel_type == 'uniform': 116 | kernel = np.ones([kernel_width, kernel_width]) 117 | 118 | elif kernel_type == 'lanczos': 119 | assert support, 'support is not specified' 120 | center = (kernel_width + 1) / 2. 121 | 122 | for i in range(1, kernel.shape[0] + 1): 123 | for j in range(1, kernel.shape[1] + 1): 124 | 125 | if phase == 0.5: 126 | di = abs(i + 0.5 - center) / factor 127 | dj = abs(j + 0.5 - center) / factor 128 | else: 129 | di = abs(i - center) / factor 130 | dj = abs(j - center) / factor 131 | 132 | pi_sq = np.pi * np.pi 133 | 134 | val = 1 135 | if di != 0: 136 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 137 | val = val / (np.pi * np.pi * di * di) 138 | 139 | if dj != 0: 140 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 141 | val = val / (np.pi * np.pi * dj * dj) 142 | kernel[i - 1][j - 1] = val 143 | else: 144 | assert False, 'wrong method name' 145 | kernel /= kernel.sum() 146 | return kernel 147 | 148 | 149 | def pad_circular(x, pad): 150 | """ 151 | :param x: pytorch tensor of shape: [batch, ch, h, w] 152 | :param pad: uint 153 | :return: 154 | """ 155 | x = torch.cat([x, x[:, :, 0:pad]], dim=2) 156 | x = torch.cat([x, x[:, :, :, 0:pad]], dim=3) 157 | x = torch.cat([x[:, :, -2 * pad:-pad], x], dim=2) 158 | x = torch.cat([x[:, :, :, -2 * pad:-pad], x], dim=3) 159 | return x -------------------------------------------------------------------------------- /models/iterBN.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: Iterative Normalization: Beyond Standardization towards Efficient Whitening, CVPR 2019 3 | 4 | - Paper: 5 | - Code: https://github.com/huangleiBuaa/IterNorm 6 | 7 | ***** 8 | This implementation allows the number of featur maps is not divided by the channel number of per Group. E,g. one can use group size of 64 when the channel number is 80. (64 + 16) 9 | 10 | """ 11 | import torch.nn 12 | from torch.nn import Parameter 13 | 14 | # import extension._bcnn as bcnn 15 | 16 | __all__ = ['iterative_normalization_FlexGroup', 'IterNorm'] 17 | 18 | 19 | # 20 | # class iterative_normalization(torch.autograd.Function): 21 | # @staticmethod 22 | # def forward(ctx, *inputs): 23 | # result = bcnn.iterative_normalization_forward(*inputs) 24 | # ctx.save_for_backward(*result[:-1]) 25 | # return result[-1] 26 | # 27 | # @staticmethod 28 | # def backward(ctx, *grad_outputs): 29 | # grad, = grad_outputs 30 | # grad_input = bcnn.iterative_normalization_backward(grad, ctx.saved_variables) 31 | # return grad_input, None, None, None, None, None, None, None 32 | 33 | 34 | class iterative_normalization_py(torch.autograd.Function): 35 | @staticmethod 36 | def forward(ctx, *args, **kwargs): 37 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args 38 | # change NxCxHxW to Dx(NxHxW), i.e., d*m 39 | ctx.g = X.size(1) // nc 40 | x = X.transpose(0, 1).contiguous().view(nc, -1) 41 | d, m = x.size() 42 | saved = [] 43 | if training: 44 | # calculate centered activation by subtracted mini-batch mean 45 | mean = x.mean(-1, keepdim=True) 46 | xc = x - mean 47 | saved.append(xc) 48 | # calculate covariance matrix 49 | P = [None] * (ctx.T + 1) 50 | P[0] = torch.eye(d).to(X) 51 | # Sigma = torch.addmm(eps, P[0], 1. / m, xc, xc.transpose(0, 1)) 52 | #beta = 1, mat, alpha = 1, mat1, mat2, out = None 53 | # input, mat1, mat2, *, beta=1, alpha=1, out=None 54 | Sigma = torch.addmm(P[0], xc, xc.transpose(0, 1), beta=eps, alpha=1. / m) 55 | 56 | # reciprocal of trace of Sigma: shape [g, 1, 1] 57 | rTr = (Sigma * P[0]).sum((0, 1), keepdim=True).reciprocal_() 58 | saved.append(rTr) 59 | Sigma_N = Sigma * rTr 60 | saved.append(Sigma_N) 61 | for k in range(ctx.T): 62 | P[k + 1] = torch.addmm(P[k], torch.matrix_power(P[k], 3), Sigma_N, beta=1.5, alpha=-0.5) 63 | # P[k + 1] = torch.addmm(1.5, P[k], -0.5, torch.matrix_power(P[k], 3), Sigma_N) 64 | saved.extend(P) 65 | wm = P[ctx.T].mul_(rTr.sqrt()) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2} 66 | running_mean.copy_(momentum * mean + (1. - momentum) * running_mean) 67 | running_wmat.copy_(momentum * wm + (1. - momentum) * running_wmat) 68 | else: 69 | xc = x - running_mean 70 | wm = running_wmat 71 | xn = wm.mm(xc) 72 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 73 | ctx.save_for_backward(*saved) 74 | return Xn 75 | 76 | @staticmethod 77 | def backward(ctx, *grad_outputs): 78 | grad, = grad_outputs 79 | saved = ctx.saved_variables 80 | xc = saved[0] # centered input 81 | rTr = saved[1] # trace of Sigma 82 | sn = saved[2].transpose(-2, -1) # normalized Sigma 83 | P = saved[3:] # middle result matrix, 84 | d, m = xc.size() 85 | 86 | g_ = grad.transpose(0, 1).contiguous().view_as(xc) 87 | g_wm = g_.mm(xc.transpose(-2, -1)) 88 | g_P = g_wm * rTr.sqrt() 89 | wm = P[ctx.T] 90 | g_sn = 0 91 | for k in range(ctx.T, 1, -1): 92 | P[k - 1].transpose_(-2, -1) 93 | P2 = P[k - 1].mm(P[k - 1]) 94 | g_sn += P2.mm(P[k - 1]).mm(g_P) 95 | g_tmp = g_P.mm(sn) 96 | g_P.addmm_(g_tmp, P2, beta=1.5, alpha=-0.5) 97 | g_P.addmm_(P2, g_tmp, beta=1, alpha=-0.5) 98 | g_P.addmm_(P[k - 1].mm(g_tmp), P[k - 1], beta=1, alpha=-0.5) 99 | g_sn += g_P 100 | # g_sn = g_sn * rTr.sqrt() 101 | g_tr = ((-sn.mm(g_sn) + g_wm.transpose(-2, -1).mm(wm)) * P[0]).sum((0, 1), keepdim=True) * P[0] 102 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2. * g_tr) * (-0.5 / m * rTr) 103 | # g_sigma = g_sigma + g_sigma.transpose(-2, -1) 104 | g_x = torch.addmm(wm.mm(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc) 105 | grad_input = g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous() 106 | return grad_input, None, None, None, None, None, None, None 107 | 108 | 109 | class IterNorm_Single(torch.nn.Module): 110 | def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, 111 | *args, **kwargs): 112 | super(IterNorm_Single, self).__init__() 113 | # assert dim == 4, 'IterNorm is not support 2D' 114 | self.T = T 115 | self.eps = eps 116 | self.momentum = momentum 117 | self.num_features = num_features 118 | self.affine = affine 119 | self.dim = dim 120 | shape = [1] * dim 121 | shape[1] = self.num_features 122 | 123 | self.register_buffer('running_mean', torch.zeros(num_features, 1)) 124 | # running whiten matrix 125 | self.register_buffer('running_wm', torch.eye(num_features)) 126 | 127 | def forward(self, X: torch.Tensor): 128 | X_hat = iterative_normalization_py.apply(X, self.running_mean, self.running_wm, self.num_features, self.T, 129 | self.eps, self.momentum, self.training) 130 | return X_hat 131 | 132 | 133 | class IterNorm(torch.nn.Module): 134 | def __init__(self, num_features, num_groups=1, num_channels=None, T=5, dim=4, eps=1e-5, momentum=0.1, affine=True, 135 | *args, **kwargs): 136 | super(IterNorm, self).__init__() 137 | # assert dim == 4, 'IterNorm is not support 2D' 138 | self.T = T 139 | self.eps = eps 140 | self.momentum = momentum 141 | self.num_features = num_features 142 | self.num_channels = num_features 143 | #num_groups = (self.num_features - 1) // self.num_channels + 1 144 | self.num_groups = num_groups 145 | self.iterNorm_Groups = torch.nn.ModuleList( 146 | [IterNorm_Single(num_features=self.num_channels, eps=eps, momentum=momentum, T=T) for _ in 147 | range(self.num_groups - 1)] 148 | ) 149 | num_channels_last = self.num_features - self.num_channels * (self.num_groups - 1) 150 | self.iterNorm_Groups.append(IterNorm_Single(num_features=num_channels_last, eps=eps, momentum=momentum, T=T)) 151 | 152 | self.affine = affine 153 | self.dim = dim 154 | shape = [1] * dim 155 | shape[1] = self.num_features 156 | if self.affine: 157 | self.weight = Parameter(torch.Tensor(*shape)) 158 | self.bias = Parameter(torch.Tensor(*shape)) 159 | else: 160 | self.register_parameter('weight', None) 161 | self.register_parameter('bias', None) 162 | self.reset_parameters() 163 | 164 | def reset_parameters(self): 165 | # self.reset_running_stats() 166 | if self.affine: 167 | torch.nn.init.ones_(self.weight) 168 | torch.nn.init.zeros_(self.bias) 169 | 170 | def forward(self, X: torch.Tensor): 171 | X_splits = torch.split(X, self.num_channels, dim=1) 172 | X_hat_splits = [] 173 | for i in range(self.num_groups): 174 | X_hat_tmp = self.iterNorm_Groups[i](X_splits[i]) 175 | X_hat_splits.append(X_hat_tmp) 176 | X_hat = torch.cat(X_hat_splits, dim=1) 177 | # affine 178 | if self.affine: 179 | return X_hat * self.weight + self.bias 180 | else: 181 | return X_hat 182 | 183 | def extra_repr(self): 184 | return '{num_features}, num_channels={num_channels}, T={T}, eps={eps}, ' \ 185 | 'momentum={momentum}, affine={affine}'.format(**self.__dict__) 186 | 187 | 188 | if __name__ == '__main__': 189 | ItN = IterNorm(16, num_channels=4, T=10, momentum=1, affine=False) 190 | print(ItN) 191 | ItN.train() 192 | # x = torch.randn(32, 64, 14, 14) 193 | x = torch.randn(32, 16) 194 | x.requires_grad_() 195 | y = ItN(x) 196 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 197 | print(z.matmul(z.t()) / z.size(1)) 198 | 199 | y.sum().backward() 200 | print('x grad', x.grad.size()) 201 | 202 | ItN.eval() 203 | y = ItN(x) 204 | z = y.transpose(0, 1).contiguous().view(x.size(1), -1) 205 | print(z.matmul(z.t()) / z.size(1)) 206 | -------------------------------------------------------------------------------- /models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import normal 4 | from numpy.linalg import svd 5 | from math import sqrt 6 | import torch.nn.init 7 | from .common import * 8 | 9 | class ResidualSequential(nn.Sequential): 10 | def __init__(self, *args): 11 | super(ResidualSequential, self).__init__(*args) 12 | 13 | def forward(self, x): 14 | out = super(ResidualSequential, self).forward(x) 15 | # print(x.size(), out.size()) 16 | x_ = None 17 | if out.size(2) != x.size(2) or out.size(3) != x.size(3): 18 | diff2 = x.size(2) - out.size(2) 19 | diff3 = x.size(3) - out.size(3) 20 | # print(1) 21 | x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2] 22 | else: 23 | x_ = x 24 | return out + x_ 25 | 26 | def eval(self): 27 | print(2) 28 | for m in self.modules(): 29 | m.eval() 30 | exit() 31 | 32 | 33 | def get_block(num_channels, norm_layer, act_fun): 34 | layers = [ 35 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 36 | norm_layer(num_channels, affine=True), 37 | act(act_fun), 38 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 39 | norm_layer(num_channels, affine=True), 40 | ] 41 | return layers 42 | 43 | 44 | class ResNet(nn.Module): 45 | def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'): 46 | ''' 47 | pad = 'start|zero|replication' 48 | ''' 49 | super(ResNet, self).__init__() 50 | 51 | if need_residual: 52 | s = ResidualSequential 53 | else: 54 | s = nn.Sequential 55 | 56 | stride = 1 57 | # First layers 58 | layers = [ 59 | # nn.ReplicationPad2d(num_blocks * 2 * stride + 3), 60 | conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad), 61 | act(act_fun) 62 | ] 63 | # Residual blocks 64 | # layers_residual = [] 65 | for i in range(num_blocks): 66 | layers += [s(*get_block(num_channels, norm_layer, act_fun))] 67 | 68 | layers += [ 69 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 70 | norm_layer(num_channels, affine=True) 71 | ] 72 | 73 | # if need_residual: 74 | # layers += [ResidualSequential(*layers_residual)] 75 | # else: 76 | # layers += [Sequential(*layers_residual)] 77 | 78 | # if factor >= 2: 79 | # # Do upsampling if needed 80 | # layers += [ 81 | # nn.Conv2d(num_channels, num_channels * 82 | # factor ** 2, 3, 1), 83 | # nn.PixelShuffle(factor), 84 | # act(act_fun) 85 | # ] 86 | layers += [ 87 | conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad), 88 | nn.Sigmoid() 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, input): 93 | return self.model(input) 94 | 95 | def eval(self): 96 | self.model.eval() 97 | -------------------------------------------------------------------------------- /models/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | def skip( 6 | num_input_channels=2, num_output_channels=3, 7 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 8 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 9 | need_sigmoid=True, need_bias=True, 10 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 11 | need1x1_up=True): 12 | """Assembles encoder-decoder with skip connections. 13 | 14 | Arguments: 15 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 16 | pad (string): zero|reflection (default: 'zero') 17 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 18 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 19 | 20 | """ 21 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 22 | 23 | n_scales = len(num_channels_down) 24 | 25 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 26 | upsample_mode = [upsample_mode]*n_scales 27 | 28 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 29 | downsample_mode = [downsample_mode]*n_scales 30 | 31 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 32 | filter_size_down = [filter_size_down]*n_scales 33 | 34 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 35 | filter_size_up = [filter_size_up]*n_scales 36 | 37 | last_scale = n_scales - 1 38 | 39 | cur_depth = None 40 | 41 | model = nn.Sequential() 42 | model_tmp = model 43 | 44 | input_depth = num_input_channels 45 | for i in range(len(num_channels_down)): 46 | 47 | deeper = nn.Sequential() 48 | skip = nn.Sequential() 49 | 50 | if num_channels_skip[i] != 0: 51 | model_tmp.add(Concat(1, skip, deeper)) 52 | else: 53 | model_tmp.add(deeper) 54 | 55 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 56 | 57 | if num_channels_skip[i] != 0: 58 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 59 | skip.add(bn(num_channels_skip[i])) 60 | skip.add(act(act_fun)) 61 | 62 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 63 | 64 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 65 | deeper.add(bn(num_channels_down[i])) 66 | deeper.add(act(act_fun)) 67 | 68 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 69 | deeper.add(bn(num_channels_down[i])) 70 | deeper.add(act(act_fun)) 71 | 72 | deeper_main = nn.Sequential() 73 | 74 | if i == len(num_channels_down) - 1: 75 | # The deepest 76 | k = num_channels_down[i] 77 | else: 78 | deeper.add(deeper_main) 79 | k = num_channels_up[i + 1] 80 | 81 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 82 | 83 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 84 | model_tmp.add(bn(num_channels_up[i])) 85 | model_tmp.add(act(act_fun)) 86 | 87 | 88 | if need1x1_up: 89 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 90 | model_tmp.add(bn(num_channels_up[i])) 91 | model_tmp.add(act(act_fun)) 92 | 93 | input_depth = num_channels_down[i] 94 | model_tmp = deeper_main 95 | 96 | # model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 97 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 98 | if need_sigmoid: 99 | model.add(nn.Sigmoid()) 100 | 101 | return model 102 | 103 | 104 | def skip_with_code( 105 | num_input_channels=2, num_output_channels=3, num_code_channels=3, 106 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 107 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 108 | need_sigmoid=True, need_bias=True, 109 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 110 | need1x1_up=True): 111 | """Assembles encoder-decoder with skip connections. 112 | 113 | Arguments: 114 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 115 | pad (string): zero|reflection (default: 'zero') 116 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 117 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 118 | 119 | """ 120 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 121 | 122 | n_scales = len(num_channels_down) 123 | 124 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 125 | upsample_mode = [upsample_mode]*n_scales 126 | 127 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 128 | downsample_mode = [downsample_mode]*n_scales 129 | 130 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 131 | filter_size_down = [filter_size_down]*n_scales 132 | 133 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 134 | filter_size_up = [filter_size_up]*n_scales 135 | 136 | last_scale = n_scales - 1 137 | 138 | cur_depth = None 139 | 140 | model = nn.Sequential() 141 | model_tmp = model 142 | 143 | input_depth = num_input_channels 144 | for i in range(len(num_channels_down)): 145 | 146 | deeper = nn.Sequential() 147 | skip = nn.Sequential() 148 | 149 | if num_channels_skip[i] != 0: 150 | model_tmp.add(Concat(1, skip, deeper)) 151 | else: 152 | model_tmp.add(deeper) 153 | 154 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 155 | 156 | if num_channels_skip[i] != 0: 157 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 158 | skip.add(bn(num_channels_skip[i])) 159 | skip.add(act(act_fun)) 160 | 161 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 162 | 163 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 164 | deeper.add(bn(num_channels_down[i])) 165 | deeper.add(act(act_fun)) 166 | 167 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 168 | deeper.add(bn(num_channels_down[i])) 169 | deeper.add(act(act_fun)) 170 | 171 | deeper_main = nn.Sequential() 172 | 173 | if i == len(num_channels_down) - 1: 174 | # The deepest 175 | k = num_channels_down[i] 176 | else: 177 | deeper.add(deeper_main) 178 | k = num_channels_up[i + 1] 179 | 180 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 181 | 182 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 183 | model_tmp.add(bn(num_channels_up[i])) 184 | model_tmp.add(act(act_fun)) 185 | 186 | 187 | if need1x1_up: 188 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 189 | model_tmp.add(bn(num_channels_up[i])) 190 | model_tmp.add(act(act_fun)) 191 | 192 | input_depth = num_channels_down[i] 193 | model_tmp = deeper_main 194 | 195 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 196 | if need_sigmoid: 197 | model.add(nn.Sigmoid()) 198 | 199 | return model 200 | -------------------------------------------------------------------------------- /models/snet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from models.common import * 4 | import math 5 | 6 | class snet_test(nn.Module): 7 | def __init__(self, img_ch=3, out_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "bn"): 8 | super(snet_test, self).__init__() 9 | self.conv1 = conv(img_ch, 64, kernel_size=5, stride=1, bias=False, pad=pad) 10 | self.conv2 = conv(64, 32, kernel_size=5, stride=1, bias=False, pad=pad) 11 | self.conv3 = conv(32, out_ch, kernel_size=5, stride=1, bias=False, pad=pad) 12 | self.conv_ = nn.Sequential( 13 | self.conv1, bn(64, bn_mode), act(act_type), 14 | bn(64, bn_mode), act(act_type), self.conv2,) 15 | 16 | self.conv__ = nn.Sequential( 17 | bn(32, bn_mode), act(act_type), self.conv3) 18 | self.Sig = nn.Sigmoid() 19 | 20 | 21 | def forward(self, x): 22 | x = (x - x.mean()) / (x.std()) 23 | d0 = self.conv_(x) 24 | d0 = self.conv__(d0) 25 | return self.Sig(d0 / 30) 26 | 27 | class snet_4x(nn.Module): 28 | def __init__(self, img_ch=3, out_ch=3, act_type="LeakyReLU", pad='reflection', bn_mode = "bn"): 29 | super(snet_4x, self).__init__() 30 | self.conv1 = conv(img_ch * 4, 64, kernel_size=5, stride=1, bias=False, pad=pad) 31 | self.conv2 = conv(64, 32, kernel_size=5, stride=1, bias=False, pad=pad) 32 | self.conv3 = conv(32, out_ch, kernel_size=5, stride=1, bias=False, pad=pad) 33 | self.conv_ = nn.Sequential( 34 | self.conv1, bn(64, bn_mode), act(act_type), 35 | bn(64, bn_mode), act(act_type), self.conv2,) 36 | 37 | self.conv__ = nn.Sequential( 38 | bn(32, bn_mode), act(act_type), self.conv3) 39 | self.Sig = nn.Sigmoid() 40 | 41 | def forward(self, x): 42 | x = (x - x.mean()) / (x.std()) 43 | x_bu = torch.flip(x, dims=[2]) 44 | x_rl = torch.flip(x, dims=[3]) 45 | x_all = torch.flip(x, dims=[2,3]) 46 | x = torch.cat([x, x_bu, x_rl, x_all], dim=1) 47 | 48 | d0 = self.conv_(x) 49 | d0 = self.conv__(d0) 50 | return self.Sig(d0 / 30) 51 | 52 | 53 | 54 | class S2SnetW(nn.Module): 55 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection'): 56 | super(S2SnetW, self).__init__() 57 | self.net1 = snet_test(img_ch, output_ch, act_type, pad) 58 | self.net2 = snet_test(img_ch, output_ch, act_type, pad) 59 | 60 | def forward(self, x): 61 | return self.net2(self.net1(x)) 62 | 63 | 64 | class S2SnetT(nn.Module): 65 | def __init__(self, img_ch=3, output_ch=3, act_type="LeakyReLU", pad='reflection'): 66 | super(S2SnetT, self).__init__() 67 | self.net = snet_test(img_ch, output_ch, act_type, pad) 68 | 69 | def forward(self, x): 70 | return self.net(self.net(x)) 71 | 72 | if __name__ == '__main__': 73 | x = torch.rand([1, 3, 481, 321]) 74 | 75 | net = snet_test() 76 | print(net(x).shape) 77 | -------------------------------------------------------------------------------- /models/texture_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | 6 | normalization = nn.BatchNorm2d 7 | 8 | 9 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero'): 10 | if pad == 'zero': 11 | return nn.Conv2d(in_f, out_f, kernel_size, stride, padding=(kernel_size - 1) / 2, bias=bias) 12 | elif pad == 'reflection': 13 | layers = [nn.ReflectionPad2d((kernel_size - 1) / 2), 14 | nn.Conv2d(in_f, out_f, kernel_size, stride, padding=0, bias=bias)] 15 | return nn.Sequential(*layers) 16 | 17 | def get_texture_nets(inp=3, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False, pad='zero', need_sigmoid=False, conv_num=8, upsample_mode='nearest'): 18 | 19 | 20 | for i in range(len(ratios)): 21 | j = i + 1 22 | 23 | seq = nn.Sequential() 24 | 25 | tmp = nn.AvgPool2d(ratios[i], ratios[i]) 26 | 27 | seq.add(tmp) 28 | if fill_noise: 29 | seq.add(GenNoise(inp)) 30 | 31 | seq.add(conv(inp, conv_num, 3, pad=pad)) 32 | seq.add(normalization(conv_num)) 33 | seq.add(act()) 34 | 35 | seq.add(conv(conv_num, conv_num, 3, pad=pad)) 36 | seq.add(normalization(conv_num)) 37 | seq.add(act()) 38 | 39 | seq.add(conv(conv_num, conv_num, 1, pad=pad)) 40 | seq.add(normalization(conv_num)) 41 | seq.add(act()) 42 | 43 | if i == 0: 44 | seq.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 45 | cur = seq 46 | else: 47 | 48 | cur_temp = cur 49 | 50 | cur = nn.Sequential() 51 | 52 | # Batch norm before merging 53 | seq.add(normalization(conv_num)) 54 | cur_temp.add(normalization(conv_num * (j - 1))) 55 | 56 | cur.add(Concat(1, cur_temp, seq)) 57 | 58 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 59 | cur.add(normalization(conv_num * j)) 60 | cur.add(act()) 61 | 62 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 63 | cur.add(normalization(conv_num * j)) 64 | cur.add(act()) 65 | 66 | cur.add(conv(conv_num * j, conv_num * j, 1, pad=pad)) 67 | cur.add(normalization(conv_num * j)) 68 | cur.add(act()) 69 | 70 | if i == len(ratios) - 1: 71 | cur.add(conv(conv_num * j, 3, 1, pad=pad)) 72 | else: 73 | cur.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 74 | 75 | model = cur 76 | if need_sigmoid: 77 | model.add(nn.Sigmoid()) 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | 7 | class ListModule(nn.Module): 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | idx = 0 11 | for module in args: 12 | self.add_module(str(idx), module) 13 | idx += 1 14 | 15 | def __getitem__(self, idx): 16 | if idx >= len(self._modules): 17 | raise IndexError('index {} is out of range'.format(idx)) 18 | if idx < 0: 19 | idx = len(self) + idx 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | return next(it) 25 | 26 | def __iter__(self): 27 | return iter(self._modules.values()) 28 | 29 | def __len__(self): 30 | return len(self._modules) 31 | 32 | class UNet(nn.Module): 33 | ''' 34 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 35 | pad in ['zero', 'replication', 'none'] 36 | ''' 37 | def __init__(self, num_input_channels=3, num_output_channels=3, 38 | feature_scale=4, more_layers=0, concat_x=False, 39 | upsample_mode='deconv', pad='zero', norm_layer=nn.InstanceNorm2d, need_sigmoid=True, need_bias=True): 40 | super(UNet, self).__init__() 41 | 42 | self.feature_scale = feature_scale 43 | self.more_layers = more_layers 44 | self.concat_x = concat_x 45 | 46 | 47 | filters = [64, 128, 256, 512, 1024] 48 | filters = [x // self.feature_scale for x in filters] 49 | 50 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 51 | 52 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 53 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 54 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 55 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 56 | 57 | # more downsampling layers 58 | if self.more_layers > 0: 59 | self.more_downs = [ 60 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 61 | self.more_ups = [unetUp(filters[4], upsample_mode, need_bias, pad, same_num_filt =True) for i in range(self.more_layers)] 62 | 63 | self.more_downs = ListModule(*self.more_downs) 64 | self.more_ups = ListModule(*self.more_ups) 65 | 66 | self.up4 = unetUp(filters[3], upsample_mode, need_bias, pad) 67 | self.up3 = unetUp(filters[2], upsample_mode, need_bias, pad) 68 | self.up2 = unetUp(filters[1], upsample_mode, need_bias, pad) 69 | self.up1 = unetUp(filters[0], upsample_mode, need_bias, pad) 70 | 71 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 72 | 73 | if need_sigmoid: 74 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 75 | 76 | def forward(self, inputs): 77 | 78 | # Downsample 79 | downs = [inputs] 80 | down = nn.AvgPool2d(2, 2) 81 | for i in range(4 + self.more_layers): 82 | downs.append(down(downs[-1])) 83 | 84 | in64 = self.start(inputs) 85 | if self.concat_x: 86 | in64 = torch.cat([in64, downs[0]], 1) 87 | 88 | down1 = self.down1(in64) 89 | if self.concat_x: 90 | down1 = torch.cat([down1, downs[1]], 1) 91 | 92 | down2 = self.down2(down1) 93 | if self.concat_x: 94 | down2 = torch.cat([down2, downs[2]], 1) 95 | 96 | down3 = self.down3(down2) 97 | if self.concat_x: 98 | down3 = torch.cat([down3, downs[3]], 1) 99 | 100 | down4 = self.down4(down3) 101 | if self.concat_x: 102 | down4 = torch.cat([down4, downs[4]], 1) 103 | 104 | if self.more_layers > 0: 105 | prevs = [down4] 106 | for kk, d in enumerate(self.more_downs): 107 | # print(prevs[-1].size()) 108 | out = d(prevs[-1]) 109 | if self.concat_x: 110 | out = torch.cat([out, downs[kk + 5]], 1) 111 | 112 | prevs.append(out) 113 | 114 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 115 | for idx in range(self.more_layers - 1): 116 | l = self.more_ups[self.more - idx - 2] 117 | up_= l(up_, prevs[self.more - idx - 2]) 118 | else: 119 | up_= down4 120 | 121 | up4= self.up4(up_, down3) 122 | up3= self.up3(up4, down2) 123 | up2= self.up2(up3, down1) 124 | up1= self.up1(up2, in64) 125 | 126 | return self.final(up1) 127 | 128 | 129 | 130 | class unetConv2(nn.Module): 131 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 132 | super(unetConv2, self).__init__() 133 | 134 | print(pad) 135 | if norm_layer is not None: 136 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 137 | norm_layer(out_size), 138 | nn.ReLU(),) 139 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 140 | norm_layer(out_size), 141 | nn.ReLU(),) 142 | else: 143 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 144 | nn.ReLU(),) 145 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 146 | nn.ReLU(),) 147 | def forward(self, inputs): 148 | outputs= self.conv1(inputs) 149 | outputs= self.conv2(outputs) 150 | return outputs 151 | 152 | 153 | class unetDown(nn.Module): 154 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 155 | super(unetDown, self).__init__() 156 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 157 | self.down= nn.MaxPool2d(2, 2) 158 | 159 | def forward(self, inputs): 160 | outputs= self.down(inputs) 161 | outputs= self.conv(outputs) 162 | return outputs 163 | 164 | 165 | class unetUp(nn.Module): 166 | def __init__(self, out_size, upsample_mode, need_bias, pad, same_num_filt=False): 167 | super(unetUp, self).__init__() 168 | 169 | num_filt = out_size if same_num_filt else out_size * 2 170 | if upsample_mode == 'deconv': 171 | self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1) 172 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 173 | elif upsample_mode=='bilinear' or upsample_mode=='nearest': 174 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode=upsample_mode, align_corners=True), 175 | conv(num_filt, out_size, 3, bias=need_bias, pad=pad)) 176 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 177 | else: 178 | assert False 179 | 180 | def forward(self, inputs1, inputs2): 181 | in1_up= self.up(inputs1) 182 | 183 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 184 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 185 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 186 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 187 | else: 188 | inputs2_ = inputs2 189 | 190 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 191 | 192 | return output 193 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | @EXPLICIT 5 | https://repo.anaconda.com/pkgs/main/linux-64/_pytorch_select-0.1-cpu_0.conda 6 | https://repo.anaconda.com/pkgs/main/linux-64/_libgcc_mutex-0.1-main.conda 7 | https://repo.anaconda.com/pkgs/main/linux-64/blas-1.0-mkl.conda 8 | https://repo.anaconda.com/pkgs/main/linux-64/ca-certificates-2020.7.22-0.conda 9 | https://repo.anaconda.com/pkgs/main/linux-64/intel-openmp-2019.4-243.conda 10 | https://repo.anaconda.com/pkgs/main/linux-64/ld_impl_linux-64-2.33.1-h53a641e_7.conda 11 | https://repo.anaconda.com/pkgs/main/linux-64/libgfortran-ng-7.3.0-hdf63c60_0.conda 12 | https://repo.anaconda.com/pkgs/main/linux-64/libstdcxx-ng-9.1.0-hdf63c60_0.conda 13 | https://repo.anaconda.com/pkgs/main/linux-64/libgcc-ng-9.1.0-hdf63c60_0.conda 14 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-2019.4-243.conda 15 | https://repo.anaconda.com/pkgs/main/linux-64/bzip2-1.0.8-h7b6447c_0.conda 16 | https://repo.anaconda.com/pkgs/main/linux-64/c-ares-1.16.1-h7b6447c_0.conda 17 | https://repo.anaconda.com/pkgs/main/linux-64/cudatoolkit-10.2.89-hfd86e86_1.conda 18 | https://repo.anaconda.com/pkgs/main/linux-64/expat-2.2.9-he6710b0_2.conda 19 | https://repo.anaconda.com/pkgs/main/linux-64/freeglut-3.0.0-hf484d3e_5.conda 20 | https://repo.anaconda.com/pkgs/main/linux-64/graphite2-1.3.14-h23475e2_0.conda 21 | https://repo.anaconda.com/pkgs/main/linux-64/icu-58.2-he6710b0_3.conda 22 | https://repo.anaconda.com/pkgs/main/linux-64/jpeg-9b-h024ee3a_2.conda 23 | https://repo.anaconda.com/pkgs/main/linux-64/libffi-3.3-he6710b0_2.conda 24 | https://repo.anaconda.com/pkgs/main/linux-64/libglu-9.0.0-hf484d3e_1.conda 25 | https://repo.anaconda.com/pkgs/main/linux-64/libopus-1.3.1-h7b6447c_0.conda 26 | https://repo.anaconda.com/pkgs/main/linux-64/libuuid-1.0.3-h1bed415_2.conda 27 | https://repo.anaconda.com/pkgs/main/linux-64/libvpx-1.7.0-h439df22_0.conda 28 | https://repo.anaconda.com/pkgs/main/linux-64/libxcb-1.14-h7b6447c_0.conda 29 | https://repo.anaconda.com/pkgs/main/linux-64/lz4-c-1.9.2-he6710b0_1.conda 30 | https://repo.anaconda.com/pkgs/main/linux-64/ncurses-6.2-he6710b0_1.conda 31 | https://repo.anaconda.com/pkgs/main/linux-64/openssl-1.1.1h-h7b6447c_0.conda 32 | https://repo.anaconda.com/pkgs/main/linux-64/pcre-8.44-he6710b0_0.conda 33 | https://repo.anaconda.com/pkgs/main/linux-64/pixman-0.40.0-h7b6447c_0.conda 34 | https://repo.anaconda.com/pkgs/main/linux-64/xz-5.2.5-h7b6447c_0.conda 35 | https://repo.anaconda.com/pkgs/main/linux-64/yaml-0.2.5-h7b6447c_0.conda 36 | https://repo.anaconda.com/pkgs/main/linux-64/zlib-1.2.11-h7b6447c_3.conda 37 | https://repo.anaconda.com/pkgs/main/linux-64/glib-2.65.0-h3eb4bd4_0.conda 38 | https://repo.anaconda.com/pkgs/main/linux-64/hdf5-1.10.2-hba1933b_1.conda 39 | https://repo.anaconda.com/pkgs/main/linux-64/jasper-2.0.14-h07fcdf6_1.conda 40 | https://repo.anaconda.com/pkgs/main/linux-64/libedit-3.1.20191231-h14c3975_1.conda 41 | https://repo.anaconda.com/pkgs/main/linux-64/libpng-1.6.37-hbc83047_0.conda 42 | https://repo.anaconda.com/pkgs/main/linux-64/libprotobuf-3.13.0-hd408876_0.conda 43 | https://repo.anaconda.com/pkgs/main/linux-64/libxml2-2.9.10-he19cac6_1.conda 44 | https://repo.anaconda.com/pkgs/main/linux-64/readline-8.0-h7b6447c_0.conda 45 | https://repo.anaconda.com/pkgs/main/linux-64/tk-8.6.10-hbc83047_0.conda 46 | https://repo.anaconda.com/pkgs/main/linux-64/zstd-1.4.5-h9ceee32_0.conda 47 | https://repo.anaconda.com/pkgs/main/linux-64/dbus-1.13.16-hb2f20db_0.conda 48 | https://repo.anaconda.com/pkgs/main/linux-64/freetype-2.10.2-h5ab3b9f_0.conda 49 | https://repo.anaconda.com/pkgs/main/linux-64/gstreamer-1.14.0-hb31296c_0.conda 50 | https://repo.anaconda.com/pkgs/main/linux-64/libtiff-4.1.0-h2733197_1.conda 51 | https://repo.anaconda.com/pkgs/main/linux-64/sqlite-3.33.0-h62c20be_0.conda 52 | https://repo.anaconda.com/pkgs/main/linux-64/ffmpeg-4.0-hcdf2ecd_0.conda 53 | https://repo.anaconda.com/pkgs/main/linux-64/fontconfig-2.13.0-h9420a91_0.conda 54 | https://repo.anaconda.com/pkgs/main/linux-64/gst-plugins-base-1.14.0-hbbd80ab_1.conda 55 | https://repo.anaconda.com/pkgs/main/linux-64/lcms2-2.11-h396b838_0.conda 56 | https://repo.anaconda.com/pkgs/main/linux-64/python-3.7.9-h7579374_0.conda 57 | https://repo.anaconda.com/pkgs/main/linux-64/async-timeout-3.0.1-py37_0.conda 58 | https://repo.anaconda.com/pkgs/main/noarch/attrs-20.2.0-py_0.conda 59 | https://repo.anaconda.com/pkgs/main/linux-64/blinker-1.4-py37_0.conda 60 | https://repo.anaconda.com/pkgs/main/noarch/cachetools-4.1.1-py_0.conda 61 | https://repo.anaconda.com/pkgs/main/linux-64/cairo-1.14.12-h8948797_3.conda 62 | https://repo.anaconda.com/pkgs/main/linux-64/certifi-2020.6.20-py37_0.conda 63 | https://repo.anaconda.com/pkgs/main/linux-64/chardet-3.0.4-py37_1003.conda 64 | https://repo.anaconda.com/pkgs/main/noarch/click-7.1.2-py_0.conda 65 | https://repo.anaconda.com/pkgs/main/noarch/cloudpickle-1.6.0-py_0.conda 66 | https://repo.anaconda.com/pkgs/main/noarch/decorator-4.4.2-py_0.conda 67 | https://repo.anaconda.com/pkgs/main/noarch/idna-2.10-py_0.conda 68 | https://repo.anaconda.com/pkgs/main/linux-64/invoke-1.4.1-py37_0.conda 69 | https://repo.anaconda.com/pkgs/main/linux-64/kiwisolver-1.2.0-py37hfd86e86_0.conda 70 | https://repo.anaconda.com/pkgs/main/linux-64/multidict-4.7.6-py37h7b6447c_1.conda 71 | https://repo.anaconda.com/pkgs/main/linux-64/ninja-1.10.1-py37hfd86e86_0.conda 72 | https://repo.anaconda.com/pkgs/main/linux-64/olefile-0.46-py37_0.conda 73 | https://repo.anaconda.com/pkgs/main/noarch/pyasn1-0.4.8-py_0.conda 74 | https://repo.anaconda.com/pkgs/main/noarch/pycparser-2.20-py_2.conda 75 | https://repo.anaconda.com/pkgs/main/noarch/pyparsing-2.4.7-py_0.conda 76 | https://repo.anaconda.com/pkgs/main/linux-64/pysocks-1.7.1-py37_1.conda 77 | https://repo.anaconda.com/pkgs/main/noarch/pytz-2020.1-py_0.conda 78 | https://repo.anaconda.com/pkgs/main/linux-64/pyyaml-5.3.1-py37h7b6447c_1.conda 79 | https://repo.anaconda.com/pkgs/main/linux-64/qt-5.9.7-h5867ecd_1.conda 80 | https://repo.anaconda.com/pkgs/main/linux-64/sip-4.19.8-py37hf484d3e_0.conda 81 | https://repo.anaconda.com/pkgs/main/noarch/six-1.15.0-py_0.conda 82 | https://repo.anaconda.com/pkgs/main/noarch/toolz-0.11.1-py_0.conda 83 | https://repo.anaconda.com/pkgs/main/linux-64/tornado-6.0.4-py37h7b6447c_1.conda 84 | https://repo.anaconda.com/pkgs/main/noarch/werkzeug-1.0.1-py_0.conda 85 | https://repo.anaconda.com/pkgs/main/noarch/wheel-0.35.1-py_0.conda 86 | https://repo.anaconda.com/pkgs/main/noarch/zipp-3.1.0-py_0.conda 87 | https://repo.anaconda.com/pkgs/main/linux-64/absl-py-0.10.0-py37_0.conda 88 | https://repo.anaconda.com/pkgs/main/linux-64/cffi-1.14.3-py37he30daa8_0.conda 89 | https://repo.anaconda.com/pkgs/main/linux-64/cycler-0.10.0-py37_0.conda 90 | https://repo.anaconda.com/pkgs/main/linux-64/cytoolz-0.11.0-py37h7b6447c_0.conda 91 | https://repo.anaconda.com/pkgs/main/noarch/dask-core-2.28.0-py_0.conda 92 | https://repo.anaconda.com/pkgs/main/linux-64/harfbuzz-1.8.8-hffaf4a1_0.conda 93 | https://repo.anaconda.com/pkgs/main/linux-64/importlib-metadata-1.7.0-py37_0.conda 94 | https://repo.anaconda.com/pkgs/main/linux-64/mkl-service-2.3.0-py37he904b0f_0.conda 95 | https://repo.anaconda.com/pkgs/main/linux-64/pillow-7.2.0-py37hb39fc2d_0.conda 96 | https://repo.anaconda.com/pkgs/main/noarch/pyasn1-modules-0.2.8-py_0.conda 97 | https://repo.anaconda.com/pkgs/main/linux-64/pyqt-5.9.2-py37h05f1152_2.conda 98 | https://repo.anaconda.com/pkgs/main/noarch/python-dateutil-2.8.1-py_0.conda 99 | https://repo.anaconda.com/pkgs/main/noarch/rsa-4.6-py_0.conda 100 | https://repo.anaconda.com/pkgs/main/linux-64/setuptools-49.6.0-py37_1.conda 101 | https://repo.anaconda.com/pkgs/main/linux-64/yarl-1.5.1-py37h7b6447c_0.conda 102 | https://repo.anaconda.com/pkgs/main/linux-64/aiohttp-3.6.2-py37h7b6447c_0.conda 103 | https://repo.anaconda.com/pkgs/main/linux-64/brotlipy-0.7.0-py37h7b6447c_1000.conda 104 | https://repo.anaconda.com/pkgs/main/linux-64/cryptography-3.1.1-py37h1ba5d50_0.conda 105 | https://repo.anaconda.com/pkgs/main/linux-64/grpcio-1.31.0-py37hf8bcb03_0.conda 106 | https://repo.anaconda.com/pkgs/main/linux-64/libopencv-3.4.2-hb342d67_1.conda 107 | https://repo.anaconda.com/pkgs/main/linux-64/markdown-3.2.2-py37_0.conda 108 | https://repo.anaconda.com/pkgs/main/noarch/networkx-2.5-py_0.conda 109 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-base-1.19.1-py37hfa32c7d_0.conda 110 | https://repo.anaconda.com/pkgs/main/linux-64/pip-20.2.3-py37_0.conda 111 | https://repo.anaconda.com/pkgs/main/linux-64/protobuf-3.13.0-py37hf484d3e_1.conda 112 | https://repo.anaconda.com/pkgs/main/noarch/tensorboard-plugin-wit-1.6.0-py_0.conda 113 | https://repo.anaconda.com/pkgs/main/noarch/google-auth-1.22.0-py_0.conda 114 | https://repo.anaconda.com/pkgs/r/linux-64/mkl_random-1.0.4-py37hd81dba3_0.tar.bz2 115 | https://repo.anaconda.com/pkgs/main/linux-64/pyjwt-1.7.1-py37_0.conda 116 | https://repo.anaconda.com/pkgs/main/noarch/pyopenssl-19.1.0-py_1.conda 117 | https://repo.anaconda.com/pkgs/main/noarch/oauthlib-3.1.0-py_0.conda 118 | https://repo.anaconda.com/pkgs/main/noarch/urllib3-1.25.10-py_0.conda 119 | https://repo.anaconda.com/pkgs/main/noarch/requests-2.24.0-py_0.conda 120 | https://repo.anaconda.com/pkgs/main/noarch/requests-oauthlib-1.3.0-py_0.conda 121 | https://repo.anaconda.com/pkgs/main/noarch/google-auth-oauthlib-0.4.1-py_2.conda 122 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-3.3.1-0.conda 123 | https://repo.anaconda.com/pkgs/main/linux-64/matplotlib-base-3.3.1-py37h817c723_0.conda 124 | https://repo.anaconda.com/pkgs/main/linux-64/mkl_fft-1.2.0-py37h23d657b_0.conda 125 | https://repo.anaconda.com/pkgs/main/linux-64/numpy-1.19.1-py37hbc911f0_0.conda 126 | https://repo.anaconda.com/pkgs/main/noarch/imageio-2.9.0-py_0.conda 127 | https://repo.anaconda.com/pkgs/main/linux-64/pandas-1.1.2-py37he6710b0_0.conda 128 | https://repo.anaconda.com/pkgs/main/linux-64/py-opencv-3.4.2-py37hb342d67_1.conda 129 | https://conda.anaconda.org/pytorch/linux-64/pytorch-1.6.0-py3.7_cuda10.2.89_cudnn7.6.5_0.tar.bz2 130 | https://repo.anaconda.com/pkgs/main/linux-64/pywavelets-1.1.1-py37h7b6447c_2.conda 131 | https://repo.anaconda.com/pkgs/main/linux-64/scipy-1.5.2-py37h0b6359f_0.conda 132 | https://repo.anaconda.com/pkgs/main/noarch/tensorboard-2.2.1-pyh532a8cf_0.conda 133 | https://repo.anaconda.com/pkgs/main/linux-64/opencv-3.4.2-py37h6fd60c2_1.conda 134 | https://repo.anaconda.com/pkgs/main/linux-64/scikit-image-0.16.2-py37h0573a6f_0.conda 135 | https://conda.anaconda.org/pytorch/linux-64/torchvision-0.7.0-py37_cu102.tar.bz2 136 | -------------------------------------------------------------------------------- /tasks.py: -------------------------------------------------------------------------------- 1 | from invoke import task 2 | import pandas as pd 3 | 4 | @task 5 | def showtable(c, csv_dir, prefix = ""): 6 | import os 7 | csv_list = os.listdir(csv_dir) 8 | csv_list = sorted(csv_list) 9 | 10 | print_result = lambda i, tmp: print(i[:-4], "\t\t\t", ("optimal stopping : %.2f,\t" + "%.2f/%.2f \t| ZCSC : %.2f, \t %.2f/%.2f | STE %.2f/%.2f") % 11 | (tmp["max_step"], tmp["max_psnr"], tmp["max_lpips"] * 10, 12 | tmp["final_ep"], tmp["final_psnr"], tmp["final_lpips"] * 10, tmp["final_psnr_avg"], tmp["final_lpips_avg"] * 10,)) 13 | 14 | for i in csv_list: 15 | tmp = pd.read_csv(os.path.join(csv_dir, i)) 16 | tmp_mean = tmp.mean() 17 | print_result(i, tmp_mean) 18 | -------------------------------------------------------------------------------- /testset/BSD68/test001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test001.png -------------------------------------------------------------------------------- /testset/BSD68/test002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test002.png -------------------------------------------------------------------------------- /testset/BSD68/test003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test003.png -------------------------------------------------------------------------------- /testset/BSD68/test004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test004.png -------------------------------------------------------------------------------- /testset/BSD68/test005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test005.png -------------------------------------------------------------------------------- /testset/BSD68/test006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test006.png -------------------------------------------------------------------------------- /testset/BSD68/test007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test007.png -------------------------------------------------------------------------------- /testset/BSD68/test008.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test008.png -------------------------------------------------------------------------------- /testset/BSD68/test009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test009.png -------------------------------------------------------------------------------- /testset/BSD68/test010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test010.png -------------------------------------------------------------------------------- /testset/BSD68/test011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test011.png -------------------------------------------------------------------------------- /testset/BSD68/test012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test012.png -------------------------------------------------------------------------------- /testset/BSD68/test013.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test013.png -------------------------------------------------------------------------------- /testset/BSD68/test014.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test014.png -------------------------------------------------------------------------------- /testset/BSD68/test015.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test015.png -------------------------------------------------------------------------------- /testset/BSD68/test016.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test016.png -------------------------------------------------------------------------------- /testset/BSD68/test017.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test017.png -------------------------------------------------------------------------------- /testset/BSD68/test018.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test018.png -------------------------------------------------------------------------------- /testset/BSD68/test019.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test019.png -------------------------------------------------------------------------------- /testset/BSD68/test020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test020.png -------------------------------------------------------------------------------- /testset/BSD68/test021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test021.png -------------------------------------------------------------------------------- /testset/BSD68/test022.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test022.png -------------------------------------------------------------------------------- /testset/BSD68/test023.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test023.png -------------------------------------------------------------------------------- /testset/BSD68/test024.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test024.png -------------------------------------------------------------------------------- /testset/BSD68/test025.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test025.png -------------------------------------------------------------------------------- /testset/BSD68/test026.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test026.png -------------------------------------------------------------------------------- /testset/BSD68/test027.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test027.png -------------------------------------------------------------------------------- /testset/BSD68/test028.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test028.png -------------------------------------------------------------------------------- /testset/BSD68/test029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test029.png -------------------------------------------------------------------------------- /testset/BSD68/test030.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test030.png -------------------------------------------------------------------------------- /testset/BSD68/test031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test031.png -------------------------------------------------------------------------------- /testset/BSD68/test032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test032.png -------------------------------------------------------------------------------- /testset/BSD68/test033.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test033.png -------------------------------------------------------------------------------- /testset/BSD68/test034.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test034.png -------------------------------------------------------------------------------- /testset/BSD68/test035.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test035.png -------------------------------------------------------------------------------- /testset/BSD68/test036.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test036.png -------------------------------------------------------------------------------- /testset/BSD68/test037.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test037.png -------------------------------------------------------------------------------- /testset/BSD68/test038.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test038.png -------------------------------------------------------------------------------- /testset/BSD68/test039.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test039.png -------------------------------------------------------------------------------- /testset/BSD68/test040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test040.png -------------------------------------------------------------------------------- /testset/BSD68/test041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test041.png -------------------------------------------------------------------------------- /testset/BSD68/test042.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test042.png -------------------------------------------------------------------------------- /testset/BSD68/test043.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test043.png -------------------------------------------------------------------------------- /testset/BSD68/test044.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test044.png -------------------------------------------------------------------------------- /testset/BSD68/test045.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test045.png -------------------------------------------------------------------------------- /testset/BSD68/test046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test046.png -------------------------------------------------------------------------------- /testset/BSD68/test047.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test047.png -------------------------------------------------------------------------------- /testset/BSD68/test048.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test048.png -------------------------------------------------------------------------------- /testset/BSD68/test049.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test049.png -------------------------------------------------------------------------------- /testset/BSD68/test050.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test050.png -------------------------------------------------------------------------------- /testset/BSD68/test051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test051.png -------------------------------------------------------------------------------- /testset/BSD68/test052.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test052.png -------------------------------------------------------------------------------- /testset/BSD68/test053.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test053.png -------------------------------------------------------------------------------- /testset/BSD68/test054.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test054.png -------------------------------------------------------------------------------- /testset/BSD68/test055.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test055.png -------------------------------------------------------------------------------- /testset/BSD68/test056.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test056.png -------------------------------------------------------------------------------- /testset/BSD68/test057.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test057.png -------------------------------------------------------------------------------- /testset/BSD68/test058.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test058.png -------------------------------------------------------------------------------- /testset/BSD68/test059.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test059.png -------------------------------------------------------------------------------- /testset/BSD68/test060.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test060.png -------------------------------------------------------------------------------- /testset/BSD68/test061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test061.png -------------------------------------------------------------------------------- /testset/BSD68/test062.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test062.png -------------------------------------------------------------------------------- /testset/BSD68/test063.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test063.png -------------------------------------------------------------------------------- /testset/BSD68/test064.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test064.png -------------------------------------------------------------------------------- /testset/BSD68/test065.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test065.png -------------------------------------------------------------------------------- /testset/BSD68/test066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test066.png -------------------------------------------------------------------------------- /testset/BSD68/test067.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test067.png -------------------------------------------------------------------------------- /testset/BSD68/test068.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/BSD68/test068.png -------------------------------------------------------------------------------- /testset/CSet9/image_Baboon512rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/image_Baboon512rgb.png -------------------------------------------------------------------------------- /testset/CSet9/image_F16_512rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/image_F16_512rgb.png -------------------------------------------------------------------------------- /testset/CSet9/image_House256rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/image_House256rgb.png -------------------------------------------------------------------------------- /testset/CSet9/image_Lena512rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/image_Lena512rgb.png -------------------------------------------------------------------------------- /testset/CSet9/image_Peppers512rgb.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/image_Peppers512rgb.png -------------------------------------------------------------------------------- /testset/CSet9/kodim01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/kodim01.png -------------------------------------------------------------------------------- /testset/CSet9/kodim02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/kodim02.png -------------------------------------------------------------------------------- /testset/CSet9/kodim03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/kodim03.png -------------------------------------------------------------------------------- /testset/CSet9/kodim12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/CSet9/kodim12.png -------------------------------------------------------------------------------- /testset/MNIST/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/1.png -------------------------------------------------------------------------------- /testset/MNIST/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/10.png -------------------------------------------------------------------------------- /testset/MNIST/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/2.png -------------------------------------------------------------------------------- /testset/MNIST/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/3.png -------------------------------------------------------------------------------- /testset/MNIST/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/4.png -------------------------------------------------------------------------------- /testset/MNIST/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/5.png -------------------------------------------------------------------------------- /testset/MNIST/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/6.png -------------------------------------------------------------------------------- /testset/MNIST/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/7.png -------------------------------------------------------------------------------- /testset/MNIST/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/8.png -------------------------------------------------------------------------------- /testset/MNIST/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/MNIST/9.png -------------------------------------------------------------------------------- /testset/Set12/01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/01.png -------------------------------------------------------------------------------- /testset/Set12/02.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/02.png -------------------------------------------------------------------------------- /testset/Set12/03.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/03.png -------------------------------------------------------------------------------- /testset/Set12/04.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/04.png -------------------------------------------------------------------------------- /testset/Set12/05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/05.png -------------------------------------------------------------------------------- /testset/Set12/06.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/06.png -------------------------------------------------------------------------------- /testset/Set12/07.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/07.png -------------------------------------------------------------------------------- /testset/Set12/08.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/08.png -------------------------------------------------------------------------------- /testset/Set12/09.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/09.png -------------------------------------------------------------------------------- /testset/Set12/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/10.png -------------------------------------------------------------------------------- /testset/Set12/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/11.png -------------------------------------------------------------------------------- /testset/Set12/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/testset/Set12/12.png -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from . import dist_model 4 | 5 | 6 | class PerceptualLoss(torch.nn.Module): 7 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0], 8 | version='0.1'): # VGG using our perceptually-learned weights (LPIPS metric) 9 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 10 | super(PerceptualLoss, self).__init__() 11 | print('Setting up Perceptual loss...') 12 | self.use_gpu = use_gpu 13 | self.spatial = spatial 14 | self.gpu_ids = gpu_ids 15 | self.model = dist_model.DistModel() 16 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, 17 | gpu_ids=gpu_ids, version=version) 18 | print('...[%s] initialized' % self.model.name()) 19 | print('...Done') 20 | 21 | def forward(self, pred, target, normalize=False): 22 | """ 23 | Pred and target are Variables. 24 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 25 | If normalize is False, assumes the images are already between [-1,+1] 26 | 27 | Inputs pred and target are Nx3xHxW 28 | Output pytorch Variable N long 29 | """ 30 | 31 | if normalize: 32 | target = 2 * target - 1 33 | pred = 2 * pred - 1 34 | 35 | return self.model.forward(target, pred) 36 | 37 | 38 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel: 6 | def __init__(self): 7 | pass; 8 | 9 | def name(self): 10 | return 'BaseModel' 11 | 12 | def initialize(self, use_gpu=True, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print('Loading network from %s' % save_path) 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def update_learning_rate(): 48 | pass 49 | 50 | def get_image_paths(self): 51 | return self.image_paths 52 | 53 | def save_done(self, flag=False): 54 | np.save(os.path.join(self.save_dir, 'done_flag'), flag) 55 | np.savetxt(os.path.join(self.save_dir, 'done_flag'), [flag, ], fmt='%i') 56 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/dist_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | from torch.autograd import Variable 6 | from .base_model import BaseModel 7 | from scipy.ndimage import zoom 8 | from tqdm import tqdm 9 | 10 | from . import networks_basic as networks 11 | from . import util as util 12 | 13 | 14 | class DistModel(BaseModel): 15 | def name(self): 16 | return self.model_name 17 | 18 | def initialize(self, model='net-lin', net='alex', colorspace='Lab', pnet_rand=False, pnet_tune=False, 19 | model_path=None, 20 | use_gpu=True, printNet=False, spatial=False, 21 | is_train=False, lr=.0001, beta1=0.5, version='0.1', gpu_ids=[0]): 22 | ''' 23 | INPUTS 24 | model - ['net-lin'] for linearly calibrated network 25 | ['net'] for off-the-shelf network 26 | ['L2'] for L2 distance in Lab colorspace 27 | ['SSIM'] for ssim in RGB colorspace 28 | net - ['squeeze','alex','vgg'] 29 | model_path - if None, will look in weights/[NET_NAME].pth 30 | colorspace - ['Lab','RGB'] colorspace to use for L2 and SSIM 31 | use_gpu - bool - whether or not to use a GPU 32 | printNet - bool - whether or not to print network architecture out 33 | spatial - bool - whether to output an array containing varying distances across spatial dimensions 34 | spatial_shape - if given, output spatial shape. if None then spatial shape is determined automatically via spatial_factor (see below). 35 | spatial_factor - if given, specifies upsampling factor relative to the largest spatial extent of a convolutional layer. if None then resized to size of input images. 36 | spatial_order - spline order of filter for upsampling in spatial mode, by default 1 (bilinear). 37 | is_train - bool - [True] for training mode 38 | lr - float - initial learning rate 39 | beta1 - float - initial momentum term for adam 40 | version - 0.1 for latest, 0.0 was original (with a bug) 41 | gpu_ids - int array - [0] by default, gpus to use 42 | ''' 43 | BaseModel.initialize(self, use_gpu=use_gpu, gpu_ids=gpu_ids) 44 | 45 | self.model = model 46 | self.net = net 47 | self.is_train = is_train 48 | self.spatial = spatial 49 | self.gpu_ids = gpu_ids 50 | self.model_name = '%s [%s]' % (model, net) 51 | 52 | if (self.model == 'net-lin'): # pretrained net + linear layer 53 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_tune=pnet_tune, pnet_type=net, 54 | use_dropout=True, spatial=spatial, version=version, lpips=True) 55 | kw = {} 56 | if not use_gpu: 57 | kw['map_location'] = 'cpu' 58 | if (model_path is None): 59 | import inspect 60 | model_path = os.path.abspath( 61 | os.path.join(inspect.getfile(self.initialize), '..', 'weights/v%s/%s.pth' % (version, net))) 62 | 63 | if (not is_train): 64 | print('Loading model from: %s' % model_path) 65 | self.net.load_state_dict(torch.load(model_path, **kw), strict=False) 66 | 67 | elif (self.model == 'net'): # pretrained network 68 | self.net = networks.PNetLin(pnet_rand=pnet_rand, pnet_type=net, lpips=False) 69 | elif (self.model in ['L2', 'l2']): 70 | self.net = networks.L2(use_gpu=use_gpu, colorspace=colorspace) # not really a network, only for testing 71 | self.model_name = 'L2' 72 | elif (self.model in ['DSSIM', 'dssim', 'SSIM', 'ssim']): 73 | self.net = networks.DSSIM(use_gpu=use_gpu, colorspace=colorspace) 74 | self.model_name = 'SSIM' 75 | else: 76 | raise ValueError("Model [%s] not recognized." % self.model) 77 | 78 | self.parameters = list(self.net.parameters()) 79 | 80 | if self.is_train: # training mode 81 | # extra network on top to go from distances (d0,d1) => predicted human judgment (h*) 82 | self.rankLoss = networks.BCERankingLoss() 83 | self.parameters += list(self.rankLoss.net.parameters()) 84 | self.lr = lr 85 | self.old_lr = lr 86 | self.optimizer_net = torch.optim.Adam(self.parameters, lr=lr, betas=(beta1, 0.999)) 87 | else: # test mode 88 | self.net.eval() 89 | 90 | if (use_gpu): 91 | self.net = self.net.to('cuda') 92 | if (self.is_train): 93 | self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 94 | 95 | if (printNet): 96 | print('---------- Networks initialized -------------') 97 | networks.print_network(self.net) 98 | print('-----------------------------------------------') 99 | 100 | def forward(self, in0, in1, retPerLayer=False): 101 | ''' Function computes the distance between image patches in0 and in1 102 | INPUTS 103 | in0, in1 - torch.Tensor object of shape Nx3xXxY - image patch scaled to [-1,1] 104 | OUTPUT 105 | computed distances between in0 and in1 106 | ''' 107 | 108 | return self.net.forward(in0, in1, retPerLayer=retPerLayer) 109 | 110 | # ***** TRAINING FUNCTIONS ***** 111 | def optimize_parameters(self): 112 | self.forward_train() 113 | self.optimizer_net.zero_grad() 114 | self.backward_train() 115 | self.optimizer_net.step() 116 | self.clamp_weights() 117 | 118 | def clamp_weights(self): 119 | for module in self.net.modules(): 120 | if (hasattr(module, 'weight') and module.kernel_size == (1, 1)): 121 | module.weight.data = torch.clamp(module.weight.data, min=0) 122 | 123 | def set_input(self, data): 124 | self.input_ref = data['ref'] 125 | self.input_p0 = data['p0'] 126 | self.input_p1 = data['p1'] 127 | self.input_judge = data['judge'] 128 | 129 | if (self.use_gpu): 130 | self.input_ref = self.input_ref.to(device=self.gpu_ids[0]) 131 | self.input_p0 = self.input_p0.to(device=self.gpu_ids[0]) 132 | self.input_p1 = self.input_p1.to(device=self.gpu_ids[0]) 133 | self.input_judge = self.input_judge.to(device=self.gpu_ids[0]) 134 | 135 | self.var_ref = Variable(self.input_ref, requires_grad=True) 136 | self.var_p0 = Variable(self.input_p0, requires_grad=True) 137 | self.var_p1 = Variable(self.input_p1, requires_grad=True) 138 | 139 | def forward_train(self): # run forward pass 140 | # print(self.net.module.scaling_layer.shift) 141 | # print(torch.norm(self.net.module.net.slice1[0].weight).item(), torch.norm(self.net.module.lin0.model[1].weight).item()) 142 | 143 | self.d0 = self.forward(self.var_ref, self.var_p0) 144 | self.d1 = self.forward(self.var_ref, self.var_p1) 145 | self.acc_r = self.compute_accuracy(self.d0, self.d1, self.input_judge) 146 | 147 | self.var_judge = Variable(1. * self.input_judge).view(self.d0.size()) 148 | 149 | self.loss_total = self.rankLoss.forward(self.d0, self.d1, self.var_judge * 2. - 1.) 150 | 151 | return self.loss_total 152 | 153 | def backward_train(self): 154 | torch.mean(self.loss_total).backward() 155 | 156 | def compute_accuracy(self, d0, d1, judge): 157 | ''' d0, d1 are Variables, judge is a Tensor ''' 158 | d1_lt_d0 = (d1 < d0).cpu().data.numpy().flatten() 159 | judge_per = judge.cpu().numpy().flatten() 160 | return d1_lt_d0 * judge_per + (1 - d1_lt_d0) * (1 - judge_per) 161 | 162 | def get_current_errors(self): 163 | retDict = OrderedDict([('loss_total', self.loss_total.data.cpu().numpy()), 164 | ('acc_r', self.acc_r)]) 165 | 166 | for key in retDict.keys(): 167 | retDict[key] = np.mean(retDict[key]) 168 | 169 | return retDict 170 | 171 | def get_current_visuals(self): 172 | zoom_factor = 256 / self.var_ref.data.size()[2] 173 | 174 | ref_img = util.tensor2im(self.var_ref.data) 175 | p0_img = util.tensor2im(self.var_p0.data) 176 | p1_img = util.tensor2im(self.var_p1.data) 177 | 178 | ref_img_vis = zoom(ref_img, [zoom_factor, zoom_factor, 1], order=0) 179 | p0_img_vis = zoom(p0_img, [zoom_factor, zoom_factor, 1], order=0) 180 | p1_img_vis = zoom(p1_img, [zoom_factor, zoom_factor, 1], order=0) 181 | 182 | return OrderedDict([('ref', ref_img_vis), 183 | ('p0', p0_img_vis), 184 | ('p1', p1_img_vis)]) 185 | 186 | def save(self, path, label): 187 | if (self.use_gpu): 188 | self.save_network(self.net.module, path, '', label) 189 | else: 190 | self.save_network(self.net, path, '', label) 191 | self.save_network(self.rankLoss.net, path, 'rank', label) 192 | 193 | def update_learning_rate(self, nepoch_decay): 194 | lrd = self.lr / nepoch_decay 195 | lr = self.old_lr - lrd 196 | 197 | for param_group in self.optimizer_net.param_groups: 198 | param_group['lr'] = lr 199 | 200 | print('update lr [%s] decay: %f -> %f' % (type, self.old_lr, lr)) 201 | self.old_lr = lr 202 | 203 | 204 | def score_2afc_dataset(data_loader, func, name=''): 205 | ''' Function computes Two Alternative Forced Choice (2AFC) score using 206 | distance function 'func' in dataset 'data_loader' 207 | INPUTS 208 | data_loader - CustomDatasetDataLoader object - contains a TwoAFCDataset inside 209 | func - callable distance function - calling d=func(in0,in1) should take 2 210 | pytorch tensors with shape Nx3xXxY, and return numpy array of length N 211 | OUTPUTS 212 | [0] - 2AFC score in [0,1], fraction of time func agrees with human evaluators 213 | [1] - dictionary with following elements 214 | d0s,d1s - N arrays containing distances between reference patch to perturbed patches 215 | gts - N array in [0,1], preferred patch selected by human evaluators 216 | (closer to "0" for left patch p0, "1" for right patch p1, 217 | "0.6" means 60pct people preferred right patch, 40pct preferred left) 218 | scores - N array in [0,1], corresponding to what percentage function agreed with humans 219 | CONSTS 220 | N - number of test triplets in data_loader 221 | ''' 222 | 223 | d0s = [] 224 | d1s = [] 225 | gts = [] 226 | 227 | for data in tqdm(data_loader.load_data(), desc=name): 228 | d0s += func(data['ref'], data['p0']).data.cpu().numpy().flatten().tolist() 229 | d1s += func(data['ref'], data['p1']).data.cpu().numpy().flatten().tolist() 230 | gts += data['judge'].cpu().numpy().flatten().tolist() 231 | 232 | d0s = np.array(d0s) 233 | d1s = np.array(d1s) 234 | gts = np.array(gts) 235 | scores = (d0s < d1s) * (1. - gts) + (d1s < d0s) * gts + (d1s == d0s) * .5 236 | 237 | return (np.mean(scores), dict(d0s=d0s, d1s=d1s, gts=gts, scores=scores)) 238 | 239 | 240 | def score_jnd_dataset(data_loader, func, name=''): 241 | ''' Function computes JND score using distance function 'func' in dataset 'data_loader' 242 | INPUTS 243 | data_loader - CustomDatasetDataLoader object - contains a JNDDataset inside 244 | func - callable distance function - calling d=func(in0,in1) should take 2 245 | pytorch tensors with shape Nx3xXxY, and return pytorch array of length N 246 | OUTPUTS 247 | [0] - JND score in [0,1], mAP score (area under precision-recall curve) 248 | [1] - dictionary with following elements 249 | ds - N array containing distances between two patches shown to human evaluator 250 | sames - N array containing fraction of people who thought the two patches were identical 251 | CONSTS 252 | N - number of test triplets in data_loader 253 | ''' 254 | 255 | ds = [] 256 | gts = [] 257 | 258 | for data in tqdm(data_loader.load_data(), desc=name): 259 | ds += func(data['p0'], data['p1']).data.cpu().numpy().tolist() 260 | gts += data['same'].cpu().numpy().flatten().tolist() 261 | 262 | sames = np.array(gts) 263 | ds = np.array(ds) 264 | 265 | sorted_inds = np.argsort(ds) 266 | ds_sorted = ds[sorted_inds] 267 | sames_sorted = sames[sorted_inds] 268 | 269 | TPs = np.cumsum(sames_sorted) 270 | FPs = np.cumsum(1 - sames_sorted) 271 | FNs = np.sum(sames_sorted) - TPs 272 | 273 | precs = TPs / (TPs + FPs) 274 | recs = TPs / (TPs + FNs) 275 | score = util.voc_ap(recs, precs) 276 | 277 | return (score, dict(ds=ds, sames=sames)) 278 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/networks_basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from . import pretrained_networks as pn 5 | 6 | from . import util as util 7 | 8 | 9 | def spatial_average(in_tens, keepdim=True): 10 | return in_tens.mean([2, 3], keepdim=keepdim) 11 | 12 | 13 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 14 | in_H = in_tens.shape[2] 15 | scale_factor = 1. * out_H / in_H 16 | 17 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 18 | 19 | 20 | # Learned perceptual metric 21 | class PNetLin(nn.Module): 22 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, 23 | version='0.1', lpips=True): 24 | super(PNetLin, self).__init__() 25 | 26 | self.pnet_type = pnet_type 27 | self.pnet_tune = pnet_tune 28 | self.pnet_rand = pnet_rand 29 | self.spatial = spatial 30 | self.lpips = lpips 31 | self.version = version 32 | self.scaling_layer = ScalingLayer() 33 | 34 | if (self.pnet_type in ['vgg', 'vgg16']): 35 | net_type = pn.vgg16 36 | self.chns = [64, 128, 256, 512, 512] 37 | elif (self.pnet_type == 'alex'): 38 | net_type = pn.alexnet 39 | self.chns = [64, 192, 384, 256, 256] 40 | elif (self.pnet_type == 'squeeze'): 41 | net_type = pn.squeezenet 42 | self.chns = [64, 128, 256, 384, 384, 512, 512] 43 | self.L = len(self.chns) 44 | 45 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 46 | 47 | if (lpips): 48 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 49 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 50 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 51 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 52 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 53 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 54 | if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet 55 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 56 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 57 | self.lins += [self.lin5, self.lin6] 58 | 59 | def forward(self, in0, in1, retPerLayer=False): 60 | # v0.0 - original release had a bug, where input was not scaled 61 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == '0.1' else ( 62 | in0, in1) 63 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 64 | feats0, feats1, diffs = {}, {}, {} 65 | 66 | for kk in range(self.L): 67 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 68 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 69 | 70 | if (self.lpips): 71 | if (self.spatial): 72 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 73 | else: 74 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 75 | else: 76 | if (self.spatial): 77 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 78 | else: 79 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] 80 | 81 | val = res[0] 82 | for l in range(1, self.L): 83 | val += res[l] 84 | 85 | if (retPerLayer): 86 | return (val, res) 87 | else: 88 | return val 89 | 90 | 91 | class ScalingLayer(nn.Module): 92 | def __init__(self): 93 | super(ScalingLayer, self).__init__() 94 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 95 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 96 | 97 | def forward(self, inp): 98 | return (inp - self.shift) / self.scale 99 | 100 | 101 | class NetLinLayer(nn.Module): 102 | ''' A single linear layer which does a 1x1 conv ''' 103 | 104 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 105 | super(NetLinLayer, self).__init__() 106 | 107 | layers = [nn.Dropout(), ] if (use_dropout) else [] 108 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 109 | self.model = nn.Sequential(*layers) 110 | 111 | 112 | class Dist2LogitLayer(nn.Module): 113 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 114 | 115 | def __init__(self, chn_mid=32, use_sigmoid=True): 116 | super(Dist2LogitLayer, self).__init__() 117 | 118 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), ] 119 | layers += [nn.LeakyReLU(0.2, True), ] 120 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), ] 121 | layers += [nn.LeakyReLU(0.2, True), ] 122 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), ] 123 | if (use_sigmoid): 124 | layers += [nn.Sigmoid(), ] 125 | self.model = nn.Sequential(*layers) 126 | 127 | def forward(self, d0, d1, eps=0.1): 128 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) 129 | 130 | 131 | class BCERankingLoss(nn.Module): 132 | def __init__(self, chn_mid=32): 133 | super(BCERankingLoss, self).__init__() 134 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 135 | # self.parameters = list(self.net.parameters()) 136 | self.loss = torch.nn.BCELoss() 137 | 138 | def forward(self, d0, d1, judge): 139 | per = (judge + 1.) / 2. 140 | self.logit = self.net.forward(d0, d1) 141 | return self.loss(self.logit, per) 142 | 143 | 144 | # L2, DSSIM metrics 145 | class FakeNet(nn.Module): 146 | def __init__(self, use_gpu=True, colorspace='Lab'): 147 | super(FakeNet, self).__init__() 148 | self.use_gpu = use_gpu 149 | self.colorspace = colorspace 150 | 151 | 152 | class L2(FakeNet): 153 | 154 | def forward(self, in0, in1, retPerLayer=None): 155 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 156 | 157 | if (self.colorspace == 'RGB'): 158 | (N, C, X, Y) = in0.size() 159 | value = torch.mean(torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), 160 | dim=3).view(N) 161 | return value 162 | elif (self.colorspace == 'Lab'): 163 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 164 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') 165 | ret_var = Variable(torch.Tensor((value,))) 166 | if (self.use_gpu): 167 | ret_var = ret_var.cuda() 168 | return ret_var 169 | 170 | 171 | class DSSIM(FakeNet): 172 | 173 | def forward(self, in0, in1, retPerLayer=None): 174 | assert (in0.size()[0] == 1) # currently only supports batchSize 1 175 | 176 | if (self.colorspace == 'RGB'): 177 | value = util.dssim(1. * util.tensor2im(in0.data), 1. * util.tensor2im(in1.data), range=255.).astype('float') 178 | elif (self.colorspace == 'Lab'): 179 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 180 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype( 181 | 'float') 182 | ret_var = Variable(torch.Tensor((value,))) 183 | if (self.use_gpu): 184 | ret_var = ret_var.cuda() 185 | return ret_var 186 | 187 | 188 | def print_network(net): 189 | num_params = 0 190 | for param in net.parameters(): 191 | num_params += param.numel() 192 | print('Network', net) 193 | print('Total number of parameters: %d' % num_params) 194 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2, 5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5', 'relu6', 'relu7']) 52 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | 98 | class vgg16(torch.nn.Module): 99 | def __init__(self, requires_grad=False, pretrained=True): 100 | super(vgg16, self).__init__() 101 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.N_slices = 5 108 | for x in range(4): 109 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(4, 9): 111 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(9, 16): 113 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(16, 23): 115 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(23, 30): 117 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h = self.slice1(X) 124 | h_relu1_2 = h 125 | h = self.slice2(h) 126 | h_relu2_2 = h 127 | h = self.slice3(h) 128 | h_relu3_3 = h 129 | h = self.slice4(h) 130 | h_relu4_3 = h 131 | h = self.slice5(h) 132 | h_relu5_3 = h 133 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 134 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 135 | 136 | return out 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if (num == 18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif (num == 34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif (num == 50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif (num == 101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif (num == 152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1', 'conv2', 'conv3', 'conv4', 'conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage.measure import compare_ssim 4 | 5 | 6 | def normalize_tensor(in_feat, eps=1e-10): 7 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 8 | return in_feat / (norm_factor + eps) 9 | 10 | 11 | def l2(p0, p1, range=255.): 12 | return .5 * np.mean((p0 / range - p1 / range) ** 2) 13 | 14 | 15 | def psnr(p0, p1, peak=255.): 16 | return 10 * np.log10(peak ** 2 / np.mean((1. * p0 - 1. * p1) ** 2)) 17 | 18 | 19 | def dssim(p0, p1, range=255.): 20 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 21 | 22 | 23 | def rgb2lab(in_img, mean_cent=False): 24 | from skimage import color 25 | img_lab = color.rgb2lab(in_img) 26 | if (mean_cent): 27 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 28 | return img_lab 29 | 30 | 31 | def tensor2np(tensor_obj): 32 | # change dimension of a tensor object into a numpy array 33 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 34 | 35 | 36 | def np2tensor(np_obj): 37 | # change dimenion of np array into tensor array 38 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 39 | 40 | 41 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 42 | # image tensor to lab tensor 43 | from skimage import color 44 | 45 | img = tensor2im(image_tensor) 46 | img_lab = color.rgb2lab(img) 47 | if (mc_only): 48 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 49 | if (to_norm and not mc_only): 50 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 51 | img_lab = img_lab / 100. 52 | 53 | return np2tensor(img_lab) 54 | 55 | 56 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 57 | from skimage import color 58 | import warnings 59 | warnings.filterwarnings("ignore") 60 | 61 | lab = tensor2np(lab_tensor) * 100. 62 | lab[:, :, 0] = lab[:, :, 0] + 50 63 | 64 | rgb_back = 255. * np.clip(color.lab2rgb(lab.astype('float')), 0, 1) 65 | if (return_inbnd): 66 | # convert back to lab, see if we match 67 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 68 | mask = 1. * np.isclose(lab_back, lab, atol=2.) 69 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 70 | return (im2tensor(rgb_back), mask) 71 | else: 72 | return im2tensor(rgb_back) 73 | 74 | 75 | def rgb2lab(input): 76 | from skimage import color 77 | return color.rgb2lab(input / 255.) 78 | 79 | 80 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 81 | image_numpy = image_tensor[0].cpu().float().numpy() 82 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 83 | return image_numpy.astype(imtype) 84 | 85 | 86 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 87 | return torch.Tensor((image / factor - cent) 88 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 89 | 90 | 91 | def tensor2vec(vector_tensor): 92 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 93 | 94 | 95 | def voc_ap(rec, prec, use_07_metric=False): 96 | """ ap = voc_ap(rec, prec, [use_07_metric]) 97 | Compute VOC AP given precision and recall. 98 | If use_07_metric is true, uses the 99 | VOC 07 11 point method (default:False). 100 | """ 101 | if use_07_metric: 102 | # 11 point metric 103 | ap = 0. 104 | for t in np.arange(0., 1.1, 0.1): 105 | if np.sum(rec >= t) == 0: 106 | p = 0 107 | else: 108 | p = np.max(prec[rec >= t]) 109 | ap = ap + p / 11. 110 | else: 111 | # correct AP calculation 112 | # first append sentinel values at the end 113 | mrec = np.concatenate(([0.], rec, [1.])) 114 | mpre = np.concatenate(([0.], prec, [0.])) 115 | 116 | # compute the precision envelope 117 | for i in range(mpre.size - 1, 0, -1): 118 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 119 | 120 | # to calculate area under PR curve, look for points 121 | # where X axis (recall) changes value 122 | i = np.where(mrec[1:] != mrec[:-1])[0] 123 | 124 | # and sum (\Delta recall) * prec 125 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 126 | return ap 127 | 128 | 129 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255. / 2.): 130 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 131 | image_numpy = image_tensor[0].cpu().float().numpy() 132 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 133 | return image_numpy.astype(imtype) 134 | 135 | 136 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255. / 2.): 137 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 138 | return torch.Tensor((image / factor - cent) 139 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 140 | -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /utils/PerceptualSimilarity/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gistvision/DIP-denosing/853faac97a451e6430b47f4d4da54c6d08a7ee50/utils/PerceptualSimilarity/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /utils/REDutils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import torch 4 | import PIL 5 | from PIL import Image 6 | from skimage.measure import compare_psnr 7 | 8 | 9 | # ---- Scaling image ---- # 10 | def pil_resize(pil_img, factor, downscale=True): 11 | if downscale: 12 | new_size = [pil_img.size[0] // factor, pil_img.size[1] // factor] 13 | else: 14 | new_size = [pil_img.size[0] * factor, pil_img.size[1] * factor] 15 | new_pil_img = pil_img.resize(new_size, Image.ANTIALIAS) 16 | return new_pil_img, pil_to_np(new_pil_img) 17 | 18 | 19 | # ----------- gauss kernel ----------- 20 | def fspecial_gauss(size, sigma): 21 | """Function to mimic the 'fspecial' gaussian MATLAB function 22 | """ 23 | x, y = np.mgrid[-size // 2 + 1:size // 2 + 1, -size // 2 + 1:size // 2 + 1] 24 | g = np.exp(-((x ** 2 + y ** 2) / (2.0 * sigma ** 2))) 25 | return g / g.sum() 26 | 27 | 28 | # -------- Load image and crop it if needed ------ # 29 | def load_and_crop_image(fname, d=1): 30 | """Make dimensions divisible by `d`""" 31 | img = Image.open(fname) 32 | if d == 1: return img, pil_to_np(img) 33 | new_size = (img.size[0] - img.size[0] % d, 34 | img.size[1] - img.size[1] % d) 35 | if new_size[0] == img.size[0] and new_size[1] == img.size[1]: 36 | return img, pil_to_np(img) 37 | bbox = [ 38 | int((img.size[0] - new_size[0]) / 2), 39 | int((img.size[1] - new_size[1]) / 2), 40 | int((img.size[0] + new_size[0]) / 2), 41 | int((img.size[1] + new_size[1]) / 2), 42 | ] 43 | img_cropped = img.crop(bbox) 44 | return img_cropped, pil_to_np(img_cropped) 45 | 46 | 47 | # ------- Working with numpy / pil / torch images auxiliary functions -------- # 48 | def save_np(np_img, file, ext='.png'): 49 | """ saves a numpy image as png (default) """ 50 | pil_img = np_to_pil(np_img) 51 | pil_img.save(file + ext) 52 | 53 | 54 | # ---------- compare_psnr ------------ 55 | def compare_PSNR(org, est, on_y=False, gray_scale=False): 56 | assert (on_y==False or gray_scale==False), "Is your image RGB or gray? please choose and try again" 57 | if on_y: 58 | return compare_psnr_y(np_to_pil(org), np_to_pil(est)) 59 | if gray_scale: 60 | return compare_psnr(np.mean(org, axis=0), np.mean(est, axis=0)) 61 | return compare_psnr(org, est) 62 | 63 | 64 | def load_and_compare_psnr(fclean, fnoisy, crop_factor=1, on_y=False, eng=None): 65 | # matlab: 66 | if eng is not None: 67 | return eng.compare_psnr_y("../" + fclean, "../" + fnoisy, on_y, nargout=1) 68 | # load: 69 | _, img_np = load_and_crop_image(fclean, crop_factor) 70 | _, img_noisy_np = load_and_crop_image(fnoisy, crop_factor) 71 | # rgba -> rgb 72 | if img_np.shape[0] == 4: img_np = img_np[:3, :, :] 73 | if img_noisy_np.shape[0] == 4: img_noisy_np = img_noisy_np[:3, :, :] 74 | return compare_PSNR(img_np, img_noisy_np, on_y=on_y) 75 | 76 | 77 | def get_p_signal(im): 78 | return 10 * np.log10(np.mean(np.square(im))) 79 | 80 | 81 | def compare_SNR(im_true, im_test): 82 | return compare_psnr(im_true, im_test, 1) + get_p_signal(im_true) 83 | 84 | 85 | def rgb2ycbcr(img): 86 | """ 87 | Image to Y (ycbcr) 88 | Input: 89 | PIL IMAGE, in range [0, 255] 90 | Output: 91 | Numpy Y Ch. in range [0, 1] 92 | """ 93 | y = np.array(img, np.float32) 94 | if len(y.shape) == 3 and y.shape[2] == 3: 95 | y = np.dot(y, [65.481, 128.553, 24.966]) / 255.0 + 16.0 96 | return y.round() / 255.0 97 | 98 | 99 | def rgb2gray(img): 100 | """ 101 | RGB image to gray scale 102 | Input: 103 | PIL IMAGE, in range [0, 255] 104 | Output: 105 | Numpy 3 x Gray Scale in range [0, 1] 106 | Following the matlab code at: https://www.mathworks.com/help/matlab/ref/rgb2gray.html 107 | The formula: 0.2989 * R + 0.5870 * G + 0.1140 * B 108 | """ 109 | img = np.array(img, np.float32) 110 | if len(img.shape) == 3 and img.shape[2] == 3: 111 | img = np.dot(img, [0.2989, 0.5870, 0.1140]) 112 | return np.array([img.round() / 255.0]*3, dtype=np.float32) 113 | 114 | 115 | def compare_psnr_y(org_pil, est_pil): 116 | return compare_psnr(rgb2ycbcr(org_pil), rgb2ycbcr(est_pil)) 117 | 118 | 119 | # - transformation functions pil <-> numpy <-> torch 120 | def pil_to_np(img_PIL): 121 | """Converts image in PIL format to np.array. 122 | 123 | From W x H x C [0...255] to C x W x H [0..1] 124 | """ 125 | ar = np.array(img_PIL, np.float32) 126 | 127 | if len(ar.shape) == 3: 128 | ar = ar.transpose(2, 0, 1) 129 | else: 130 | ar = ar[None, ...] 131 | 132 | return ar / 255. 133 | 134 | 135 | def np_to_pil(img_np): 136 | """Converts image in np.array format to PIL image. 137 | 138 | From C x W x H [0..1] to W x H x C [0...255] 139 | """ 140 | ar = np.clip(np.rint(img_np * 255), 0, 255).astype(np.uint8) 141 | 142 | if img_np.shape[0] == 1: 143 | ar = ar[0] 144 | else: 145 | ar = ar.transpose(1, 2, 0) 146 | 147 | return Image.fromarray(ar) 148 | 149 | 150 | def np_to_torch(img_np): 151 | """Converts image in numpy.array to torch.Tensor. 152 | 153 | From C x W x H [0..1] to C x W x H [0..1] 154 | """ 155 | return torch.from_numpy(img_np)[None, :] 156 | 157 | 158 | def torch_to_np(img_var): 159 | """Converts an image in torch.Tensor format to np.array. 160 | 161 | From 1 x C x W x H [0..1] to C x W x H [0..1] 162 | """ 163 | return img_var.detach().cpu().numpy()[0] 164 | 165 | 166 | def put_in_center(img_np, target_size): 167 | img_out = np.zeros([3, target_size[0], target_size[1]]) 168 | 169 | bbox = [ 170 | int((target_size[0] - img_np.shape[1]) / 2), 171 | int((target_size[1] - img_np.shape[2]) / 2), 172 | int((target_size[0] + img_np.shape[1]) / 2), 173 | int((target_size[1] + img_np.shape[2]) / 2), 174 | ] 175 | 176 | img_out[:, bbox[0]:bbox[2], bbox[1]:bbox[3]] = img_np 177 | 178 | return img_out 179 | 180 | 181 | # --------- get noise ---------- # 182 | def fill_noise(x, noise_type): 183 | """Fills tensor `x` with noise of type `noise_type`.""" 184 | if noise_type == 'u': 185 | x.uniform_() 186 | elif noise_type == 'n': 187 | x.normal_() 188 | else: 189 | assert False 190 | 191 | 192 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1. / 10): 193 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 194 | initialized in a specific way. 195 | Args: 196 | input_depth: number of channels in the tensor 197 | method: `noise` for filling tensor with noise; `meshgrid` for np.meshgrid 198 | spatial_size: spatial size of the tensor to initialize 199 | noise_type: 'u' for uniform; 'n' for normal 200 | var: a factor, a noise will be multiplied by. Basically it is standard deviation scalar. 201 | """ 202 | if isinstance(spatial_size, int): 203 | spatial_size = (spatial_size, spatial_size) 204 | if method == 'noise': 205 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 206 | net_input = torch.zeros(shape) 207 | 208 | fill_noise(net_input, noise_type) 209 | net_input *= var 210 | elif method == 'meshgrid': 211 | assert input_depth == 2 212 | X, Y = np.meshgrid(np.arange(0, spatial_size[1]) / float(spatial_size[1] - 1), 213 | np.arange(0, spatial_size[0]) / float(spatial_size[0] - 1)) 214 | meshgrid = np.concatenate([X[None, :], Y[None, :]]) 215 | net_input = np_to_torch(meshgrid) 216 | else: 217 | assert False 218 | return net_input 219 | 220 | 221 | # ---------- plot functions ---------- 222 | def plot_dict(data_dict): 223 | i, columns = 0, len(data_dict) 224 | scale = columns * 10 # you can play with it 225 | plt.figure(figsize=(scale, scale)) 226 | for key, data in data_dict.items(): 227 | i, ax = i + 1, plt.subplot(1, columns, i + 1) 228 | plt.imshow(np_to_pil(data.img), cmap='gray') 229 | ax.text(0.5, -0.15, key + (" psnr: %.2f" % (data.psnr) if data.psnr is not None else ""), 230 | size=36, ha="center", transform=ax.transAxes) 231 | plt.show() 232 | 233 | 234 | def matplot_plot_graphs(graphs, x_labels, y_labels): 235 | total = len(graphs) 236 | for i, graph in enumerate(graphs): 237 | plt.figure(figsize=(25, 6)) 238 | ax = plt.subplot(1, total, i + 1) 239 | plt.plot(graph) 240 | plt.xlabel(x_labels[i]) 241 | plt.ylabel(y_labels[i], multialignment='center') 242 | plt.show() 243 | 244 | 245 | # -------- numpy gray to color ----- 246 | def np_gray_to_color(img): 247 | """ 1 x w x h => 3 x w x h 248 | """ 249 | img = np.stack([img, img, img], ) 250 | return img 251 | 252 | 253 | # ------- used for bokeh plots ------- 254 | def np_to_rgba(np_img): 255 | """ ch x w x h => W x H x (ch+1), for alpha 256 | """ 257 | img = np_img.transpose(1, 2, 0) 258 | if img.shape[2] == 3: # 3D image (3, w, h) 259 | img = 255 * np.dstack([img, np.ones(img.shape[:2])]) 260 | else: # 2D image (1, w, h) 261 | img = 255 * np.dstack([img, img, img, np.ones(img.shape[:2])]) 262 | return img.astype(np.uint8) 263 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common_utils import * 2 | from .denoising_utils import * 3 | -------------------------------------------------------------------------------- /utils/blur_utils.py: -------------------------------------------------------------------------------- 1 | from .REDutils import * 2 | from models.downsampler import Downsampler 3 | 4 | 5 | # - blur image - exactly like the NCSR is doing it - 6 | def get_fft_h(im, blur_type): 7 | assert blur_type in ['uniform_blur', 'gauss_blur'], "blur_type can be or 'uniform' or 'gauss'" 8 | ch, h, w = im.shape 9 | fft_h = np.zeros((h,w),) 10 | if blur_type=='uniform_blur': 11 | t = 4 # 9//2 12 | fft_h[h//2-t:h//2+1+t, w//2-t:w//2+1+t] = 1/81 13 | fft_h = np.fft.fft2(np.fft.fftshift(fft_h)) 14 | else: # gauss_blur 15 | psf = fspecial_gauss(25, 1.6) 16 | t = 12 # 25 // 2 17 | fft_h[h//2-t:h//2+1+t, w//2-t:w//2+1+t] = psf 18 | fft_h = np.fft.fft2(np.fft.fftshift(fft_h)) 19 | return fft_h 20 | 21 | 22 | def blur(im, blur_type): 23 | fft_h = get_fft_h(im, blur_type) 24 | imout = np.zeros_like(im) 25 | for i in range(im.shape[0]): 26 | im_f = np.fft.fft2(im[i, :, :]) 27 | z_f = fft_h*im_f # .* of matlab 28 | z = np.real(np.fft.ifft2(z_f)) 29 | imout[i, :, :] = z 30 | return imout 31 | 32 | 33 | # - the inverse function H - 34 | def get_h(n_ch, blur_type, use_fourier, dtype): 35 | assert blur_type in ['uniform_blur', 'gauss_blur'], "blur_type can be or 'uniform' or 'gauss'" 36 | if not use_fourier: 37 | return Downsampler(n_ch, 1, blur_type, preserve_size=True).type(dtype) 38 | return lambda im: torch_blur(im, blur_type, dtype) 39 | 40 | 41 | def torch_blur(im, blur_type, dtype): 42 | fft_h = get_fft_h(torch_to_np(im), blur_type) 43 | fft_h_torch = torch.unsqueeze(torch.from_numpy(np.real(fft_h)).type(dtype), 2) 44 | fft_h_torch = torch.cat([fft_h_torch, fft_h_torch], 2) 45 | z = [] 46 | for i in range(im.shape[1]): 47 | im_torch = torch.unsqueeze(im[0, i, :, :], 2) 48 | im_torch = torch.cat([im_torch, im_torch], 2) 49 | im_f = torch.fft(im_torch, 2) 50 | z_f = torch.mul(torch.unsqueeze(fft_h_torch, 0), torch.unsqueeze(im_f, 0)) # .* of matlab 51 | z.append(torch.ifft(z_f, 2)) 52 | z = torch.cat(z, 0) 53 | return torch.unsqueeze(z[:, :, :, 0], 0) 54 | -------------------------------------------------------------------------------- /utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from utils.blur_utils import * # blur functions 4 | import torch 5 | import numpy as np 6 | 7 | def fill_noise(x, noise_type): 8 | """Fills tensor `x` with noise of type `noise_type`.""" 9 | if noise_type == 'u': 10 | x.uniform_() 11 | elif noise_type == 'n': 12 | x.normal_() 13 | else: 14 | assert False 15 | 16 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 17 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 18 | initialized in a specific way. 19 | Args: 20 | input_depth: number of channels in the tensor 21 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 22 | spatial_size: spatial size of the tensor to initialize 23 | noise_type: 'u' for uniform; 'n' for normal 24 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 25 | """ 26 | if isinstance(spatial_size, int): 27 | spatial_size = (spatial_size, spatial_size) 28 | if method == 'noise': 29 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 30 | net_input = torch.zeros(shape) 31 | 32 | fill_noise(net_input, noise_type) 33 | net_input *= var 34 | elif method == 'meshgrid': 35 | assert input_depth == 2 36 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 37 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 38 | net_input= np_to_torch(meshgrid) 39 | else: 40 | assert False 41 | 42 | return net_input 43 | 44 | def np_to_torch(img_np): 45 | '''Converts image in numpy.array to torch.Tensor. 46 | 47 | From C x W x H [0..1] to C x W x H [0..1] 48 | ''' 49 | return torch.from_numpy(img_np)[None, :] 50 | 51 | def torch_to_np(img_var): 52 | '''Converts an image in torch.Tensor format to np.array. 53 | 54 | From 1 x C x W x H [0..1] to C x W x H [0..1] 55 | ''' 56 | return img_var.detach().cpu().numpy()[0] 57 | 58 | def cv2_to_torch(cv2_img, dtype=None): 59 | if dtype == None: 60 | out = np_to_torch(cv2_img).float() / 255.0 61 | else: 62 | out = np_to_torch(cv2_img).type(dtype) / 255.0 63 | return out 64 | 65 | def torch_to_cv2(torch_tensor, clip=True): 66 | out = torch_to_np(torch_tensor) 67 | if clip: 68 | out = np.clip(out, 0, 1) 69 | return np.squeeze(out * 255).astype(np.uint8) 70 | 71 | def save_CHW_np(fname, CHW): 72 | if len(CHW.shape) == 2: 73 | cv2.imwrite(fname, CHW) 74 | else: 75 | cv2.imwrite(fname, CHW.transpose([1, 2, 0])) 76 | 77 | def load_image_pair(fname, task, args): 78 | """ 79 | 1. select degradation to remove. 80 | We follow the notion (Y = H * X + N) 81 | :return: X Original image, Y degradation image, 82 | """ 83 | if task == "deblur": 84 | clean_img, degradation_img = deblur_loader(fname, args.blur_type, args.sigma) 85 | elif task == "denoising": 86 | clean_img, degradation_img = denoise_loader(fname, args.sigma) 87 | elif task == "poisson": 88 | clean_img, degradation_img = poisson_loader(fname, args.scale) 89 | else: 90 | raise NotImplementedError 91 | 92 | print("[!] clean image domain : [%.2f, %.2f]" %(clean_img.min(), clean_img.max())) 93 | print("[!] noisy image domain : [%.2f, %.2f]" %(degradation_img.min(), degradation_img.max())) 94 | return clean_img, degradation_img 95 | 96 | 97 | def poisson_loader(fname, scale): 98 | img_np = read_image_np(fname) 99 | img_noisy_np = scale*np.random.poisson(img_np/255.0/scale)* 255.0 100 | return img_np, img_noisy_np 101 | 102 | def denoise_loader(fname, sigma): 103 | img_np = read_image_np(fname) 104 | if "mean" in fname: 105 | img_noisy_np = read_image_np(fname.replace("mean", "real")) 106 | else: 107 | img_noisy_np = img_np + np.random.randn(*img_np.shape) * sigma 108 | return img_np, img_noisy_np 109 | 110 | def read_image_np(path): 111 | """ 112 | :param path: image file name. 113 | :param gray: Check whether image is gray or color. 114 | :return: 115 | """ 116 | img_np = cv2.imread(path, -1) 117 | if len(img_np.shape) == 2: 118 | print("[*] read GRAY image.") 119 | img_np = img_np[np.newaxis,:] 120 | else: 121 | print("[*] read COLOR image.") 122 | img_np = img_np.transpose([2, 0, 1]) 123 | return img_np.astype(np.float) # to added noise. 124 | 125 | def read_noise_np(path, sigma): 126 | # read noise instance same as eSURE. 127 | try: 128 | dir_name = os.path.join(os.path.dirname(path), "sigma%s" % sigma) 129 | file_name = path.split("/")[-1][:-4] 130 | file_name = file_name + ".npy" 131 | new_path = os.path.join(dir_name, file_name) 132 | noisy_np = np.load(new_path)[0].transpose([2, 0, 1]) 133 | except: 134 | raise FileNotFoundError 135 | return noisy_np 136 | 137 | def deblur_loader(fname, blur_type, noise_sigma, GRAY_SCALE = False): 138 | """ Loads an image, and add gaussian blur 139 | Args: 140 | fname: path to the image 141 | blur_type: 'uniform' or 'gauss' 142 | noise_sigma: noise added after blur 143 | covert2gray: should we convert to gray scale image? 144 | plot: will plot the images 145 | Out: 146 | dictionary of images and dictionary of psnrs 147 | """ 148 | BLUR_TYPE = 'gauss_blur' if blur_type == 'g' else 'uniform_blur' 149 | img_np = read_image_np(fname) # loadload_and_crop_image img_pil, 150 | # if GRAY_SCALE: 151 | # img_np = rgb2gray(img_pil) 152 | blurred = blur(img_np, BLUR_TYPE) # blur, and the line below adds noise 153 | blurred = blurred + np.random.normal(scale=noise_sigma, size=blurred.shape) 154 | return img_np, blurred 155 | 156 | if __name__ == "__main__": 157 | fname = "./testset/CBSD68/3096.png" 158 | gt, noise = denoise_loader(fname, 10) 159 | print(gt.shape, gt.min(), gt.max()) 160 | print(noise.shape, noise.min(), noise.max()) 161 | fname = "./testset/BSD68/test001.png" 162 | gt, noise = denoise_loader(fname, 10) 163 | print(gt.shape, gt.min(), gt.max()) 164 | print(noise.shape, noise.min(), noise.max()) 165 | 166 | -------------------------------------------------------------------------------- /utils/denoising_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from utils.common_utils import * 3 | from skimage.restoration import denoise_nl_means 4 | from skimage.metrics import structural_similarity 5 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 6 | 7 | def get_noisy_image(img_np, sigma): 8 | """Adds Gaussian noise to an image. 9 | 10 | Args: 11 | img_np: image, np.array with values from 0 to 1 12 | sigma: std of the noise 13 | """ 14 | img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32) 15 | img_noisy_pil = np_to_pil(img_noisy_np) 16 | 17 | return img_noisy_pil, img_noisy_np 18 | 19 | 20 | def non_local_means(noisy_np_img, sigma, fast_mode=True): 21 | """ get a numpy noisy image 22 | returns a denoised numpy image using Non-Local-Means 23 | """ 24 | sigma = sigma / 255. 25 | h = 0.6 * sigma if fast_mode else 0.8 * sigma 26 | patch_kw = dict(h=h, # Cut-off distance, a higher h results in a smoother image 27 | sigma=sigma, # sigma provided 28 | fast_mode=fast_mode, # If True, a fast version is used. If False, the original version is used. 29 | patch_size=5, # 5x5 patches (Size of patches used for denoising.) 30 | patch_distance=6, # 13x13 search area 31 | multichannel=False) 32 | denoised_img = [] 33 | n_channels = noisy_np_img.shape[0] 34 | for c in range(n_channels): 35 | denoise_fast = denoise_nl_means(noisy_np_img[c, :, :], **patch_kw) 36 | denoised_img += [denoise_fast] 37 | return np.array(denoised_img, dtype=np.float32) 38 | 39 | def compare_ssim(a, b): 40 | if a.shape[0] == 3: 41 | a = np.mean(a, axis=0) 42 | b = np.mean(b, axis=0) 43 | elif a.shape[2] == 3: 44 | a = np.mean(a, axis=2) 45 | b = np.mean(b, axis=2) 46 | else: 47 | a,b = a[0], b[0] 48 | return structural_similarity(a,b) 49 | 50 | 51 | 52 | import math 53 | import cv2 54 | # ---------- 55 | # PSNR 56 | # ---------- 57 | def calculate_psnr(img1, img2, border=0): 58 | # img1 and img2 have range [0, 255] 59 | img1 = np.squeeze(img1) 60 | img2 = np.squeeze(img2) 61 | if not img1.shape == img2.shape: 62 | raise ValueError('Input images must have the same dimensions.') 63 | h, w = img1.shape[:2] 64 | img1 = img1[border:h-border, border:w-border] 65 | img2 = img2[border:h-border, border:w-border] 66 | 67 | img1 = img1.astype(np.float64) 68 | img2 = img2.astype(np.float64) 69 | mse = np.mean((img1 - img2)**2) 70 | if mse == 0: 71 | return float('inf') 72 | return 20 * math.log10(255.0 / math.sqrt(mse)) 73 | 74 | 75 | # ---------- 76 | # SSIM 77 | # ---------- 78 | def calculate_ssim(img1, img2, border=0): 79 | '''calculate SSIM 80 | the same outputs as MATLAB's 81 | img1, img2: [0, 255] 82 | ''' 83 | img1 = np.squeeze(img1) 84 | img2 = np.squeeze(img2) 85 | if not img1.shape == img2.shape: 86 | raise ValueError('Input images must have the same dimensions.') 87 | h, w = img1.shape[:2] 88 | img1 = img1[border:h-border, border:w-border] 89 | img2 = img2[border:h-border, border:w-border] 90 | 91 | if img1.ndim == 2: 92 | return ssim(img1, img2) 93 | elif img1.ndim == 3: 94 | if img1.shape[0] == 3: 95 | ssims = [] 96 | for i in range(3): 97 | ssims.append(ssim(img1[i], img2[i])) 98 | return np.array(ssims).mean() 99 | elif img1.shape[0] == 1: 100 | return ssim(np.squeeze(img1), np.squeeze(img2)) 101 | else: 102 | raise ValueError('Wrong input image dimensions.') 103 | 104 | 105 | def ssim(img1, img2): 106 | C1 = (0.01 * 255)**2 107 | C2 = (0.03 * 255)**2 108 | 109 | img1 = img1.astype(np.float64) 110 | img2 = img2.astype(np.float64) 111 | kernel = cv2.getGaussianKernel(11, 1.5) 112 | window = np.outer(kernel, kernel.transpose()) 113 | 114 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 115 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 116 | mu1_sq = mu1**2 117 | mu2_sq = mu2**2 118 | mu1_mu2 = mu1 * mu2 119 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 120 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 121 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 122 | 123 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 124 | (sigma1_sq + sigma2_sq + C2)) 125 | return ssim_map.mean() 126 | 127 | 128 | from .PerceptualSimilarity import PerceptualLoss 129 | 130 | def get_lpips(device="cuda"): 131 | return PerceptualLoss(model='net-lin', net='alex', use_gpu=(device == 'cuda')) 132 | 133 | def calculate_lpips(img1_, img2_, LPIPS= None, device="cuda", color= "BGR"): 134 | if img1_.shape[0] < 3: 135 | make_color = lambda x: cv2.cvtColor(x, cv2.COLOR_GRAY2BGR) 136 | img1 = make_color(img1_[0].astype(np.uint8)) 137 | img2 = make_color(img2_.astype(np.uint8)) 138 | else: 139 | img1 = img1_.transpose([1,2,0]).astype(np.uint8) 140 | img2 = img2_.transpose([1,2,0]).astype(np.uint8) 141 | if LPIPS is None: 142 | LPIPS = PerceptualLoss(model='net-lin', net='alex', use_gpu=(device == 'cuda')) 143 | if color == "BGR": 144 | img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB) 145 | img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB) 146 | img1 = torch.tensor(img1.transpose([2,0,1])) / 255.0 147 | img2 = torch.tensor(img2.transpose([2,0,1])) / 255.0 148 | if device == "cuda": 149 | img1 = img1.cuda() 150 | img2 = img2.cuda() 151 | return LPIPS(img1, img2, normalize=True).item() 152 | 153 | 154 | if __name__ == "__main__": 155 | import numpy as np 156 | img1 = np.random.rand(255,255) 157 | img2 = np.random.rand(255,255) 158 | print(compare_psnr(img1, img2)) 159 | print(compare_ssim(img1, img2)) 160 | min_ = min(img1.min(), img2.min()) 161 | max_ = max(img1.max(), img2.max()) 162 | img1 = ((img1 - min_) / (max_ - min_) * 255).astype(np.uint8) 163 | img2 = ((img2 - min_) / (max_ - min_) * 255).astype(np.uint8) 164 | print(compare_psnr(img1, img2), calculate_psnr(img1, img2)) 165 | print(compare_ssim(img1, img2), calculate_ssim(img1, img2)) 166 | -------------------------------------------------------------------------------- /utils/parse_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | from argparse import ArgumentParser 3 | 4 | def load_parse(f_name): 5 | if not '.' in f_name: 6 | f_name += '.json' 7 | parser = ArgumentParser() 8 | args = parser.parse_args() 9 | with open(f_name, 'r') as f: 10 | args.__dict__ = json.load(f) 11 | return args 12 | 13 | def save_parse(f_name, args): 14 | if not '.' in f_name: 15 | f_name += '.json' 16 | with open(f_name, 'w') as f: 17 | json.dump(args.__dict__, f, indent=2) 18 | 19 | 20 | if __name__ =='__main__': 21 | parser = ArgumentParser() 22 | parser.add_argument('--seed', type=int, default=8) 23 | parser.add_argument('--resume', type=str, default='a/b/c.ckpt') 24 | parser.add_argument('--surgery', type=str, default='190', choices=['190', '417']) 25 | args = parser.parse_args() 26 | 27 | # save_parse("test", args) 28 | # args_t = load_parse("test") 29 | print(args) 30 | print(args_t) --------------------------------------------------------------------------------