├── .DS_Store ├── DIP ├── .DS_Store ├── __init__.py ├── data │ ├── .DS_Store │ ├── denoising │ │ └── .gitignore │ ├── inpainting │ │ └── .gitignore │ └── sr │ │ └── .gitignore ├── denoising-test.py ├── denoising.py ├── inpainting-test.py ├── inpainting.py ├── models │ ├── .DS_Store │ ├── __init__.py │ ├── common.py │ ├── common_test.py │ ├── cross_skip.py │ ├── downsampler.py │ ├── gen_upsample_layer.py │ ├── model_denoising.py │ ├── model_inpainting.py │ ├── model_sr.py │ ├── ref.py │ ├── resnet.py │ ├── skip.py │ ├── skip_search_up.py │ ├── texture_nets.py │ ├── unet.py │ └── unet_search_up.py ├── super-resolution-test.py ├── super-resolution.py └── utils │ ├── __init__.py │ ├── common_utils.py │ ├── denoising_utils.py │ ├── feature_inversion_utils.py │ ├── inpainting_utils.py │ ├── load_image.py │ ├── matcher.py │ ├── perceptual_loss │ ├── __init__.py │ ├── matcher.py │ ├── perceptual_loss.py │ └── vgg_modified.py │ ├── sr_utils.py │ └── timer.py ├── NAS ├── __init__.py ├── demo.py ├── gen_id.py ├── gen_upsample_layer-prev.py ├── gen_upsample_layer.py ├── genotypes-prev.py ├── genotypes.py ├── index_to_model_mapping.log ├── model.py ├── model_gen.py ├── operations-prev.py ├── operations.py └── utils.py ├── README.md └── img └── teaser.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/.DS_Store -------------------------------------------------------------------------------- /DIP/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/.DS_Store -------------------------------------------------------------------------------- /DIP/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/__init__.py -------------------------------------------------------------------------------- /DIP/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/data/.DS_Store -------------------------------------------------------------------------------- /DIP/data/denoising/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/data/denoising/.gitignore -------------------------------------------------------------------------------- /DIP/data/inpainting/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/data/inpainting/.gitignore -------------------------------------------------------------------------------- /DIP/data/sr/.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/data/sr/.gitignore -------------------------------------------------------------------------------- /DIP/denoising-test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import ipdb 9 | import random 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | from skimage.measure import compare_psnr 14 | 15 | from utils.denoising_utils import * 16 | from utils.timer import Timer 17 | 18 | import torch 19 | import torch.optim 20 | 21 | import warnings 22 | warnings.filterwarnings("ignore") 23 | 24 | torch.backends.cudnn.enabled = True 25 | torch.backends.cudnn.benchmark = True 26 | torch.backends.cudnn.deterministic = True 27 | dtype = torch.cuda.FloatTensor 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='NAS-DIP Denoising') 31 | 32 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 33 | parser.add_argument('--num_iter', dest='num_iter', default=3000, type=int) 34 | parser.add_argument('--show_every', dest='show_every', default=50, type=int) 35 | parser.add_argument('--lr', dest='lr', default=0.01, type=float) 36 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 37 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 38 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 39 | parser.add_argument('--output_path', dest='output_path',default='results/denoising', type=str) 40 | parser.add_argument('--batch_size', dest='batch_size',default=1, type=int) 41 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=1./30., type=float) 42 | parser.add_argument('--sigma', dest='sigma', default=25, type=float) 43 | parser.add_argument('--save_png', dest='save_png', default=0, type=int) 44 | parser.add_argument('--exp_weight', dest='exp_weight', default=0.99, type=float) 45 | parser.add_argument('--image_name', type=str) 46 | 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | if __name__ == '__main__': 52 | 53 | args = parse_args() 54 | 55 | img_path = 'data/denoising/' + args.image_name 56 | 57 | img_pil = crop_image(get_image(img_path, -1)[0], 32) 58 | img_np = pil_to_np(img_pil) 59 | 60 | img_noisy_pil, img_noisy_np = get_noisy_image(img_np, args.sigma / 255.) 61 | 62 | from models.model_denoising import Model 63 | net = Model() 64 | 65 | net = net.type(dtype) 66 | 67 | net_input = get_noise(args.input_depth, args.noise_method, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() 68 | 69 | mse = torch.nn.MSELoss().type(dtype) 70 | 71 | img_noisy_torch = np_to_torch(img_noisy_np).type(dtype) 72 | 73 | net_input_saved = net_input.detach().clone() 74 | noise = net_input.detach().clone() 75 | out_avg = None 76 | last_net = None 77 | psnr_noisy_last = 0 78 | psnr_gt_best = 0 79 | 80 | i = 0 81 | PSNR_list = [] 82 | 83 | _t = {'im_detect' : Timer(), 'misc' : Timer()} 84 | 85 | def closure(): 86 | 87 | global i, out_avg, psnr_noisy_last, last_net, net_input, psnr_gt_best, PSNR_list 88 | 89 | _t['im_detect'].tic() 90 | 91 | if args.reg_noise_std > 0: 92 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 93 | 94 | out = net(net_input) 95 | 96 | if out_avg is None: 97 | out_avg = out.detach() 98 | else: 99 | out_avg = out_avg * args.exp_weight + out.detach() * (1 - args.exp_weight) 100 | 101 | total_loss = mse(out, img_noisy_torch) 102 | total_loss.backward() 103 | 104 | psnr_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) 105 | psnr_gt = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) 106 | 107 | PSNR_list.append(psnr_gt) 108 | 109 | if psnr_gt > psnr_gt_best: 110 | psnr_gt_best = psnr_gt 111 | 112 | _t['im_detect'].toc() 113 | 114 | print ('Iteration %05d Loss %f PSNR_noisy: %f PSNR_gt: %f Time %.3f' % (i, total_loss.item(), psnr_noisy, psnr_gt, _t['im_detect'].total_time), '\n', end='') 115 | 116 | if i % args.show_every == 0: 117 | out_np = torch_to_np(out) 118 | if args.save_png == 1: 119 | cv2.imwrite(os.path.join(global_path, image_name, str(i) + '.png'),\ 120 | np.clip(out_np, 0, 1).transpose(1, 2, 0)[:,:,::-1] * 255) 121 | 122 | if args.plot: 123 | plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1) 124 | 125 | if i % args.show_every: 126 | if psnr_noisy - psnr_noisy_last < -5: 127 | print('Falling back to previous checkpoint.') 128 | 129 | for new_param, net_param in zip(last_net, net.parameters()): 130 | net_param.data.copy_(new_param.cuda()) 131 | 132 | return total_loss*0 133 | else: 134 | last_net = [x.detach().cpu() for x in net.parameters()] 135 | psnr_noisy_last = psnr_noisy 136 | 137 | i += 1 138 | 139 | return total_loss 140 | 141 | p = get_params('net', net, net_input) 142 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 143 | 144 | PSNR_mat = np.concatenate((PSNR_mat, np.array(PSNR_list).reshape(1,args.num_iter)), axis=0) 145 | pickle.dump( PSNR_mat, open( os.path.join(global_path, 'PSNR.pkl'), "wb" ) ) 146 | 147 | psnr_gt_best_list.append(psnr_gt_best) 148 | 149 | print('Finish optimization\n') -------------------------------------------------------------------------------- /DIP/denoising.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import ipdb 9 | import random 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | from skimage.measure import compare_psnr 14 | #from torchviz import make_dot, make_dot_from_trace 15 | # from torchvision import transforms, utils 16 | # from torch.utils.data import Dataset, DataLoader 17 | 18 | 19 | from utils.denoising_utils import * 20 | from utils.timer import Timer 21 | 22 | 23 | import torch 24 | import torch.optim 25 | 26 | import warnings 27 | warnings.filterwarnings("ignore") 28 | 29 | torch.backends.cudnn.enabled = True 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = True 32 | dtype = torch.cuda.FloatTensor 33 | 34 | def parse_args(): 35 | parser = argparse.ArgumentParser(description='NAS-DIP Denoising') 36 | 37 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 38 | parser.add_argument('--num_iter', dest='num_iter', default=3000, type=int) 39 | parser.add_argument('--show_every', dest='show_every', default=50, type=int) 40 | parser.add_argument('--lr', dest='lr', default=0.01, type=float) 41 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 42 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 43 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 44 | parser.add_argument('--output_path', dest='output_path',default='results/denoising', type=str) 45 | parser.add_argument('--batch_size', dest='batch_size',default=1, type=int) 46 | parser.add_argument('--random_seed', dest='random_seed',default=0, type=int) 47 | parser.add_argument('--net', dest='net',default='default', type=str) 48 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=1./30., type=float) 49 | parser.add_argument('--sigma', dest='sigma', default=25, type=float) 50 | parser.add_argument('--i_NAS', dest='i_NAS', default=-1, type=int) 51 | parser.add_argument('--save_png', dest='save_png', default=0, type=int) 52 | parser.add_argument('--exp_weight', dest='exp_weight', default=0.99, type=float) 53 | 54 | args = parser.parse_args() 55 | return args 56 | 57 | 58 | if __name__ == '__main__': 59 | 60 | args = parse_args() 61 | 62 | 63 | if args.net == 'default': 64 | global_path = args.output_path + '_' + args.net 65 | if not os.path.exists(global_path): 66 | os.makedirs(global_path) 67 | elif args.net == 'NAS': 68 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) 69 | if not os.path.exists(global_path): 70 | os.makedirs(global_path) 71 | elif args.net == 'Multiscale': 72 | from gen_skip_index import skip_index 73 | skip_connect = skip_index() 74 | token = skip_connect.flatten() 75 | token = ''.join(str(x) for x in token) 76 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) + '_' + token 77 | if not os.path.exists(global_path): 78 | os.makedirs(global_path) 79 | pickle.dump(skip_connect, open(os.path.join(global_path, 'skip_connect.pkl'), 'wb')) 80 | else: 81 | assert False, 'Please choose between default and NAS' 82 | 83 | # Creat the output_path if not exists 84 | 85 | np.random.seed(args.random_seed) 86 | torch.manual_seed(args.random_seed) 87 | torch.cuda.manual_seed_all(args.random_seed) 88 | 89 | # #batch x #iter 90 | PSNR_mat = np.empty((0, args.num_iter), dtype=np.float32) 91 | 92 | # Choose figure 93 | img_path_list = ['F16', 'Baboon', 'House', 'kodim01', 'kodim02', 'kodim03', 'kodim12', 'Lena', 'Peppers'] 94 | psnr_gt_best_list = [] 95 | 96 | for image_name in img_path_list: 97 | 98 | if args.save_png == 1 and not os.path.exists(os.path.join(global_path, image_name)): 99 | os.makedirs(os.path.join(global_path, image_name)) 100 | 101 | # Choose figure 102 | img_path = 'data/denoising/' + image_name + '.png' 103 | 104 | # Add synthetic noise 105 | img_pil = crop_image(get_image(img_path, -1)[0], 32) 106 | img_np = pil_to_np(img_pil) # (3, 512, 512) pixel value range: [0, 1] 107 | 108 | img_noisy_pil, img_noisy_np = get_noisy_image(img_np, args.sigma / 255.) # (3, 512, 512) [0, 1] 109 | 110 | # Visualization 111 | if args.plot: 112 | plot_image_grid([img_np, img_noisy_np], 4, 6) 113 | 114 | 115 | if args.net == 'default': 116 | from models.skip import skip 117 | net = skip(num_input_channels=args.input_depth, 118 | num_output_channels=3, 119 | num_channels_down=[128] * 5, 120 | num_channels_up=[128] * 5, 121 | num_channels_skip=[4] * 5, 122 | upsample_mode='bilinear', 123 | downsample_mode='stride', 124 | need_sigmoid=True, 125 | need_bias=True, 126 | pad='reflection', 127 | act_fun='LeakyReLU') 128 | 129 | elif args.net == 'NAS': 130 | from models.skip_search_up import skip 131 | if args.i_NAS in [249, 250, 251]: 132 | exit(1) 133 | net = skip(model_index=args.i_NAS, 134 | num_input_channels=args.input_depth, 135 | num_output_channels=3, 136 | num_channels_down=[128] * 5, 137 | num_channels_up=[128] * 5, 138 | num_channels_skip=[4] * 5, 139 | upsample_mode='bilinear', 140 | downsample_mode='stride', 141 | need_sigmoid=True, 142 | need_bias=True, 143 | pad='reflection', 144 | act_fun='LeakyReLU') 145 | 146 | elif args.net == 'Multiscale': 147 | from models.cross_skip import skip 148 | net = skip(model_index=args.i_NAS, 149 | skip_index=skip_connect, 150 | num_input_channels=args.input_depth, 151 | num_output_channels=3, 152 | num_channels_down=[128] * 5, 153 | num_channels_up=[128] * 5, 154 | num_channels_skip=[4] * 5, 155 | upsample_mode='bilinear', 156 | downsample_mode='stride', 157 | need_sigmoid=True, 158 | need_bias=True, 159 | pad='reflection', 160 | act_fun='LeakyReLU') 161 | else: 162 | assert False, 'Please choose between default and NAS' 163 | 164 | net = net.type(dtype) 165 | 166 | # z torch.Size([1, 32, 512, 512]) 167 | net_input = get_noise(args.input_depth, args.noise_method, (img_pil.size[1], img_pil.size[0])).type(dtype).detach() 168 | 169 | 170 | #dot = make_dot(net(net_input), params=dict(net.named_parameters())) 171 | #dot.format = 'svg' 172 | #dot.render(args.output_path + 'best_model', view=False) 173 | #exit(-1) 174 | 175 | # Compute number of parameters 176 | s = sum([np.prod(list(p.size())) for p in net.parameters()]); 177 | print ('Number of params: %d' % s) 178 | 179 | # Loss 180 | mse = torch.nn.MSELoss().type(dtype) 181 | 182 | # x0 torch.Size([1, 3, 512, 512]) 183 | img_noisy_torch = np_to_torch(img_noisy_np).type(dtype) 184 | 185 | net_input_saved = net_input.detach().clone() 186 | noise = net_input.detach().clone() 187 | out_avg = None 188 | last_net = None 189 | psnr_noisy_last = 0 190 | psnr_gt_best = 0 191 | 192 | # Main 193 | i = 0 194 | PSNR_list = [] 195 | 196 | _t = {'im_detect' : Timer(), 'misc' : Timer()} 197 | 198 | def closure(): 199 | 200 | global i, out_avg, psnr_noisy_last, last_net, net_input, psnr_gt_best, PSNR_list 201 | 202 | _t['im_detect'].tic() 203 | 204 | # Add variation 205 | if args.reg_noise_std > 0: 206 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 207 | 208 | out = net(net_input) 209 | 210 | # Smoothing 211 | if out_avg is None: 212 | out_avg = out.detach() 213 | else: 214 | out_avg = out_avg * args.exp_weight + out.detach() * (1 - args.exp_weight) 215 | 216 | 217 | total_loss = mse(out, img_noisy_torch) 218 | total_loss.backward() 219 | 220 | psnr_noisy = compare_psnr(img_noisy_np, out.detach().cpu().numpy()[0]) 221 | psnr_gt = compare_psnr(img_np, out_avg.detach().cpu().numpy()[0]) 222 | 223 | PSNR_list.append(psnr_gt) 224 | 225 | if psnr_gt > psnr_gt_best: 226 | psnr_gt_best = psnr_gt 227 | 228 | _t['im_detect'].toc() 229 | 230 | print ('Iteration %05d Loss %f PSNR_noisy: %f PSNR_gt: %f Time %.3f' % (i, total_loss.item(), psnr_noisy, psnr_gt, _t['im_detect'].total_time), '\n', end='') 231 | 232 | if i % args.show_every == 0: 233 | out_np = torch_to_np(out) 234 | if args.save_png == 1: 235 | cv2.imwrite(os.path.join(global_path, image_name, str(i) + '.png'),\ 236 | np.clip(out_np, 0, 1).transpose(1, 2, 0)[:,:,::-1] * 255) 237 | 238 | if args.plot: 239 | plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1) 240 | 241 | # Backtracking 242 | if i % args.show_every: 243 | if psnr_noisy - psnr_noisy_last < -5: 244 | print('Falling back to previous checkpoint.') 245 | 246 | for new_param, net_param in zip(last_net, net.parameters()): 247 | net_param.data.copy_(new_param.cuda()) 248 | 249 | return total_loss*0 250 | else: 251 | last_net = [x.detach().cpu() for x in net.parameters()] 252 | psnr_noisy_last = psnr_noisy 253 | 254 | i += 1 255 | 256 | return total_loss 257 | 258 | p = get_params('net', net, net_input) 259 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 260 | 261 | PSNR_mat = np.concatenate((PSNR_mat, np.array(PSNR_list).reshape(1,args.num_iter)), axis=0) 262 | pickle.dump( PSNR_mat, open( os.path.join(global_path, 'PSNR.pkl'), "wb" ) ) 263 | 264 | psnr_gt_best_list.append(psnr_gt_best) 265 | 266 | print('Finish optimization\n') 267 | 268 | for idx, image_name in enumerate(img_path_list): 269 | print ('Image: %8s PSNR: %.2f' % (image_name, psnr_gt_best_list[idx]), '\n', end='') 270 | print ('Averaged PSNR: %.2f' % (np.mean(psnr_gt_best_list)), '\n', end='') 271 | -------------------------------------------------------------------------------- /DIP/inpainting-test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | #import matplotlib 3 | #matplotlib.use('agg') 4 | #import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import argparse 9 | import numpy as np 10 | from skimage.measure import compare_psnr 11 | 12 | from utils.inpainting_utils import * 13 | from utils.timer import Timer 14 | 15 | import torch 16 | import torch.optim 17 | 18 | import warnings 19 | warnings.filterwarnings("ignore") 20 | 21 | torch.backends.cudnn.enabled = True 22 | torch.backends.cudnn.benchmark = True 23 | torch.backends.cudnn.deterministic = True 24 | dtype = torch.cuda.FloatTensor 25 | 26 | def parse_args(): 27 | parser = argparse.ArgumentParser(description='NAS-DIP Inpainting') 28 | 29 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 30 | parser.add_argument('--num_iter', dest='num_iter', default=11000, type=int) 31 | parser.add_argument('--show_every', dest='show_every', default=50, type=int) 32 | parser.add_argument('--lr', dest='lr', default=0.001, type=float) 33 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 34 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 35 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 36 | parser.add_argument('--output_path', dest='output_path',default='results/inpainting', type=str) 37 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=0.03, type=float) 38 | parser.add_argument('--image_name', type=str) 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | if __name__ == '__main__': 45 | 46 | args = parse_args() 47 | 48 | img_path = 'data/inpainting/' + args.image_name 49 | 50 | img_pil, img_np = get_image(img_path, -1) 51 | img_pil = np_to_pil(img_np) 52 | 53 | img_mask = get_bernoulli_mask(img_pil, 0.50) 54 | img_mask_np = pil_to_np(img_mask) 55 | 56 | img_masked = img_np * img_mask_np 57 | mask_var = np_to_torch(img_mask_np).type(dtype) 58 | 59 | # Visualization 60 | if args.plot: 61 | plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3, 11); 62 | 63 | from models.model_inpainting import Model 64 | net = Model() 65 | 66 | net = net.type(dtype) 67 | 68 | net_input = get_noise(args.input_depth, args.noise_method, img_np.shape[1:]).type(dtype).detach() 69 | 70 | mse = torch.nn.MSELoss().type(dtype) 71 | 72 | img_var = np_to_torch(img_np).type(dtype) 73 | 74 | net_input_saved = net_input.detach().clone() 75 | noise = net_input.detach().clone() 76 | 77 | last_net = None 78 | psrn_masked_last = 0 79 | psnr_gt_best = 0 80 | 81 | # Main 82 | i = 0 83 | PSNR_list = [] 84 | 85 | _t = {'im_detect': Timer(), 'misc': Timer()} 86 | 87 | def closure(): 88 | 89 | global i, psrn_masked_last, last_net, net_input, psnr_gt_best, PSNR_list 90 | 91 | _t['im_detect'].tic() 92 | 93 | # Add variation 94 | if args.reg_noise_std > 0: 95 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 96 | out = net(net_input) 97 | 98 | total_loss = mse(out * mask_var, img_var * mask_var) 99 | total_loss.backward() 100 | 101 | psrn_masked = compare_psnr(img_masked, out.detach().cpu().numpy()[0] * img_mask_np) 102 | psrn = compare_psnr(img_np, out.detach().cpu().numpy()[0]) 103 | 104 | PSNR_list.append(psrn) 105 | 106 | if psrn > psnr_gt_best: 107 | psnr_gt_best = psrn 108 | 109 | _t['im_detect'].toc() 110 | 111 | print ('Iteration %05d Loss %f PSNR_masked %f PSNR %f Time %.3f' % (i, total_loss.item(), psrn_masked, psrn, _t['im_detect'].total_time), '\r', end='') 112 | 113 | # Backtracking 114 | if args.plot and i % args.show_every == 0: 115 | 116 | plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1) 117 | 118 | out_np = torch_to_np(out) 119 | 120 | if psrn_masked - psrn_masked_last < -5: 121 | print('Falling back to previous checkpoint.') 122 | 123 | for new_param, net_param in zip(last_net, net.parameters()): 124 | net_param.data.copy_(new_param.cuda()) 125 | 126 | return total_loss*0 127 | else: 128 | last_net = [x.cpu() for x in net.parameters()] 129 | psrn_masked_last = psrn_masked 130 | 131 | i += 1 132 | 133 | return total_loss 134 | 135 | p = get_params('net', net, net_input) 136 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 137 | 138 | print('Finish optimization\n') 139 | -------------------------------------------------------------------------------- /DIP/inpainting.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import ipdb 9 | import random 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | from skimage.measure import compare_psnr 14 | # from torchviz import make_dot, make_dot_from_trace 15 | # from torchvision import transforms, utils 16 | # from torch.utils.data import Dataset, DataLoader 17 | 18 | 19 | from utils.inpainting_utils import * 20 | from utils.timer import Timer 21 | 22 | 23 | import torch 24 | import torch.optim 25 | 26 | 27 | import warnings 28 | warnings.filterwarnings("ignore") 29 | 30 | torch.backends.cudnn.enabled = True 31 | torch.backends.cudnn.benchmark = True 32 | torch.backends.cudnn.deterministic = True 33 | dtype = torch.cuda.FloatTensor 34 | 35 | def parse_args(): 36 | parser = argparse.ArgumentParser(description='NAS-DIP Denoising') 37 | 38 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 39 | parser.add_argument('--num_iter', dest='num_iter', default=11000, type=int) 40 | parser.add_argument('--show_every', dest='show_every', default=50, type=int) 41 | parser.add_argument('--lr', dest='lr', default=0.001, type=float) 42 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 43 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 44 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 45 | parser.add_argument('--output_path', dest='output_path',default='results/restoration', type=str) 46 | parser.add_argument('--batch_size', dest='batch_size',default=1, type=int) 47 | parser.add_argument('--random_seed', dest='random_seed',default=0, type=int) 48 | parser.add_argument('--net', dest='net',default='default', type=str) 49 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=0.03, type=float) 50 | parser.add_argument('--i_NAS', dest='i_NAS', default=-1, type=int) 51 | parser.add_argument('--save_png', dest='save_png', default=0, type=int) 52 | 53 | args = parser.parse_args() 54 | return args 55 | 56 | 57 | if __name__ == '__main__': 58 | 59 | args = parse_args() 60 | 61 | np.random.seed(args.random_seed) 62 | torch.manual_seed(args.random_seed) 63 | torch.cuda.manual_seed_all(args.random_seed) 64 | 65 | if args.net == 'default': 66 | global_path = args.output_path + '_' + args.net 67 | elif args.net == 'NAS': 68 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) 69 | elif args.net == 'Multiscale': 70 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) 71 | else: 72 | assert False, 'Please choose between default and NAS' 73 | 74 | # Creat the output_path if not exists 75 | if not os.path.exists(global_path): 76 | os.makedirs(global_path) 77 | 78 | # #batch x #iter 79 | PSNR_mat = np.empty((0, args.num_iter), dtype=np.float32) 80 | 81 | # Choose figure 82 | img_path_list = ['barbara', 'boat', 'house', 'Lena512', 'peppers256', 'Cameraman256', 'couple', 'fingerprint', 'hill', 'man', 'montage'] 83 | psnr_gt_best_list = [] 84 | 85 | for image_name in img_path_list: 86 | 87 | if args.save_png == 1 and not os.path.exists(os.path.join(global_path, image_name)): 88 | os.makedirs(os.path.join(global_path, image_name)) 89 | 90 | # Choose figure 91 | img_path = 'data/inpainting/' + image_name + '_GT.png' 92 | 93 | # Load image 94 | img_pil, img_np = get_image(img_path, -1) 95 | img_np = nn.ReflectionPad2d(1)(np_to_torch(img_np))[0].numpy() 96 | img_pil = np_to_pil(img_np) 97 | 98 | img_mask = get_bernoulli_mask(img_pil, 0.50) 99 | img_mask_np = pil_to_np(img_mask) 100 | 101 | img_masked = img_np * img_mask_np 102 | mask_var = np_to_torch(img_mask_np).type(dtype) 103 | 104 | # Visualization 105 | if args.plot: 106 | plot_image_grid([img_np, img_mask_np, img_mask_np * img_np], 3, 11); 107 | 108 | 109 | if args.net == 'default': 110 | from models.skip import skip 111 | net = skip(num_input_channels=args.input_depth, 112 | num_output_channels=1, 113 | num_channels_down=[128] * 5, 114 | num_channels_up=[128] * 5, 115 | num_channels_skip=[4] * 5, 116 | upsample_mode='bilinear', 117 | downsample_mode='stride', 118 | need_sigmoid=True, 119 | need_bias=True, 120 | pad='reflection', 121 | act_fun='LeakyReLU') 122 | 123 | elif args.net == 'NAS': 124 | from models.skip_search_up import skip 125 | if args.i_NAS in [249, 250, 251]: 126 | exit(1) 127 | net = skip(model_index=args.i_NAS, 128 | num_input_channels=args.input_depth, 129 | num_output_channels=1, 130 | num_channels_down=[128] * 5, 131 | num_channels_up=[128] * 5, 132 | num_channels_skip=[4] * 5, 133 | upsample_mode='bilinear', 134 | downsample_mode='stride', 135 | need_sigmoid=True, 136 | need_bias=True, 137 | pad='reflection', 138 | act_fun='LeakyReLU') 139 | 140 | elif args.net == 'Multiscale': 141 | from models.cross_skip import skip 142 | from gen_skip_index import skip_index 143 | skip_connect = skip_index() 144 | net = skip(model_index=args.i_NAS, 145 | skip_index=skip_connect, 146 | num_input_channels=args.input_depth, 147 | num_output_channels=1, 148 | num_channels_down=[128] * 5, 149 | num_channels_up=[128] * 5, 150 | num_channels_skip=[4] * 5, 151 | upsample_mode='bilinear', 152 | downsample_mode='stride', 153 | need_sigmoid=True, 154 | need_bias=True, 155 | pad='reflection', 156 | act_fun='LeakyReLU') 157 | 158 | else: 159 | assert False, 'Please choose between default and NAS' 160 | 161 | net = net.type(dtype) 162 | 163 | # z torch.Size([1, 32, 512, 512]) 164 | net_input = get_noise(args.input_depth, args.noise_method, img_np.shape[1:]).type(dtype).detach() 165 | 166 | 167 | # Loss 168 | mse = torch.nn.MSELoss().type(dtype) 169 | 170 | # x0 171 | img_var = np_to_torch(img_np).type(dtype) 172 | 173 | net_input_saved = net_input.detach().clone() 174 | noise = net_input.detach().clone() 175 | 176 | last_net = None 177 | psrn_masked_last = 0 178 | psnr_gt_best = 0 179 | 180 | # Main 181 | i = 0 182 | PSNR_list = [] 183 | 184 | _t = {'im_detect' : Timer(), 'misc' : Timer()} 185 | 186 | def closure(): 187 | 188 | global i, psrn_masked_last, last_net, net_input, psnr_gt_best, PSNR_list 189 | 190 | _t['im_detect'].tic() 191 | 192 | # Add variation 193 | if args.reg_noise_std > 0: 194 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 195 | #ipdb.set_trace() 196 | out = net(net_input) 197 | 198 | total_loss = mse(out * mask_var, img_var * mask_var) 199 | total_loss.backward() 200 | 201 | psrn_masked = compare_psnr(img_masked, out.detach().cpu().numpy()[0] * img_mask_np) 202 | psrn = compare_psnr(img_np, out.detach().cpu().numpy()[0]) 203 | 204 | PSNR_list.append(psrn) 205 | 206 | if psrn > psnr_gt_best: 207 | psnr_gt_best = psrn 208 | 209 | _t['im_detect'].toc() 210 | 211 | print ('Iteration %05d Loss %f PSNR_masked %f PSNR %f Time %.3f' % (i, total_loss.item(), psrn_masked, psrn, _t['im_detect'].total_time), '\r', end='') 212 | 213 | if i % args.show_every == 0: 214 | if args.save_png == 1: 215 | out_np = torch_to_np(out) 216 | cv2.imwrite(os.path.join(global_path, image_name, str(i) + '.png'),\ 217 | np.clip(out_np, 0, 1).transpose(1, 2, 0)[:,:,::-1] * 255) 218 | 219 | if args.plot: 220 | plot_image_grid([np.clip(out_np, 0, 1)], factor=4, nrow=1) 221 | 222 | # Backtracking 223 | if args.plot and i % args.show_every == 0: 224 | out_np = torch_to_np(out) 225 | 226 | if psrn_masked - psrn_masked_last < -5: 227 | print('Falling back to previous checkpoint.') 228 | 229 | for new_param, net_param in zip(last_net, net.parameters()): 230 | net_param.data.copy_(new_param.cuda()) 231 | 232 | return total_loss*0 233 | else: 234 | last_net = [x.cpu() for x in net.parameters()] 235 | psrn_masked_last = psrn_masked 236 | 237 | i += 1 238 | 239 | return total_loss 240 | 241 | p = get_params('net', net, net_input) 242 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 243 | 244 | PSNR_mat = np.concatenate((PSNR_mat, np.array(PSNR_list).reshape(1,args.num_iter)), axis=0) 245 | pickle.dump( PSNR_mat, open( os.path.join(global_path, 'PSNR.pkl'), "wb" ) ) 246 | 247 | psnr_gt_best_list.append(psnr_gt_best) 248 | 249 | print('Finish optimization\n') 250 | 251 | for idx, image_name in enumerate(img_path_list): 252 | print ('Image: %8s PSNR: %.2f' % (image_name, psnr_gt_best_list[idx]), '\n', end='') 253 | print ('Averaged PSNR: %.2f' % (np.mean(psnr_gt_best_list)), '\n', end='') 254 | -------------------------------------------------------------------------------- /DIP/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/models/.DS_Store -------------------------------------------------------------------------------- /DIP/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .skip import skip 2 | from .texture_nets import get_texture_nets 3 | from .resnet import ResNet 4 | from .unet import UNet 5 | 6 | import torch.nn as nn 7 | 8 | def get_net(input_depth, NET_TYPE, pad, upsample_mode, n_channels=3, act_fun='LeakyReLU', skip_n33d=128, skip_n33u=128, skip_n11=4, num_scales=5, downsample_mode='stride'): 9 | if NET_TYPE == 'ResNet': 10 | # TODO 11 | net = ResNet(input_depth, 3, 10, 16, 1, nn.BatchNorm2d, False) 12 | elif NET_TYPE == 'skip': 13 | net = skip(input_depth, n_channels, num_channels_down = [skip_n33d]*num_scales if isinstance(skip_n33d, int) else skip_n33d, 14 | num_channels_up = [skip_n33u]*num_scales if isinstance(skip_n33u, int) else skip_n33u, 15 | num_channels_skip = [skip_n11]*num_scales if isinstance(skip_n11, int) else skip_n11, 16 | upsample_mode=upsample_mode, downsample_mode=downsample_mode, 17 | need_sigmoid=True, need_bias=True, pad=pad, act_fun=act_fun) 18 | 19 | elif NET_TYPE == 'texture_nets': 20 | net = get_texture_nets(inp=input_depth, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False,pad=pad) 21 | 22 | elif NET_TYPE =='UNet': 23 | net = UNet(num_input_channels=input_depth, num_output_channels=3, 24 | feature_scale=4, more_layers=0, concat_x=False, 25 | upsample_mode=upsample_mode, pad=pad, norm_layer=nn.BatchNorm2d, need_sigmoid=True, need_bias=True) 26 | elif NET_TYPE == 'identity': 27 | assert input_depth == 3 28 | net = nn.Sequential() 29 | else: 30 | assert False 31 | 32 | return net -------------------------------------------------------------------------------- /DIP/models/common.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from .downsampler import Downsampler 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class Concat(nn.Module): 12 | def __init__(self, dim, *args): 13 | super(Concat, self).__init__() 14 | self.dim = dim 15 | 16 | for idx, module in enumerate(args): 17 | self.add_module(str(idx), module) 18 | 19 | def forward(self, input): 20 | inputs = [] 21 | for module in self._modules.values(): 22 | inputs.append(module(input)) 23 | 24 | inputs_shapes2 = [x.shape[2] for x in inputs] 25 | inputs_shapes3 = [x.shape[3] for x in inputs] 26 | 27 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 28 | inputs_ = inputs 29 | else: 30 | target_shape2 = min(inputs_shapes2) 31 | target_shape3 = min(inputs_shapes3) 32 | 33 | inputs_ = [] 34 | for inp in inputs: 35 | diff2 = (inp.size(2) - target_shape2) // 2 36 | diff3 = (inp.size(3) - target_shape3) // 2 37 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 38 | 39 | return torch.cat(inputs_, dim=self.dim) 40 | 41 | def __len__(self): 42 | return len(self._modules) 43 | 44 | 45 | class GenNoise(nn.Module): 46 | def __init__(self, dim2): 47 | super(GenNoise, self).__init__() 48 | self.dim2 = dim2 49 | 50 | def forward(self, input): 51 | a = list(input.size()) 52 | a[1] = self.dim2 53 | # print (input.data.type()) 54 | 55 | b = torch.zeros(a).type_as(input.data) 56 | b.normal_() 57 | 58 | x = torch.autograd.Variable(b) 59 | 60 | return x 61 | 62 | 63 | class Swish(nn.Module): 64 | """ 65 | https://arxiv.org/abs/1710.05941 66 | The hype was so huge that I could not help but try it 67 | """ 68 | def __init__(self): 69 | super(Swish, self).__init__() 70 | self.s = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | return x * self.s(x) 74 | 75 | 76 | def act(act_fun = 'LeakyReLU'): 77 | ''' 78 | Either string defining an activation function or module (e.g. nn.ReLU) 79 | ''' 80 | if isinstance(act_fun, str): 81 | if act_fun == 'LeakyReLU': 82 | return nn.LeakyReLU(0.2, inplace=True) 83 | elif act_fun == 'Swish': 84 | return Swish() 85 | elif act_fun == 'ELU': 86 | return nn.ELU() 87 | elif act_fun == 'none': 88 | return nn.Sequential() 89 | else: 90 | assert False 91 | else: 92 | return act_fun() 93 | 94 | 95 | def bn(num_features): 96 | return nn.BatchNorm2d(num_features) 97 | 98 | 99 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 100 | downsampler = None 101 | if stride != 1 and downsample_mode != 'stride': 102 | 103 | if downsample_mode == 'avg': 104 | downsampler = nn.AvgPool2d(stride, stride) 105 | elif downsample_mode == 'max': 106 | downsampler = nn.MaxPool2d(stride, stride) 107 | elif downsample_mode in ['lanczos2', 'lanczos3']: 108 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 109 | else: 110 | assert False 111 | 112 | stride = 1 113 | 114 | padder = None 115 | to_pad = int((kernel_size - 1) / 2) 116 | if pad == 'reflection': 117 | padder = nn.ReflectionPad2d(to_pad) 118 | to_pad = 0 119 | 120 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 121 | 122 | 123 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 124 | return nn.Sequential(*layers) -------------------------------------------------------------------------------- /DIP/models/common_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from downsampler import Downsampler 5 | 6 | def add_module(self, module): 7 | self.add_module(str(len(self) + 1), module) 8 | 9 | torch.nn.Module.add = add_module 10 | 11 | class Concat(nn.Module): 12 | def __init__(self, dim, *args): 13 | super(Concat, self).__init__() 14 | self.dim = dim 15 | 16 | for idx, module in enumerate(args): 17 | self.add_module(str(idx), module) 18 | 19 | def forward(self, input): 20 | inputs = [] 21 | for module in self._modules.values(): 22 | inputs.append(module(input)) 23 | 24 | inputs_shapes2 = [x.shape[2] for x in inputs] 25 | inputs_shapes3 = [x.shape[3] for x in inputs] 26 | 27 | if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)): 28 | inputs_ = inputs 29 | else: 30 | target_shape2 = min(inputs_shapes2) 31 | target_shape3 = min(inputs_shapes3) 32 | 33 | inputs_ = [] 34 | for inp in inputs: 35 | diff2 = (inp.size(2) - target_shape2) // 2 36 | diff3 = (inp.size(3) - target_shape3) // 2 37 | inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3]) 38 | 39 | return torch.cat(inputs_, dim=self.dim) 40 | 41 | def __len__(self): 42 | return len(self._modules) 43 | 44 | 45 | class GenNoise(nn.Module): 46 | def __init__(self, dim2): 47 | super(GenNoise, self).__init__() 48 | self.dim2 = dim2 49 | 50 | def forward(self, input): 51 | a = list(input.size()) 52 | a[1] = self.dim2 53 | # print (input.data.type()) 54 | 55 | b = torch.zeros(a).type_as(input.data) 56 | b.normal_() 57 | 58 | x = torch.autograd.Variable(b) 59 | 60 | return x 61 | 62 | 63 | class Swish(nn.Module): 64 | """ 65 | https://arxiv.org/abs/1710.05941 66 | The hype was so huge that I could not help but try it 67 | """ 68 | def __init__(self): 69 | super(Swish, self).__init__() 70 | self.s = nn.Sigmoid() 71 | 72 | def forward(self, x): 73 | return x * self.s(x) 74 | 75 | 76 | def act(act_fun = 'LeakyReLU'): 77 | ''' 78 | Either string defining an activation function or module (e.g. nn.ReLU) 79 | ''' 80 | if isinstance(act_fun, str): 81 | if act_fun == 'LeakyReLU': 82 | return nn.LeakyReLU(0.2, inplace=True) 83 | elif act_fun == 'Swish': 84 | return Swish() 85 | elif act_fun == 'ELU': 86 | return nn.ELU() 87 | elif act_fun == 'none': 88 | return nn.Sequential() 89 | else: 90 | assert False 91 | else: 92 | return act_fun() 93 | 94 | 95 | def bn(num_features): 96 | return nn.BatchNorm2d(num_features) 97 | 98 | 99 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 100 | downsampler = None 101 | if stride != 1 and downsample_mode != 'stride': 102 | 103 | if downsample_mode == 'avg': 104 | downsampler = nn.AvgPool2d(stride, stride) 105 | elif downsample_mode == 'max': 106 | downsampler = nn.MaxPool2d(stride, stride) 107 | elif downsample_mode in ['lanczos2', 'lanczos3']: 108 | downsampler = Downsampler(n_planes=out_f, factor=stride, kernel_type=downsample_mode, phase=0.5, preserve_size=True) 109 | else: 110 | assert False 111 | 112 | stride = 1 113 | 114 | padder = None 115 | to_pad = int((kernel_size - 1) / 2) 116 | if pad == 'reflection': 117 | padder = nn.ReflectionPad2d(to_pad) 118 | to_pad = 0 119 | 120 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 121 | 122 | 123 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 124 | return nn.Sequential(*layers) 125 | -------------------------------------------------------------------------------- /DIP/models/downsampler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | class Downsampler(nn.Module): 6 | ''' 7 | http://www.realitypixels.com/turk/computergraphics/ResamplingFilters.pdf 8 | ''' 9 | def __init__(self, n_planes, factor, kernel_type, phase=0, kernel_width=None, support=None, sigma=None, preserve_size=False): 10 | super(Downsampler, self).__init__() 11 | 12 | assert phase in [0, 0.5], 'phase should be 0 or 0.5' 13 | 14 | if kernel_type == 'lanczos2': 15 | support = 2 16 | kernel_width = 4 * factor + 1 17 | kernel_type_ = 'lanczos' 18 | 19 | elif kernel_type == 'lanczos3': 20 | support = 3 21 | kernel_width = 6 * factor + 1 22 | kernel_type_ = 'lanczos' 23 | 24 | elif kernel_type == 'gauss12': 25 | kernel_width = 7 26 | sigma = 1/2 27 | kernel_type_ = 'gauss' 28 | 29 | elif kernel_type == 'gauss1sq2': 30 | kernel_width = 9 31 | sigma = 1./np.sqrt(2) 32 | kernel_type_ = 'gauss' 33 | 34 | elif kernel_type in ['lanczos', 'gauss', 'box']: 35 | kernel_type_ = kernel_type 36 | 37 | else: 38 | assert False, 'wrong name kernel' 39 | 40 | 41 | # note that `kernel width` will be different to actual size for phase = 1/2 42 | self.kernel = get_kernel(factor, kernel_type_, phase, kernel_width, support=support, sigma=sigma) 43 | 44 | downsampler = nn.Conv2d(n_planes, n_planes, kernel_size=self.kernel.shape, stride=factor, padding=0) 45 | downsampler.weight.data[:] = 0 46 | downsampler.bias.data[:] = 0 47 | 48 | kernel_torch = torch.from_numpy(self.kernel) 49 | for i in range(n_planes): 50 | downsampler.weight.data[i, i] = kernel_torch 51 | 52 | self.downsampler_ = downsampler 53 | 54 | if preserve_size: 55 | 56 | if self.kernel.shape[0] % 2 == 1: 57 | pad = int((self.kernel.shape[0] - 1) / 2.) 58 | else: 59 | pad = int((self.kernel.shape[0] - factor) / 2.) 60 | 61 | self.padding = nn.ReplicationPad2d(pad) 62 | 63 | self.preserve_size = preserve_size 64 | 65 | def forward(self, input): 66 | if self.preserve_size: 67 | x = self.padding(input) 68 | else: 69 | x= input 70 | self.x = x 71 | return self.downsampler_(x) 72 | 73 | def get_kernel(factor, kernel_type, phase, kernel_width, support=None, sigma=None): 74 | assert kernel_type in ['lanczos', 'gauss', 'box'] 75 | 76 | # factor = float(factor) 77 | if phase == 0.5 and kernel_type != 'box': 78 | kernel = np.zeros([kernel_width - 1, kernel_width - 1]) 79 | else: 80 | kernel = np.zeros([kernel_width, kernel_width]) 81 | 82 | 83 | if kernel_type == 'box': 84 | assert phase == 0.5, 'Box filter is always half-phased' 85 | kernel[:] = 1./(kernel_width * kernel_width) 86 | 87 | elif kernel_type == 'gauss': 88 | assert sigma, 'sigma is not specified' 89 | assert phase != 0.5, 'phase 1/2 for gauss not implemented' 90 | 91 | center = (kernel_width + 1.)/2. 92 | print(center, kernel_width) 93 | sigma_sq = sigma * sigma 94 | 95 | for i in range(1, kernel.shape[0] + 1): 96 | for j in range(1, kernel.shape[1] + 1): 97 | di = (i - center)/2. 98 | dj = (j - center)/2. 99 | kernel[i - 1][j - 1] = np.exp(-(di * di + dj * dj)/(2 * sigma_sq)) 100 | kernel[i - 1][j - 1] = kernel[i - 1][j - 1]/(2. * np.pi * sigma_sq) 101 | elif kernel_type == 'lanczos': 102 | assert support, 'support is not specified' 103 | center = (kernel_width + 1) / 2. 104 | 105 | for i in range(1, kernel.shape[0] + 1): 106 | for j in range(1, kernel.shape[1] + 1): 107 | 108 | if phase == 0.5: 109 | di = abs(i + 0.5 - center) / factor 110 | dj = abs(j + 0.5 - center) / factor 111 | else: 112 | di = abs(i - center) / factor 113 | dj = abs(j - center) / factor 114 | 115 | 116 | pi_sq = np.pi * np.pi 117 | 118 | val = 1 119 | if di != 0: 120 | val = val * support * np.sin(np.pi * di) * np.sin(np.pi * di / support) 121 | val = val / (np.pi * np.pi * di * di) 122 | 123 | if dj != 0: 124 | val = val * support * np.sin(np.pi * dj) * np.sin(np.pi * dj / support) 125 | val = val / (np.pi * np.pi * dj * dj) 126 | 127 | kernel[i - 1][j - 1] = val 128 | 129 | 130 | else: 131 | assert False, 'wrong method name' 132 | 133 | kernel /= kernel.sum() 134 | 135 | return kernel 136 | 137 | #a = Downsampler(n_planes=3, factor=2, kernel_type='lanczos2', phase='1', preserve_size=True) 138 | 139 | 140 | 141 | 142 | 143 | 144 | ################# 145 | # Learnable downsampler 146 | 147 | # KS = 32 148 | # dow = nn.Sequential(nn.ReplicationPad2d(int((KS - factor) / 2.)), nn.Conv2d(1,1,KS,factor)) 149 | 150 | # class Apply(nn.Module): 151 | # def __init__(self, what, dim, *args): 152 | # super(Apply, self).__init__() 153 | # self.dim = dim 154 | 155 | # self.what = what 156 | 157 | # def forward(self, input): 158 | # inputs = [] 159 | # for i in range(input.size(self.dim)): 160 | # inputs.append(self.what(input.narrow(self.dim, i, 1))) 161 | 162 | # return torch.cat(inputs, dim=self.dim) 163 | 164 | # def __len__(self): 165 | # return len(self._modules) 166 | 167 | # downs = Apply(dow, 1) 168 | # downs.type(dtype)(net_input.type(dtype)).size() 169 | -------------------------------------------------------------------------------- /DIP/models/gen_upsample_layer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | from NAS import genotypes 7 | except ImportError: 8 | import genotypes 9 | 10 | try: 11 | from NAS import model 12 | except ImportError: 13 | import model 14 | 15 | try: 16 | from NAS import operations 17 | except ImportError: 18 | import operations 19 | 20 | def gen_layer(C_in, C_out, model_index): 21 | 22 | swap = False 23 | 24 | """ Bilinear """ 25 | if model_index >= 0 and model_index <= 251: 26 | prim_index = model_index // 63 27 | model_index = model_index % 63 28 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 29 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 30 | act_index = model_index % len(genotypes.ACTIVATION) 31 | 32 | if (model_index >= 60 and model_index <= 62): 33 | conv_index = 5 34 | 35 | """ DepthToSpace - Second """ 36 | if model_index >= 252 and model_index <= 311: 37 | swap = True 38 | prim_index = (model_index - 63) // 63 39 | model_index = model_index % 63 40 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 41 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 42 | act_index = model_index % len(genotypes.ACTIVATION) 43 | 44 | """ Transposed Convolution """ 45 | if model_index >= 312 and model_index <= 323: 46 | prim_index = 4 47 | conv_index = 5 48 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 49 | act_index = model_index % len(genotypes.ACTIVATION) 50 | 51 | prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_index] # adjust the spatial size 52 | conv_op = genotypes.UPSAMPLE_CONV[conv_index] # adjust the spatial size 53 | kernel_size = genotypes.KERNEL_SIZE[kernel_index] # select the kernel size 54 | act_op = genotypes.ACTIVATION[act_index] # select the kernel size 55 | 56 | #print('prim op:', prim_op) 57 | #print('conv op:', conv_op) 58 | #print('kernel size:', kernel_size) 59 | #print('act op:', act_op) 60 | #return 61 | 62 | if prim_op == 'pixel_shuffle': 63 | if not swap: 64 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 65 | C_out=C_out, 66 | kernel_size=kernel_size, 67 | act_op=act_op) 68 | 69 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=int(C_in/4), 70 | C_out=C_out, 71 | kernel_size=kernel_size, 72 | act_op=act_op) 73 | return nn.Sequential(prim_op_layer, conv_op_layer) 74 | 75 | else: 76 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 77 | C_out=int(C_out*4), 78 | kernel_size=kernel_size, 79 | act_op=act_op) 80 | 81 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 82 | C_out=C_out, 83 | kernel_size=kernel_size, 84 | act_op=act_op) 85 | return nn.Sequential(conv_op_layer, prim_op_layer) 86 | 87 | else: 88 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 89 | C_out=C_out, 90 | kernel_size=kernel_size, 91 | act_op=act_op) 92 | 93 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 94 | C_out=C_out, 95 | kernel_size=kernel_size, 96 | act_op=act_op) 97 | return nn.Sequential(prim_op_layer, conv_op_layer) 98 | 99 | #for i in range(321): 100 | # gen_layer(0, 0, model_index=i) 101 | #gen_layer(0, 0, model_index=189) 102 | -------------------------------------------------------------------------------- /DIP/models/model_sr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | import models.gen_upsample_layer 6 | 7 | class OutputBlock(nn.Module): 8 | 9 | def __init__(self, 10 | in_channel, 11 | out_channel, 12 | kernel_size, 13 | bias, 14 | pad, 15 | need_sigmoid): 16 | 17 | super(OutputBlock, self).__init__() 18 | 19 | if need_sigmoid: 20 | self.op = nn.Sequential( 21 | conv(in_f=in_channel, 22 | out_f=out_channel, 23 | kernel_size=kernel_size, 24 | bias=bias, 25 | pad=pad 26 | ), 27 | nn.Sigmoid(), 28 | ) 29 | 30 | else: 31 | self.op = nn.Sequential( 32 | conv(in_f=in_channel, 33 | out_f=out_channel, 34 | kernel_size=kernel_size, 35 | bias=bias, 36 | pad=pad 37 | ), 38 | ) 39 | 40 | def forward(self, data): 41 | return self.op(data) 42 | 43 | 44 | class UpsampleBlock(nn.Module): 45 | 46 | def __init__(self, 47 | in_channel, 48 | out_channel, 49 | model_index): 50 | 51 | super(UpsampleBlock, self).__init__() 52 | 53 | self.op = models.gen_upsample_layer.gen_layer( 54 | C_in=in_channel, 55 | C_out=out_channel, 56 | model_index=model_index 57 | ) 58 | 59 | def forward(self, data): 60 | return self.op(data) 61 | 62 | 63 | class DownsampleBlock(nn.Module): 64 | 65 | def __init__(self, 66 | in_channel, 67 | out_channel, 68 | kernel_size, 69 | bias, 70 | pad, 71 | act_fun, 72 | downsample_mode): 73 | 74 | super(DownsampleBlock, self).__init__() 75 | 76 | self.op = nn.Sequential( 77 | conv(in_f=in_channel, out_f=out_channel, kernel_size=kernel_size, stride=2, bias=bias, pad=pad, downsample_mode=downsample_mode), 78 | bn(num_features=out_channel), 79 | act(act_fun=act_fun) 80 | ) 81 | 82 | def forward(self, data): 83 | return self.op(data) 84 | 85 | 86 | class SkipBlock(nn.Module): 87 | 88 | def __init__(self, 89 | in_channel, 90 | out_channel, 91 | kernel_size, 92 | bias, 93 | pad, 94 | act_fun): 95 | 96 | super(SkipBlock, self).__init__() 97 | 98 | self.op = nn.Sequential( 99 | conv(in_f=in_channel, 100 | out_f=out_channel, 101 | kernel_size=kernel_size, 102 | bias=bias, 103 | pad=pad), 104 | bn(num_features=out_channel), 105 | act(act_fun=act_fun) 106 | ) 107 | 108 | def forward(self, data): 109 | return self.op(data) 110 | 111 | 112 | class EncoderBlock(nn.Module): 113 | 114 | def __init__(self, 115 | in_channel, 116 | out_channel, 117 | kernel_size, 118 | bias, 119 | pad, 120 | act_fun, 121 | downsample_mode): 122 | 123 | super(EncoderBlock, self).__init__() 124 | 125 | self.op = nn.Sequential( 126 | conv(in_f=in_channel, out_f=out_channel, kernel_size=kernel_size, stride=2, bias=bias, pad=pad, downsample_mode=downsample_mode), 127 | bn(num_features=out_channel), 128 | act(act_fun=act_fun), 129 | conv(in_f=out_channel, out_f=out_channel, kernel_size=kernel_size, bias=bias, pad=pad), 130 | bn(num_features=out_channel), 131 | act(act_fun=act_fun), 132 | ) 133 | 134 | def forward(self, data): 135 | return self.op(data) 136 | 137 | 138 | class DecoderBlock(nn.Module): 139 | 140 | def __init__(self, 141 | in_channel, 142 | out_channel, 143 | kernel_size, 144 | bias, 145 | pad, 146 | act_fun, 147 | need1x1_up): 148 | 149 | super(DecoderBlock, self).__init__() 150 | 151 | if need1x1_up: 152 | self.op = nn.Sequential( 153 | bn(num_features=out_channel), 154 | conv(in_f=in_channel, out_f=out_channel, kernel_size=kernel_size, stride=1, bias=bias, pad=pad), 155 | bn(num_features=out_channel), 156 | act(act_fun=act_fun), 157 | conv(in_f=out_channel, out_f=out_channel, kernel_size=kernel_size, bias=bias, pad=pad), 158 | bn(num_features=out_channel), 159 | act(act_fun=act_fun), 160 | conv(in_f=out_channel, out_f=out_channel, kernel_size=1, bias=bias, pad=pad), 161 | bn(num_features=out_channel), 162 | act(act_fun=act_fun), 163 | ) 164 | 165 | else: 166 | self.op = nn.Sequential( 167 | bn(num_features=out_channel), 168 | conv(in_f=in_channel, out_f=out_channel, kernel_size=kernel_size, stride=1, bias=bias, pad=pad), 169 | bn(num_features=out_channel), 170 | act(act_fun=act_fun), 171 | conv(in_f=out_channel, out_f=out_channel, kernel_size=kernel_size, bias=bias, pad=pad), 172 | bn(num_features=out_channel), 173 | act(act_fun=act_fun), 174 | ) 175 | 176 | def forward(self, data): 177 | return self.op(data) 178 | 179 | 180 | class Model(nn.Module): 181 | 182 | def __init__(self, 183 | model_index=119, 184 | num_input_channels=32, 185 | num_output_channels=3, 186 | num_channels_down=[128, 128, 128, 128, 128], 187 | num_channels_up=[128, 128, 128, 128, 128], 188 | num_channels_skip=[4, 4, 4, 4, 4], 189 | filter_size_down=3, 190 | filter_size_up=3, 191 | filter_skip_size=1, 192 | need_sigmoid=True, 193 | need_bias=True, 194 | pad='reflection', 195 | upsample_mode='nearest', 196 | downsample_mode='stride', 197 | act_fun='LeakyReLU', 198 | need1x1_up=True): 199 | 200 | super(Model, self).__init__() 201 | 202 | self.enc1 = EncoderBlock(in_channel=num_input_channels, 203 | out_channel=num_channels_down[0], 204 | kernel_size=filter_size_down, 205 | bias=need_bias, 206 | pad=pad, 207 | act_fun=act_fun, 208 | downsample_mode=downsample_mode) 209 | 210 | self.enc2 = EncoderBlock(in_channel=num_channels_down[0], 211 | out_channel=num_channels_down[1], 212 | kernel_size=filter_size_down, 213 | bias=need_bias, 214 | pad=pad, 215 | act_fun=act_fun, 216 | downsample_mode=downsample_mode) 217 | 218 | self.enc3 = EncoderBlock(in_channel=num_channels_down[1], 219 | out_channel=num_channels_down[2], 220 | kernel_size=filter_size_down, 221 | bias=need_bias, 222 | pad=pad, 223 | act_fun=act_fun, 224 | downsample_mode=downsample_mode) 225 | 226 | self.enc4 = EncoderBlock(in_channel=num_channels_down[2], 227 | out_channel=num_channels_down[3], 228 | kernel_size=filter_size_down, 229 | bias=need_bias, 230 | pad=pad, 231 | act_fun=act_fun, 232 | downsample_mode=downsample_mode) 233 | 234 | self.enc5 = EncoderBlock(in_channel=num_channels_down[3], 235 | out_channel=num_channels_down[4], 236 | kernel_size=filter_size_down, 237 | bias=need_bias, 238 | pad=pad, 239 | act_fun=act_fun, 240 | downsample_mode=downsample_mode) 241 | 242 | self.skip1 = SkipBlock(in_channel=num_input_channels, 243 | out_channel=num_channels_up[0], 244 | kernel_size=1, 245 | bias=need_bias, 246 | pad=pad, 247 | act_fun=act_fun) 248 | 249 | self.skip2 = SkipBlock(in_channel=num_channels_down[0], 250 | out_channel=num_channels_up[1], 251 | kernel_size=1, 252 | bias=need_bias, 253 | pad=pad, 254 | act_fun=act_fun) 255 | 256 | self.skip3 = SkipBlock(in_channel=num_channels_down[1], 257 | out_channel=num_channels_up[2], 258 | kernel_size=1, 259 | bias=need_bias, 260 | pad=pad, 261 | act_fun=act_fun) 262 | 263 | self.skip4 = SkipBlock(in_channel=num_channels_down[2], 264 | out_channel=num_channels_up[3], 265 | kernel_size=1, 266 | bias=need_bias, 267 | pad=pad, 268 | act_fun=act_fun) 269 | 270 | self.skip5 = SkipBlock(in_channel=num_channels_down[3], 271 | out_channel=num_channels_up[4], 272 | kernel_size=1, 273 | bias=need_bias, 274 | pad=pad, 275 | act_fun=act_fun) 276 | 277 | self.skip_up_5_4 = UpsampleBlock(in_channel=num_channels_down[4], 278 | out_channel=num_channels_up[3], 279 | model_index=model_index) 280 | 281 | self.skip_up_4_3 = UpsampleBlock(in_channel=num_channels_down[3], 282 | out_channel=num_channels_up[2], 283 | model_index=model_index) 284 | 285 | self.skip_up_3_2 = UpsampleBlock(in_channel=num_channels_down[2], 286 | out_channel=num_channels_up[1], 287 | model_index=model_index) 288 | 289 | self.skip_up_2_1 = UpsampleBlock(in_channel=num_channels_down[1], 290 | out_channel=num_channels_up[0], 291 | model_index=model_index) 292 | 293 | self.skip_down_1_2 = DownsampleBlock(in_channel=num_input_channels, 294 | out_channel=num_channels_up[0], 295 | kernel_size=filter_size_down, 296 | bias=need_bias, 297 | pad=pad, 298 | act_fun=act_fun, 299 | downsample_mode=downsample_mode) 300 | 301 | self.skip_down_2_3 = DownsampleBlock(in_channel=num_channels_down[0], 302 | out_channel=num_channels_up[1], 303 | kernel_size=filter_size_down, 304 | bias=need_bias, 305 | pad=pad, 306 | act_fun=act_fun, 307 | downsample_mode=downsample_mode) 308 | 309 | self.skip_down_3_4 = DownsampleBlock(in_channel=num_channels_down[1], 310 | out_channel=num_channels_up[2], 311 | kernel_size=filter_size_down, 312 | bias=need_bias, 313 | pad=pad, 314 | act_fun=act_fun, 315 | downsample_mode=downsample_mode) 316 | 317 | self.skip_down_4_5 = DownsampleBlock(in_channel=num_channels_down[2], 318 | out_channel=num_channels_up[3], 319 | kernel_size=filter_size_down, 320 | bias=need_bias, 321 | pad=pad, 322 | act_fun=act_fun, 323 | downsample_mode=downsample_mode) 324 | 325 | self.up5 = UpsampleBlock(in_channel=num_channels_up[4], 326 | out_channel=num_channels_up[4], 327 | model_index=model_index) 328 | 329 | self.up4 = UpsampleBlock(in_channel=num_channels_up[3], 330 | out_channel=num_channels_up[3], 331 | model_index=model_index) 332 | 333 | self.up3 = UpsampleBlock(in_channel=num_channels_up[2], 334 | out_channel=num_channels_up[2], 335 | model_index=model_index) 336 | 337 | self.up2 = UpsampleBlock(in_channel=num_channels_up[1], 338 | out_channel=num_channels_up[1], 339 | model_index=model_index) 340 | 341 | self.up1 = UpsampleBlock(in_channel=num_channels_up[0], 342 | out_channel=num_channels_up[0], 343 | model_index=model_index) 344 | 345 | self.dec5 = DecoderBlock(in_channel=num_channels_down[4], 346 | out_channel=num_channels_up[4], 347 | kernel_size=filter_size_up, 348 | bias=need_bias, 349 | pad=pad, 350 | act_fun=act_fun, 351 | need1x1_up=need1x1_up) 352 | 353 | self.dec4 = DecoderBlock(in_channel=num_channels_up[3], 354 | out_channel=num_channels_up[3], 355 | kernel_size=filter_size_up, 356 | bias=need_bias, 357 | pad=pad, 358 | act_fun=act_fun, 359 | need1x1_up=need1x1_up) 360 | 361 | self.dec3 = DecoderBlock(in_channel=num_channels_up[2], 362 | out_channel=num_channels_up[2], 363 | kernel_size=filter_size_up, 364 | bias=need_bias, 365 | pad=pad, 366 | act_fun=act_fun, 367 | need1x1_up=need1x1_up) 368 | 369 | self.dec2 = DecoderBlock(in_channel=num_channels_up[1], 370 | out_channel=num_channels_up[1], 371 | kernel_size=filter_size_up, 372 | bias=need_bias, 373 | pad=pad, 374 | act_fun=act_fun, 375 | need1x1_up=need1x1_up) 376 | 377 | self.dec1 = DecoderBlock(in_channel=num_channels_up[0], 378 | out_channel=num_channels_up[0], 379 | kernel_size=filter_size_up, 380 | bias=need_bias, 381 | pad=pad, 382 | act_fun=act_fun, 383 | need1x1_up=need1x1_up) 384 | 385 | self.output = OutputBlock(in_channel=num_channels_up[0], 386 | out_channel=num_output_channels, 387 | kernel_size=1, 388 | bias=need_bias, 389 | pad=pad, 390 | need_sigmoid=need_sigmoid) 391 | 392 | def forward(self, data): 393 | 394 | enc1 = self.enc1(data) # H/2 x W/2 x 128 395 | enc2 = self.enc2(enc1) # H/4 x W/4 x 128 396 | enc3 = self.enc3(enc2) # H/8 x W/8 x 128 397 | enc4 = self.enc4(enc3) # H/16 x W/16 x 128 398 | enc5 = self.enc5(enc4) # H/32 x W/32 x 128 399 | 400 | add5 = self.up5(enc5) + self.skip_down_4_5(enc3) + self.skip5(enc4) 401 | dec5 = self.dec5(add5) 402 | 403 | add4 = self.up4(dec5) + self.skip_down_3_4(enc2) + self.skip4(enc3) 404 | dec4 = self.dec4(add4) 405 | 406 | add3 = self.up3(dec4) + self.skip_down_2_3(enc1) + self.skip3(enc2) 407 | dec3 = self.dec3(add3) 408 | 409 | add2 = self.up2(dec3) + self.skip_down_1_2(data) + self.skip2(enc1) + self.skip_up_3_2(self.skip_up_4_3(self.skip_up_5_4(enc4))) 410 | dec2 = self.dec2(add2) 411 | 412 | add1 = self.up1(dec2) + self.skip1(data) + self.skip_up_2_1(self.skip_up_3_2(self.skip_up_4_3(enc3))) 413 | dec1 = self.dec1(add1) 414 | 415 | out = self.output(dec1) 416 | 417 | return out 418 | -------------------------------------------------------------------------------- /DIP/models/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from numpy.random import normal 4 | from numpy.linalg import svd 5 | from math import sqrt 6 | import torch.nn.init 7 | from .common import * 8 | 9 | class ResidualSequential(nn.Sequential): 10 | def __init__(self, *args): 11 | super(ResidualSequential, self).__init__(*args) 12 | 13 | def forward(self, x): 14 | out = super(ResidualSequential, self).forward(x) 15 | # print(x.size(), out.size()) 16 | x_ = None 17 | if out.size(2) != x.size(2) or out.size(3) != x.size(3): 18 | diff2 = x.size(2) - out.size(2) 19 | diff3 = x.size(3) - out.size(3) 20 | # print(1) 21 | x_ = x[:, :, diff2 /2:out.size(2) + diff2 / 2, diff3 / 2:out.size(3) + diff3 / 2] 22 | else: 23 | x_ = x 24 | return out + x_ 25 | 26 | def eval(self): 27 | print(2) 28 | for m in self.modules(): 29 | m.eval() 30 | exit() 31 | 32 | 33 | def get_block(num_channels, norm_layer, act_fun): 34 | layers = [ 35 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 36 | norm_layer(num_channels, affine=True), 37 | act(act_fun), 38 | nn.Conv2d(num_channels, num_channels, 3, 1, 1, bias=False), 39 | norm_layer(num_channels, affine=True), 40 | ] 41 | return layers 42 | 43 | 44 | class ResNet(nn.Module): 45 | def __init__(self, num_input_channels, num_output_channels, num_blocks, num_channels, need_residual=True, act_fun='LeakyReLU', need_sigmoid=True, norm_layer=nn.BatchNorm2d, pad='reflection'): 46 | ''' 47 | pad = 'start|zero|replication' 48 | ''' 49 | super(ResNet, self).__init__() 50 | 51 | if need_residual: 52 | s = ResidualSequential 53 | else: 54 | s = nn.Sequential 55 | 56 | stride = 1 57 | # First layers 58 | layers = [ 59 | # nn.ReplicationPad2d(num_blocks * 2 * stride + 3), 60 | conv(num_input_channels, num_channels, 3, stride=1, bias=True, pad=pad), 61 | act(act_fun) 62 | ] 63 | # Residual blocks 64 | # layers_residual = [] 65 | for i in range(num_blocks): 66 | layers += [s(*get_block(num_channels, norm_layer, act_fun))] 67 | 68 | layers += [ 69 | nn.Conv2d(num_channels, num_channels, 3, 1, 1), 70 | norm_layer(num_channels, affine=True) 71 | ] 72 | 73 | # if need_residual: 74 | # layers += [ResidualSequential(*layers_residual)] 75 | # else: 76 | # layers += [Sequential(*layers_residual)] 77 | 78 | # if factor >= 2: 79 | # # Do upsampling if needed 80 | # layers += [ 81 | # nn.Conv2d(num_channels, num_channels * 82 | # factor ** 2, 3, 1), 83 | # nn.PixelShuffle(factor), 84 | # act(act_fun) 85 | # ] 86 | layers += [ 87 | conv(num_channels, num_output_channels, 3, 1, bias=True, pad=pad), 88 | nn.Sigmoid() 89 | ] 90 | self.model = nn.Sequential(*layers) 91 | 92 | def forward(self, input): 93 | return self.model(input) 94 | 95 | def eval(self): 96 | self.model.eval() 97 | -------------------------------------------------------------------------------- /DIP/models/skip.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | import ipdb 5 | 6 | 7 | def skip( 8 | num_input_channels=2, num_output_channels=3, 9 | num_channels_down=[16, 32, 64, 128, 128], num_channels_up=[16, 32, 64, 128, 128], num_channels_skip=[4, 4, 4, 4, 4], 10 | filter_size_down=3, filter_size_up=3, filter_skip_size=1, 11 | need_sigmoid=True, need_bias=True, 12 | pad='zero', upsample_mode='nearest', downsample_mode='stride', act_fun='LeakyReLU', 13 | need1x1_up=True): 14 | """Assembles encoder-decoder with skip connections. 15 | 16 | Arguments: 17 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 18 | pad (string): zero|reflection (default: 'zero') 19 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 20 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 21 | 22 | """ 23 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 24 | 25 | n_scales = len(num_channels_down) 26 | 27 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 28 | upsample_mode = [upsample_mode]*n_scales 29 | 30 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 31 | downsample_mode = [downsample_mode]*n_scales 32 | 33 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 34 | filter_size_down = [filter_size_down]*n_scales 35 | 36 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 37 | filter_size_up = [filter_size_up]*n_scales 38 | # ipdb.set_trace() 39 | 40 | last_scale = n_scales - 1 41 | 42 | cur_depth = None 43 | # ipdb.set_trace() 44 | model = nn.Sequential() 45 | model_tmp = model 46 | 47 | input_depth = num_input_channels 48 | for i in range(len(num_channels_down)): 49 | 50 | deeper = nn.Sequential() 51 | skip = nn.Sequential() 52 | 53 | if num_channels_skip[i] != 0: 54 | model_tmp.add(Concat(1, skip, deeper)) 55 | else: 56 | model_tmp.add(deeper) 57 | 58 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 59 | 60 | if num_channels_skip[i] != 0: 61 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 62 | skip.add(bn(num_channels_skip[i])) 63 | skip.add(act(act_fun)) 64 | 65 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 66 | 67 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 68 | deeper.add(bn(num_channels_down[i])) 69 | deeper.add(act(act_fun)) 70 | 71 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 72 | deeper.add(bn(num_channels_down[i])) 73 | deeper.add(act(act_fun)) 74 | 75 | deeper_main = nn.Sequential() 76 | 77 | if i == len(num_channels_down) - 1: 78 | # The deepest 79 | k = num_channels_down[i] 80 | else: 81 | deeper.add(deeper_main) 82 | k = num_channels_up[i + 1] 83 | 84 | deeper.add(nn.Upsample(scale_factor=2, mode=upsample_mode[i])) 85 | 86 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 87 | model_tmp.add(bn(num_channels_up[i])) 88 | model_tmp.add(act(act_fun)) 89 | 90 | 91 | if need1x1_up: 92 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 93 | model_tmp.add(bn(num_channels_up[i])) 94 | model_tmp.add(act(act_fun)) 95 | 96 | input_depth = num_channels_down[i] 97 | model_tmp = deeper_main 98 | 99 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 100 | if need_sigmoid: 101 | model.add(nn.Sigmoid()) 102 | 103 | return model 104 | -------------------------------------------------------------------------------- /DIP/models/skip_search_up.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | import ipdb 5 | 6 | from NAS import operations 7 | from NAS import gen_upsample_layer 8 | 9 | def skip(model_index, 10 | num_input_channels=2, 11 | num_output_channels=3, 12 | num_channels_down=[16, 32, 64, 128, 128], 13 | num_channels_up=[16, 32, 64, 128, 128], 14 | num_channels_skip=[4, 4, 4, 4, 4], 15 | filter_size_down=3, 16 | filter_size_up=3, 17 | filter_skip_size=1, 18 | need_sigmoid=True, 19 | need_bias=True, 20 | pad='zero', 21 | upsample_mode='nearest', 22 | downsample_mode='stride', 23 | act_fun='LeakyReLU', 24 | need1x1_up=True): 25 | 26 | """Assembles encoder-decoder with skip connections. 27 | 28 | Arguments: 29 | act_fun: Either string 'LeakyReLU|Swish|ELU|none' or module (e.g. nn.ReLU) 30 | pad (string): zero|reflection (default: 'zero') 31 | upsample_mode (string): 'nearest|bilinear' (default: 'nearest') 32 | downsample_mode (string): 'stride|avg|max|lanczos2' (default: 'stride') 33 | 34 | """ 35 | assert len(num_channels_down) == len(num_channels_up) == len(num_channels_skip) 36 | 37 | n_scales = len(num_channels_down) 38 | 39 | if not (isinstance(upsample_mode, list) or isinstance(upsample_mode, tuple)) : 40 | upsample_mode = [upsample_mode] * n_scales 41 | 42 | if not (isinstance(downsample_mode, list)or isinstance(downsample_mode, tuple)): 43 | downsample_mode = [downsample_mode] * n_scales 44 | 45 | if not (isinstance(filter_size_down, list) or isinstance(filter_size_down, tuple)) : 46 | filter_size_down = [filter_size_down] * n_scales 47 | 48 | if not (isinstance(filter_size_up, list) or isinstance(filter_size_up, tuple)) : 49 | filter_size_up = [filter_size_up] * n_scales 50 | # ipdb.set_trace() 51 | 52 | last_scale = n_scales - 1 53 | 54 | cur_depth = None 55 | # ipdb.set_trace() 56 | model = nn.Sequential() 57 | model_tmp = model 58 | 59 | input_depth = num_input_channels 60 | for i in range(len(num_channels_down)): 61 | 62 | deeper = nn.Sequential() 63 | skip = nn.Sequential() 64 | 65 | if num_channels_skip[i] != 0: 66 | model_tmp.add(Concat(1, skip, deeper)) 67 | else: 68 | model_tmp.add(deeper) 69 | 70 | model_tmp.add(bn(num_channels_skip[i] + (num_channels_up[i + 1] if i < last_scale else num_channels_down[i]))) 71 | 72 | if num_channels_skip[i] != 0: 73 | skip.add(conv(input_depth, num_channels_skip[i], filter_skip_size, bias=need_bias, pad=pad)) 74 | skip.add(bn(num_channels_skip[i])) 75 | skip.add(act(act_fun)) 76 | 77 | # skip.add(Concat(2, GenNoise(nums_noise[i]), skip_part)) 78 | 79 | deeper.add(conv(input_depth, num_channels_down[i], filter_size_down[i], 2, bias=need_bias, pad=pad, downsample_mode=downsample_mode[i])) 80 | deeper.add(bn(num_channels_down[i])) 81 | deeper.add(act(act_fun)) 82 | 83 | deeper.add(conv(num_channels_down[i], num_channels_down[i], filter_size_down[i], bias=need_bias, pad=pad)) 84 | deeper.add(bn(num_channels_down[i])) 85 | deeper.add(act(act_fun)) 86 | 87 | deeper_main = nn.Sequential() 88 | 89 | if i == len(num_channels_down) - 1: 90 | # The deepest 91 | k = num_channels_down[i] 92 | 93 | else: 94 | deeper.add(deeper_main) 95 | k = num_channels_up[i+1] 96 | 97 | C_in = num_channels_down[i] 98 | C_out = num_channels_down[i] 99 | 100 | deeper.add( 101 | gen_upsample_layer.gen_layer( 102 | C_in=C_in, 103 | C_out=C_out, 104 | model_index=model_index 105 | ) 106 | ) 107 | 108 | model_tmp.add(conv(num_channels_skip[i] + k, num_channels_up[i], filter_size_up[i], 1, bias=need_bias, pad=pad)) 109 | model_tmp.add(bn(num_channels_up[i])) 110 | model_tmp.add(act(act_fun)) 111 | 112 | if need1x1_up: 113 | model_tmp.add(conv(num_channels_up[i], num_channels_up[i], 1, bias=need_bias, pad=pad)) 114 | model_tmp.add(bn(num_channels_up[i])) 115 | model_tmp.add(act(act_fun)) 116 | 117 | input_depth = num_channels_down[i] 118 | model_tmp = deeper_main 119 | 120 | model.add(conv(num_channels_up[0], num_output_channels, 1, bias=need_bias, pad=pad)) 121 | if need_sigmoid: 122 | model.add(nn.Sigmoid()) 123 | 124 | return model 125 | -------------------------------------------------------------------------------- /DIP/models/texture_nets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .common import * 4 | 5 | 6 | normalization = nn.BatchNorm2d 7 | 8 | 9 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero'): 10 | if pad == 'zero': 11 | return nn.Conv2d(in_f, out_f, kernel_size, stride, padding=(kernel_size - 1) / 2, bias=bias) 12 | elif pad == 'reflection': 13 | layers = [nn.ReflectionPad2d((kernel_size - 1) / 2), 14 | nn.Conv2d(in_f, out_f, kernel_size, stride, padding=0, bias=bias)] 15 | return nn.Sequential(*layers) 16 | 17 | def get_texture_nets(inp=3, ratios = [32, 16, 8, 4, 2, 1], fill_noise=False, pad='zero', need_sigmoid=False, conv_num=8, upsample_mode='nearest'): 18 | 19 | 20 | for i in range(len(ratios)): 21 | j = i + 1 22 | 23 | seq = nn.Sequential() 24 | 25 | tmp = nn.AvgPool2d(ratios[i], ratios[i]) 26 | 27 | seq.add(tmp) 28 | if fill_noise: 29 | seq.add(GenNoise(inp)) 30 | 31 | seq.add(conv(inp, conv_num, 3, pad=pad)) 32 | seq.add(normalization(conv_num)) 33 | seq.add(act()) 34 | 35 | seq.add(conv(conv_num, conv_num, 3, pad=pad)) 36 | seq.add(normalization(conv_num)) 37 | seq.add(act()) 38 | 39 | seq.add(conv(conv_num, conv_num, 1, pad=pad)) 40 | seq.add(normalization(conv_num)) 41 | seq.add(act()) 42 | 43 | if i == 0: 44 | seq.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 45 | cur = seq 46 | else: 47 | 48 | cur_temp = cur 49 | 50 | cur = nn.Sequential() 51 | 52 | # Batch norm before merging 53 | seq.add(normalization(conv_num)) 54 | cur_temp.add(normalization(conv_num * (j - 1))) 55 | 56 | cur.add(Concat(1, cur_temp, seq)) 57 | 58 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 59 | cur.add(normalization(conv_num * j)) 60 | cur.add(act()) 61 | 62 | cur.add(conv(conv_num * j, conv_num * j, 3, pad=pad)) 63 | cur.add(normalization(conv_num * j)) 64 | cur.add(act()) 65 | 66 | cur.add(conv(conv_num * j, conv_num * j, 1, pad=pad)) 67 | cur.add(normalization(conv_num * j)) 68 | cur.add(act()) 69 | 70 | if i == len(ratios) - 1: 71 | cur.add(conv(conv_num * j, 3, 1, pad=pad)) 72 | else: 73 | cur.add(nn.Upsample(scale_factor=2, mode=upsample_mode)) 74 | 75 | model = cur 76 | if need_sigmoid: 77 | model.add(nn.Sigmoid()) 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /DIP/models/unet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | 7 | class ListModule(nn.Module): 8 | def __init__(self, *args): 9 | super(ListModule, self).__init__() 10 | idx = 0 11 | for module in args: 12 | self.add_module(str(idx), module) 13 | idx += 1 14 | 15 | def __getitem__(self, idx): 16 | if idx >= len(self._modules): 17 | raise IndexError('index {} is out of range'.format(idx)) 18 | if idx < 0: 19 | idx = len(self) + idx 20 | 21 | it = iter(self._modules.values()) 22 | for i in range(idx): 23 | next(it) 24 | return next(it) 25 | 26 | def __iter__(self): 27 | return iter(self._modules.values()) 28 | 29 | def __len__(self): 30 | return len(self._modules) 31 | 32 | class UNet(nn.Module): 33 | ''' 34 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 35 | pad in ['zero', 'replication', 'none'] 36 | ''' 37 | def __init__(self, num_input_channels=3, num_output_channels=3, 38 | feature_scale=4, more_layers=0, concat_x=False, 39 | upsample_mode='deconv', pad='zero', norm_layer=nn.InstanceNorm2d, need_sigmoid=True, need_bias=True): 40 | super(UNet, self).__init__() 41 | 42 | self.feature_scale = feature_scale 43 | self.more_layers = more_layers 44 | self.concat_x = concat_x 45 | 46 | 47 | filters = [64, 128, 256, 512, 1024] 48 | filters = [x // self.feature_scale for x in filters] 49 | 50 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 51 | 52 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 53 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 54 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 55 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 56 | 57 | # more downsampling layers 58 | if self.more_layers > 0: 59 | self.more_downs = [ 60 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 61 | self.more_ups = [unetUp(filters[4], upsample_mode, need_bias, pad, same_num_filt =True) for i in range(self.more_layers)] 62 | 63 | self.more_downs = ListModule(*self.more_downs) 64 | self.more_ups = ListModule(*self.more_ups) 65 | 66 | self.up4 = unetUp(filters[3], upsample_mode, need_bias, pad) 67 | self.up3 = unetUp(filters[2], upsample_mode, need_bias, pad) 68 | self.up2 = unetUp(filters[1], upsample_mode, need_bias, pad) 69 | self.up1 = unetUp(filters[0], upsample_mode, need_bias, pad) 70 | 71 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 72 | 73 | if need_sigmoid: 74 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 75 | 76 | def forward(self, inputs): 77 | 78 | # Downsample 79 | downs = [inputs] 80 | down = nn.AvgPool2d(2, 2) 81 | for i in range(4 + self.more_layers): 82 | downs.append(down(downs[-1])) 83 | 84 | in64 = self.start(inputs) 85 | if self.concat_x: 86 | in64 = torch.cat([in64, downs[0]], 1) 87 | 88 | down1 = self.down1(in64) 89 | if self.concat_x: 90 | down1 = torch.cat([down1, downs[1]], 1) 91 | 92 | down2 = self.down2(down1) 93 | if self.concat_x: 94 | down2 = torch.cat([down2, downs[2]], 1) 95 | 96 | down3 = self.down3(down2) 97 | if self.concat_x: 98 | down3 = torch.cat([down3, downs[3]], 1) 99 | 100 | down4 = self.down4(down3) 101 | if self.concat_x: 102 | down4 = torch.cat([down4, downs[4]], 1) 103 | 104 | if self.more_layers > 0: 105 | prevs = [down4] 106 | for kk, d in enumerate(self.more_downs): 107 | # print(prevs[-1].size()) 108 | out = d(prevs[-1]) 109 | if self.concat_x: 110 | out = torch.cat([out, downs[kk + 5]], 1) 111 | 112 | prevs.append(out) 113 | 114 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 115 | for idx in range(self.more_layers - 1): 116 | l = self.more_ups[self.more - idx - 2] 117 | up_= l(up_, prevs[self.more - idx - 2]) 118 | else: 119 | up_= down4 120 | 121 | up4= self.up4(up_, down3) 122 | up3= self.up3(up4, down2) 123 | up2= self.up2(up3, down1) 124 | up1= self.up1(up2, in64) 125 | 126 | return self.final(up1) 127 | 128 | 129 | 130 | class unetConv2(nn.Module): 131 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 132 | super(unetConv2, self).__init__() 133 | 134 | if norm_layer is not None: 135 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 136 | norm_layer(out_size), 137 | nn.ReLU(),) 138 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 139 | norm_layer(out_size), 140 | nn.ReLU(),) 141 | else: 142 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 143 | nn.ReLU(),) 144 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 145 | nn.ReLU(),) 146 | def forward(self, inputs): 147 | outputs= self.conv1(inputs) 148 | outputs= self.conv2(outputs) 149 | return outputs 150 | 151 | 152 | class unetDown(nn.Module): 153 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 154 | super(unetDown, self).__init__() 155 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 156 | self.down= nn.MaxPool2d(2, 2) 157 | 158 | def forward(self, inputs): 159 | outputs= self.down(inputs) 160 | outputs= self.conv(outputs) 161 | return outputs 162 | 163 | 164 | class unetUp(nn.Module): 165 | def __init__(self, out_size, upsample_mode, need_bias, pad, same_num_filt=False): 166 | super(unetUp, self).__init__() 167 | 168 | num_filt = out_size if same_num_filt else out_size * 2 169 | if upsample_mode == 'deconv': 170 | self.up= nn.ConvTranspose2d(num_filt, out_size, 4, stride=2, padding=1) 171 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 172 | elif upsample_mode=='bilinear' or upsample_mode=='nearest': 173 | self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode=upsample_mode), 174 | conv(num_filt, out_size, 3, bias=need_bias, pad=pad)) 175 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 176 | else: 177 | assert False 178 | 179 | def forward(self, inputs1, inputs2): 180 | in1_up= self.up(inputs1) 181 | 182 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 183 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 184 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 185 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 186 | else: 187 | inputs2_ = inputs2 188 | 189 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 190 | 191 | return output 192 | -------------------------------------------------------------------------------- /DIP/models/unet_search_up.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .common import * 6 | from NAS import operations 7 | from NAS import gen_upsample_layer 8 | 9 | class ListModule(nn.Module): 10 | def __init__(self, *args): 11 | super(ListModule, self).__init__() 12 | idx = 0 13 | for module in args: 14 | self.add_module(str(idx), module) 15 | idx += 1 16 | 17 | def __getitem__(self, idx): 18 | if idx >= len(self._modules): 19 | raise IndexError('index {} is out of range'.format(idx)) 20 | if idx < 0: 21 | idx = len(self) + idx 22 | 23 | it = iter(self._modules.values()) 24 | for i in range(idx): 25 | next(it) 26 | return next(it) 27 | 28 | def __iter__(self): 29 | return iter(self._modules.values()) 30 | 31 | def __len__(self): 32 | return len(self._modules) 33 | 34 | class UNet(nn.Module): 35 | ''' 36 | upsample_mode in ['deconv', 'nearest', 'bilinear'] 37 | pad in ['zero', 'replication', 'none'] 38 | ''' 39 | def __init__(self, 40 | model_index, 41 | use_act, 42 | num_input_channels=3, 43 | num_output_channels=3, 44 | feature_scale=4, 45 | more_layers=0, 46 | concat_x=False, 47 | upsample_mode='deconv', 48 | pad='zero', 49 | norm_layer=nn.InstanceNorm2d, 50 | need_sigmoid=True, 51 | need_bias=True): 52 | 53 | super(UNet, self).__init__() 54 | 55 | self.feature_scale = feature_scale 56 | self.more_layers = more_layers 57 | self.concat_x = concat_x 58 | 59 | 60 | filters = [64, 128, 256, 512, 1024] 61 | filters = [x // self.feature_scale for x in filters] 62 | 63 | self.start = unetConv2(num_input_channels, filters[0] if not concat_x else filters[0] - num_input_channels, norm_layer, need_bias, pad) 64 | 65 | self.down1 = unetDown(filters[0], filters[1] if not concat_x else filters[1] - num_input_channels, norm_layer, need_bias, pad) 66 | self.down2 = unetDown(filters[1], filters[2] if not concat_x else filters[2] - num_input_channels, norm_layer, need_bias, pad) 67 | self.down3 = unetDown(filters[2], filters[3] if not concat_x else filters[3] - num_input_channels, norm_layer, need_bias, pad) 68 | self.down4 = unetDown(filters[3], filters[4] if not concat_x else filters[4] - num_input_channels, norm_layer, need_bias, pad) 69 | 70 | # more downsampling layers 71 | if self.more_layers > 0: 72 | self.more_downs = [ 73 | unetDown(filters[4], filters[4] if not concat_x else filters[4] - num_input_channels , norm_layer, need_bias, pad) for i in range(self.more_layers)] 74 | self.more_ups = [unetUp(filters[4], need_bias, pad, same_num_filt=True, model_index=model_index, use_act=use_act) for i in range(self.more_layers)] 75 | 76 | self.more_downs = ListModule(*self.more_downs) 77 | self.more_ups = ListModule(*self.more_ups) 78 | 79 | self.up4 = unetUp(filters[3], need_bias, pad, model_index=model_index, use_act=use_act) 80 | self.up3 = unetUp(filters[2], need_bias, pad, model_index=model_index, use_act=use_act) 81 | self.up2 = unetUp(filters[1], need_bias, pad, model_index=model_index, use_act=use_act) 82 | self.up1 = unetUp(filters[0], need_bias, pad, model_index=model_index, use_act=use_act) 83 | 84 | self.final = conv(filters[0], num_output_channels, 1, bias=need_bias, pad=pad) 85 | 86 | if need_sigmoid: 87 | self.final = nn.Sequential(self.final, nn.Sigmoid()) 88 | 89 | def forward(self, inputs): 90 | 91 | # Downsample 92 | downs = [inputs] 93 | down = nn.AvgPool2d(2, 2) 94 | for i in range(4 + self.more_layers): 95 | downs.append(down(downs[-1])) 96 | 97 | in64 = self.start(inputs) 98 | if self.concat_x: 99 | in64 = torch.cat([in64, downs[0]], 1) 100 | 101 | down1 = self.down1(in64) 102 | if self.concat_x: 103 | down1 = torch.cat([down1, downs[1]], 1) 104 | 105 | down2 = self.down2(down1) 106 | if self.concat_x: 107 | down2 = torch.cat([down2, downs[2]], 1) 108 | 109 | down3 = self.down3(down2) 110 | if self.concat_x: 111 | down3 = torch.cat([down3, downs[3]], 1) 112 | 113 | down4 = self.down4(down3) 114 | if self.concat_x: 115 | down4 = torch.cat([down4, downs[4]], 1) 116 | 117 | if self.more_layers > 0: 118 | prevs = [down4] 119 | for kk, d in enumerate(self.more_downs): 120 | # print(prevs[-1].size()) 121 | out = d(prevs[-1]) 122 | if self.concat_x: 123 | out = torch.cat([out, downs[kk + 5]], 1) 124 | 125 | prevs.append(out) 126 | 127 | up_ = self.more_ups[-1](prevs[-1], prevs[-2]) 128 | for idx in range(self.more_layers - 1): 129 | l = self.more_ups[self.more - idx - 2] 130 | up_= l(up_, prevs[self.more - idx - 2]) 131 | else: 132 | up_= down4 133 | 134 | up4= self.up4(up_, down3) 135 | up3= self.up3(up4, down2) 136 | up2= self.up2(up3, down1) 137 | up1= self.up1(up2, in64) 138 | 139 | return self.final(up1) 140 | 141 | 142 | 143 | class unetConv2(nn.Module): 144 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 145 | super(unetConv2, self).__init__() 146 | 147 | # print(pad) 148 | if norm_layer is not None: 149 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 150 | norm_layer(out_size), 151 | nn.ReLU(),) 152 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 153 | norm_layer(out_size), 154 | nn.ReLU(),) 155 | else: 156 | self.conv1= nn.Sequential(conv(in_size, out_size, 3, bias=need_bias, pad=pad), 157 | nn.ReLU(),) 158 | self.conv2= nn.Sequential(conv(out_size, out_size, 3, bias=need_bias, pad=pad), 159 | nn.ReLU(),) 160 | def forward(self, inputs): 161 | outputs= self.conv1(inputs) 162 | outputs= self.conv2(outputs) 163 | return outputs 164 | 165 | 166 | class unetDown(nn.Module): 167 | def __init__(self, in_size, out_size, norm_layer, need_bias, pad): 168 | super(unetDown, self).__init__() 169 | self.conv= unetConv2(in_size, out_size, norm_layer, need_bias, pad) 170 | self.down= nn.MaxPool2d(2, 2) 171 | 172 | def forward(self, inputs): 173 | outputs= self.down(inputs) 174 | outputs= self.conv(outputs) 175 | return outputs 176 | 177 | 178 | class unetUp(nn.Module): 179 | def __init__(self, out_size, need_bias, pad, model_index, use_act, same_num_filt=False): 180 | super(unetUp, self).__init__() 181 | 182 | num_filt = out_size if same_num_filt else out_size * 2 183 | self.up = gen_upsample_layer.gen_layer(C_in=num_filt, C_out=out_size, use_act=use_act, model_index=model_index) 184 | self.conv= unetConv2(out_size * 2, out_size, None, need_bias, pad) 185 | 186 | def forward(self, inputs1, inputs2): 187 | in1_up= self.up(inputs1) 188 | 189 | if (inputs2.size(2) != in1_up.size(2)) or (inputs2.size(3) != in1_up.size(3)): 190 | diff2 = (inputs2.size(2) - in1_up.size(2)) // 2 191 | diff3 = (inputs2.size(3) - in1_up.size(3)) // 2 192 | inputs2_ = inputs2[:, :, diff2 : diff2 + in1_up.size(2), diff3 : diff3 + in1_up.size(3)] 193 | else: 194 | inputs2_ = inputs2 195 | 196 | output= self.conv(torch.cat([in1_up, inputs2_], 1)) 197 | 198 | return output 199 | -------------------------------------------------------------------------------- /DIP/super-resolution-test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import ipdb 9 | import random 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | from skimage.measure import compare_psnr 14 | 15 | from models.downsampler import Downsampler 16 | from utils.sr_utils import * 17 | from utils.timer import Timer 18 | 19 | import torch 20 | import torch.optim 21 | 22 | def rgb2ycbcr(im_rgb): 23 | im_rgb = im_rgb.astype(np.float32) 24 | im_ycrcb = cv2.cvtColor(im_rgb, cv2.COLOR_RGB2YCR_CB) 25 | im_ycbcr = im_ycrcb[:,:,(0,2,1)].astype(np.float32) 26 | im_ycbcr[:,:,0] = (im_ycbcr[:,:,0]*(235-16)+16)/255.0 #to [16/255, 235/255] 27 | im_ycbcr[:,:,1:] = (im_ycbcr[:,:,1:]*(240-16)+16)/255.0 #to [16/255, 240/255] 28 | return im_ycbcr 29 | 30 | def compare_psnr_y(x, y): 31 | return compare_psnr(rgb2ycbcr(x.transpose(1,2,0))[:,:,0], rgb2ycbcr(y.transpose(1,2,0))[:,:,0]) 32 | 33 | import warnings 34 | warnings.filterwarnings("ignore") 35 | 36 | torch.backends.cudnn.enabled = True 37 | torch.backends.cudnn.benchmark = True 38 | torch.backends.cudnn.deterministic = True 39 | dtype = torch.cuda.FloatTensor 40 | 41 | def parse_args(): 42 | parser = argparse.ArgumentParser(description='NAS-DIP Super-resolution') 43 | 44 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 45 | parser.add_argument('--num_iter', dest='num_iter', default=2000, type=int) 46 | parser.add_argument('--factor', dest='factor', default=4, type=int) 47 | parser.add_argument('--show_every', dest='show_every', default=100, type=int) 48 | parser.add_argument('--lr', dest='lr', default=0.01, type=float) 49 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 50 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 51 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 52 | parser.add_argument('--output_path', dest='output_path',default='results/sr', type=str) 53 | parser.add_argument('--random_seed', dest='random_seed',default=0, type=int) 54 | parser.add_argument('--net', dest='net',default='default', type=str) 55 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=0.03, type=float) 56 | parser.add_argument('--i_NAS', dest='i_NAS', default=-1, type=int) 57 | parser.add_argument('--job_index', dest='job_index', default=1, type=int) 58 | parser.add_argument('--save_png', dest='save_png', default=0, type=int) 59 | parser.add_argument('--image_name', type=str) 60 | 61 | args = parser.parse_args() 62 | return args 63 | 64 | if __name__ == '__main__': 65 | 66 | args = parse_args() 67 | 68 | img_path = 'data/sr/' + args.image_name 69 | 70 | imgs = load_LR_HR_imgs_sr(img_path , -1, args.factor, 'CROP') 71 | 72 | from models.model_sr import Model 73 | net = Model() 74 | 75 | net = net.type(dtype) 76 | 77 | net_input = get_noise(args.input_depth, args.noise_method, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach() 78 | 79 | mse = torch.nn.MSELoss().type(dtype) 80 | 81 | img_LR_var = np_to_torch(imgs['LR_np']).type(dtype) 82 | downsampler = Downsampler(n_planes=3, factor=args.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype) 83 | 84 | psnr_gt_best = 0 85 | 86 | i = 0 87 | PSNR_list = [] 88 | 89 | _t = {'im_detect' : Timer(), 'misc' : Timer()} 90 | 91 | def closure(): 92 | 93 | global i, net_input, psnr_gt_best, PSNR_list 94 | 95 | _t['im_detect'].tic() 96 | 97 | if args.reg_noise_std > 0: 98 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 99 | 100 | out_HR = net(net_input) #torch.Size([1, 3, tH, tW]): x 101 | out_LR = downsampler(out_HR) #torch.Size([1, 3, H, W]) 102 | 103 | total_loss = mse(out_LR, img_LR_var) 104 | total_loss.backward() 105 | 106 | q1 = torch_to_np(out_HR)[:3].sum(0) 107 | t1 = np.where(q1.sum(0) > 0)[0] 108 | t2 = np.where(q1.sum(1) > 0)[0] 109 | psnr_HR = compare_psnr_y(imgs['HR_np'][:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4], \ 110 | torch_to_np(out_HR)[:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4]) 111 | PSNR_list.append(psnr_HR) 112 | 113 | if psnr_HR > psnr_gt_best: 114 | psnr_gt_best = psnr_HR 115 | 116 | _t['im_detect'].toc() 117 | 118 | print ('Iteration %05d Loss %f PSNR_HR %.3f Time %.3f' % (i, total_loss.item(), psnr_HR, _t['im_detect'].total_time), '\r', end='') 119 | if i % args.show_every == 0: 120 | if args.save_png == 1: 121 | out_HR_np = torch_to_np(out_HR) 122 | cv2.imwrite(os.path.join(global_path, image_name, str(i) + '.png'),\ 123 | np.clip(out_HR_np, 0, 1).transpose(1, 2, 0)[:,:,::-1] * 255) 124 | 125 | if args.plot: 126 | plot_image_grid([np.clip(out_HR_np, 0, 1)], factor=4, nrow=1) 127 | 128 | i += 1 129 | 130 | return total_loss 131 | 132 | net_input_saved = net_input.detach().clone() 133 | noise = net_input.detach().clone() 134 | 135 | p = get_params('net', net, net_input) 136 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 137 | 138 | psnr_gt_best_list.append(psnr_gt_best) 139 | 140 | print('Finish optimization') -------------------------------------------------------------------------------- /DIP/super-resolution.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import matplotlib 3 | matplotlib.use('agg') 4 | import matplotlib.pyplot as plt 5 | 6 | import os 7 | import cv2 8 | import ipdb 9 | import random 10 | import pickle 11 | import argparse 12 | import numpy as np 13 | from skimage.measure import compare_psnr 14 | # from torchviz import make_dot, make_dot_from_trace 15 | # from torchvision import transforms, utils 16 | # from torch.utils.data import Dataset, DataLoader 17 | 18 | 19 | from models.downsampler import Downsampler 20 | from utils.sr_utils import * 21 | from utils.timer import Timer 22 | 23 | 24 | import torch 25 | import torch.optim 26 | 27 | def rgb2ycbcr(im_rgb): 28 | im_rgb = im_rgb.astype(np.float32) 29 | im_ycrcb = cv2.cvtColor(im_rgb, cv2.COLOR_RGB2YCR_CB) 30 | im_ycbcr = im_ycrcb[:,:,(0,2,1)].astype(np.float32) 31 | im_ycbcr[:,:,0] = (im_ycbcr[:,:,0]*(235-16)+16)/255.0 #to [16/255, 235/255] 32 | im_ycbcr[:,:,1:] = (im_ycbcr[:,:,1:]*(240-16)+16)/255.0 #to [16/255, 240/255] 33 | return im_ycbcr 34 | 35 | def compare_psnr_y(x, y): 36 | return compare_psnr(rgb2ycbcr(x.transpose(1,2,0))[:,:,0], rgb2ycbcr(y.transpose(1,2,0))[:,:,0]) 37 | 38 | import warnings 39 | warnings.filterwarnings("ignore") 40 | 41 | torch.backends.cudnn.enabled = True 42 | torch.backends.cudnn.benchmark = True 43 | torch.backends.cudnn.deterministic = True 44 | dtype = torch.cuda.FloatTensor 45 | 46 | def parse_args(): 47 | parser = argparse.ArgumentParser(description='NAS-DIP Super-resolution') 48 | 49 | parser.add_argument('--optimizer', dest='optimizer',default='adam', type=str) 50 | parser.add_argument('--num_iter', dest='num_iter', default=2000, type=int) 51 | parser.add_argument('--factor', dest='factor', default=4, type=int) 52 | parser.add_argument('--show_every', dest='show_every', default=100, type=int) 53 | parser.add_argument('--lr', dest='lr', default=0.01, type=float) 54 | parser.add_argument('--plot', dest='plot', default=False, type=bool) 55 | parser.add_argument('--noise_method', dest='noise_method',default='noise', type=str) 56 | parser.add_argument('--input_depth', dest='input_depth', default=32, type=int) 57 | parser.add_argument('--output_path', dest='output_path',default='results/sr', type=str) 58 | parser.add_argument('--random_seed', dest='random_seed',default=0, type=int) 59 | parser.add_argument('--net', dest='net',default='default', type=str) 60 | parser.add_argument('--reg_noise_std', dest='reg_noise_std', default=0.03, type=float) 61 | parser.add_argument('--i_NAS', dest='i_NAS', default=-1, type=int) 62 | parser.add_argument('--job_index', dest='job_index', default=1, type=int) 63 | parser.add_argument('--save_png', dest='save_png', default=0, type=int) 64 | 65 | args = parser.parse_args() 66 | return args 67 | 68 | 69 | if __name__ == '__main__': 70 | 71 | args = parse_args() 72 | 73 | if args.net == 'default': 74 | global_path = args.output_path + '_' + args.net 75 | if not os.path.exists(global_path): 76 | os.makedirs(global_path) 77 | elif args.net == 'NAS': 78 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) 79 | if not os.path.exists(global_path): 80 | os.makedirs(global_path) 81 | elif args.net == 'Multiscale': 82 | from gen_skip_index import skip_index 83 | skip_connect = skip_index() 84 | global_path = args.output_path + '_' + args.net + '_' + str(args.i_NAS) + '_' + str(args.job_index) 85 | if not os.path.exists(global_path): 86 | os.makedirs(global_path) 87 | pickle.dump(skip_connect, open( os.path.join(global_path, 'skip_connect.pkl'), "wb" ) ) 88 | else: 89 | assert False, 'Please choose between default and NAS' 90 | 91 | np.random.seed(args.random_seed) 92 | torch.manual_seed(args.random_seed) 93 | torch.cuda.manual_seed_all(args.random_seed) 94 | 95 | 96 | # #batch x #iter 97 | PSNR_mat = np.empty((0, args.num_iter), dtype=np.float32) 98 | 99 | # Choose figure 100 | img_path_list = ['baby', 'bird', 'butterfly', 'head', 'woman', 'baboon', 'barbara', 'bridge', 'coastguard', 'comic', 'face', \ 101 | 'flowers', 'foreman', 'lenna', 'man', 'monarch', 'pepper', 'ppt3', 'zebra'] 102 | 103 | psnr_gt_best_list = [] 104 | 105 | for image_name in img_path_list: 106 | 107 | if args.save_png == 1 and not os.path.exists(os.path.join(global_path, image_name)): 108 | os.makedirs(os.path.join(global_path, image_name)) 109 | 110 | # Choose figure 111 | img_path = 'data/sr/' + image_name + '_x4_GT.png' 112 | 113 | # Load image 114 | imgs = load_LR_HR_imgs_sr(img_path , -1, args.factor, 'CROP') 115 | 116 | # Visualization 117 | if args.plot: 118 | plot_image_grid([imgs['HR_np']], 4, 6) 119 | 120 | if args.net == 'default': 121 | from models.skip import skip 122 | net = skip(num_input_channels=args.input_depth, 123 | num_output_channels=3, 124 | num_channels_down=[128] * 5, 125 | num_channels_up=[128] * 5, 126 | num_channels_skip=[4] * 5, 127 | upsample_mode='bilinear', 128 | downsample_mode='stride', 129 | need_sigmoid=True, 130 | need_bias=True, 131 | pad='reflection', 132 | act_fun='LeakyReLU') 133 | 134 | elif args.net == 'NAS': 135 | from models.skip_search_up import skip 136 | if args.i_NAS in [249, 250, 251]: 137 | exit(1) 138 | net = skip(model_index=args.i_NAS, 139 | num_input_channels=args.input_depth, 140 | num_output_channels=3, 141 | num_channels_down=[128] * 5, 142 | num_channels_up=[128] * 5, 143 | num_channels_skip=[4] * 5, 144 | upsample_mode='bilinear', 145 | downsample_mode='stride', 146 | need_sigmoid=True, 147 | need_bias=True, 148 | pad='reflection', 149 | act_fun='LeakyReLU') 150 | 151 | elif args.net == 'Multiscale': 152 | from models.cross_skip import skip 153 | net = skip(model_index=args.i_NAS, 154 | skip_index=skip_connect, 155 | num_input_channels=args.input_depth, 156 | num_output_channels=3, 157 | num_channels_down=[128] * 5, 158 | num_channels_up=[128] * 5, 159 | num_channels_skip=[4] * 5, 160 | upsample_mode='bilinear', 161 | downsample_mode='stride', 162 | need_sigmoid=True, 163 | need_bias=True, 164 | pad='reflection', 165 | act_fun='LeakyReLU') 166 | 167 | else: 168 | assert False, 'Please choose between default and NAS' 169 | 170 | net = net.type(dtype) 171 | 172 | # z torch.Size([1, 32, tH, tW]) 173 | net_input = get_noise(args.input_depth, args.noise_method, (imgs['HR_pil'].size[1], imgs['HR_pil'].size[0])).type(dtype).detach() 174 | 175 | # Loss 176 | mse = torch.nn.MSELoss().type(dtype) 177 | 178 | # x0 torch.Size([1, 3, H, W]) 179 | img_LR_var = np_to_torch(imgs['LR_np']).type(dtype) 180 | downsampler = Downsampler(n_planes=3, factor=args.factor, kernel_type='lanczos2', phase=0.5, preserve_size=True).type(dtype) 181 | 182 | psnr_gt_best = 0 183 | 184 | # Main 185 | i = 0 186 | PSNR_list = [] 187 | 188 | _t = {'im_detect' : Timer(), 'misc' : Timer()} 189 | 190 | def closure(): 191 | 192 | global i, net_input, psnr_gt_best, PSNR_list 193 | 194 | _t['im_detect'].tic() 195 | 196 | # Add variation 197 | if args.reg_noise_std > 0: 198 | net_input = net_input_saved + (noise.normal_() * args.reg_noise_std) 199 | 200 | out_HR = net(net_input) #torch.Size([1, 3, tH, tW]): x 201 | out_LR = downsampler(out_HR) #torch.Size([1, 3, H, W]) 202 | 203 | total_loss = mse(out_LR, img_LR_var) 204 | total_loss.backward() 205 | 206 | # psnr_LR = compare_psnr(imgs['LR_np'], torch_to_np(out_LR)) 207 | # psnr_HR = compare_psnr(imgs['HR_np'], torch_to_np(out_HR)) 208 | 209 | q1 = torch_to_np(out_HR)[:3].sum(0) 210 | t1 = np.where(q1.sum(0) > 0)[0] 211 | t2 = np.where(q1.sum(1) > 0)[0] 212 | psnr_HR = compare_psnr_y(imgs['HR_np'][:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4], \ 213 | torch_to_np(out_HR)[:3,t2[0] + 4:t2[-1]-4,t1[0] + 4:t1[-1] - 4]) 214 | PSNR_list.append(psnr_HR) 215 | 216 | if psnr_HR > psnr_gt_best: 217 | psnr_gt_best = psnr_HR 218 | 219 | _t['im_detect'].toc() 220 | 221 | # print ('Iteration %05d Loss %f PSNR_LR %.3f PSNR_HR %.3f Time %.3f' % (i, total_loss.item(), psnr_LR, psnr_HR, _t['im_detect'].total_time), '\r', end='') 222 | print ('Iteration %05d Loss %f PSNR_HR %.3f Time %.3f' % (i, total_loss.item(), psnr_HR, _t['im_detect'].total_time), '\r', end='') 223 | if i % args.show_every == 0: 224 | if args.save_png == 1: 225 | out_HR_np = torch_to_np(out_HR) 226 | cv2.imwrite(os.path.join(global_path, image_name, str(i) + '.png'),\ 227 | np.clip(out_HR_np, 0, 1).transpose(1, 2, 0)[:,:,::-1] * 255) 228 | 229 | if args.plot: 230 | plot_image_grid([np.clip(out_HR_np, 0, 1)], factor=4, nrow=1) 231 | 232 | i += 1 233 | 234 | return total_loss 235 | 236 | net_input_saved = net_input.detach().clone() 237 | noise = net_input.detach().clone() 238 | 239 | p = get_params('net', net, net_input) 240 | optimize(args.optimizer, p, closure, args.lr, args.num_iter) 241 | 242 | PSNR_mat = np.concatenate((PSNR_mat, np.array(PSNR_list).reshape(1,args.num_iter)), axis=0) 243 | pickle.dump( PSNR_mat, open( os.path.join(global_path, 'PSNR.pkl'), "wb" ) ) 244 | 245 | psnr_gt_best_list.append(psnr_gt_best) 246 | 247 | print('Finish optimization') 248 | 249 | 250 | for idx, image_name in enumerate(img_path_list): 251 | print ('Image: %8s PSNR: %.2f' % (image_name, psnr_gt_best_list[idx]), '\n', end='') 252 | print ('Averaged PSNR: %.2f' % (np.mean(psnr_gt_best_list)), '\n', end='') 253 | print ('Averaged PSNR (Set5): %.2f' % (np.mean(psnr_gt_best_list[:5])), '\n', end='') 254 | print ('Averaged PSNR (Set14): %.2f' % (np.mean(psnr_gt_best_list[5:])), '\n', end='') 255 | -------------------------------------------------------------------------------- /DIP/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/utils/__init__.py -------------------------------------------------------------------------------- /DIP/utils/common_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import sys 5 | 6 | import numpy as np 7 | from PIL import Image 8 | import PIL 9 | import numpy as np 10 | import matplotlib.pyplot as plt 11 | 12 | def crop_image(img, d=32): 13 | '''Make dimensions divisible by `d`''' 14 | 15 | new_size = (img.size[0] - img.size[0] % d, 16 | img.size[1] - img.size[1] % d) 17 | 18 | bbox = [ 19 | int((img.size[0] - new_size[0])/2), 20 | int((img.size[1] - new_size[1])/2), 21 | int((img.size[0] + new_size[0])/2), 22 | int((img.size[1] + new_size[1])/2), 23 | ] 24 | 25 | img_cropped = img.crop(bbox) 26 | return img_cropped 27 | 28 | def get_params(opt_over, net, net_input, downsampler=None): 29 | '''Returns parameters that we want to optimize over. 30 | 31 | Args: 32 | opt_over: comma separated list, e.g. "net,input" or "net" 33 | net: network 34 | net_input: torch.Tensor that stores input `z` 35 | ''' 36 | opt_over_list = opt_over.split(',') 37 | params = [] 38 | 39 | for opt in opt_over_list: 40 | 41 | if opt == 'net': 42 | params += [x for x in net.parameters() ] 43 | elif opt=='down': 44 | assert downsampler is not None 45 | params = [x for x in downsampler.parameters()] 46 | elif opt == 'input': 47 | net_input.requires_grad = True 48 | params += [net_input] 49 | else: 50 | assert False, 'what is it?' 51 | 52 | return params 53 | 54 | def get_image_grid(images_np, nrow=8): 55 | '''Creates a grid from a list of images by concatenating them.''' 56 | images_torch = [torch.from_numpy(x) for x in images_np] 57 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 58 | 59 | return torch_grid.numpy() 60 | 61 | def plot_image_grid(images_np, nrow =8, factor=1, interpolation='lanczos'): 62 | """Draws images in a grid 63 | 64 | Args: 65 | images_np: list of images, each image is np.array of size 3xHxW of 1xHxW 66 | nrow: how many images will be in one row 67 | factor: size if the plt.figure 68 | interpolation: interpolation used in plt.imshow 69 | """ 70 | n_channels = max(x.shape[0] for x in images_np) 71 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 72 | 73 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 74 | 75 | grid = get_image_grid(images_np, nrow) 76 | 77 | plt.figure(figsize=(len(images_np) + factor, 12 + factor)) 78 | 79 | if images_np[0].shape[0] == 1: 80 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 81 | else: 82 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 83 | 84 | plt.show(block=False) 85 | plt.pause(1) 86 | plt.close() 87 | 88 | return grid 89 | 90 | def load(path): 91 | """Load PIL image.""" 92 | img = Image.open(path) 93 | return img 94 | 95 | def get_image(path, imsize=-1): 96 | """Load an image and resize to a cpecific size. 97 | 98 | Args: 99 | path: path to image 100 | imsize: tuple or scalar with dimensions; -1 for `no resize` 101 | """ 102 | img = load(path) 103 | 104 | if isinstance(imsize, int): 105 | imsize = (imsize, imsize) 106 | 107 | if imsize[0]!= -1 and img.size != imsize: 108 | if imsize[0] > img.size[0]: 109 | img = img.resize(imsize, Image.BICUBIC) 110 | else: 111 | img = img.resize(imsize, Image.ANTIALIAS) 112 | 113 | img_np = pil_to_np(img) 114 | 115 | return img, img_np 116 | 117 | 118 | 119 | def fill_noise(x, noise_type): 120 | """Fills tensor `x` with noise of type `noise_type`.""" 121 | if noise_type == 'u': 122 | x.uniform_() 123 | elif noise_type == 'n': 124 | x.normal_() 125 | else: 126 | assert False 127 | 128 | def get_noise(input_depth, method, spatial_size, noise_type='u', var=1./10): 129 | """Returns a pytorch.Tensor of size (1 x `input_depth` x `spatial_size[0]` x `spatial_size[1]`) 130 | initialized in a specific way. 131 | Args: 132 | input_depth: number of channels in the tensor 133 | method: `noise` for fillting tensor with noise; `meshgrid` for np.meshgrid 134 | spatial_size: spatial size of the tensor to initialize 135 | noise_type: 'u' for uniform; 'n' for normal 136 | var: a factor, a noise will be multiplicated by. Basically it is standard deviation scaler. 137 | """ 138 | if isinstance(spatial_size, int): 139 | spatial_size = (spatial_size, spatial_size) 140 | if method == 'noise': 141 | shape = [1, input_depth, spatial_size[0], spatial_size[1]] 142 | net_input = torch.zeros(shape) 143 | 144 | fill_noise(net_input, noise_type) 145 | net_input *= var 146 | elif method == 'meshgrid': 147 | assert input_depth == 2 148 | X, Y = np.meshgrid(np.arange(0, spatial_size[1])/float(spatial_size[1]-1), np.arange(0, spatial_size[0])/float(spatial_size[0]-1)) 149 | meshgrid = np.concatenate([X[None,:], Y[None,:]]) 150 | net_input= np_to_torch(meshgrid) 151 | else: 152 | assert False 153 | 154 | return net_input 155 | 156 | def pil_to_np(img_PIL): 157 | '''Converts image in PIL format to np.array. 158 | 159 | From W x H x C [0...255] to C x W x H [0..1] 160 | ''' 161 | ar = np.array(img_PIL) 162 | 163 | if len(ar.shape) == 3: 164 | ar = ar.transpose(2,0,1) 165 | elif len(ar.shape) == 4: 166 | ar = ar.transpose(0,3,1,2) 167 | else: 168 | ar = ar[None, ...] 169 | 170 | return ar.astype(np.float32) / 255. 171 | 172 | def np_to_pil(img_np): 173 | '''Converts image in np.array format to PIL image. 174 | 175 | From C x W x H [0..1] to W x H x C [0...255] 176 | ''' 177 | ar = np.clip(img_np*255,0,255).astype(np.uint8) 178 | 179 | if img_np.shape[0] == 1: 180 | ar = ar[0] 181 | else: 182 | ar = ar.transpose(1, 2, 0) 183 | 184 | return Image.fromarray(ar) 185 | 186 | def np_to_torch(img_np): 187 | '''Converts image in numpy.array to torch.Tensor. 188 | 189 | From C x W x H [0..1] to C x W x H [0..1] 190 | ''' 191 | return torch.from_numpy(img_np)[None, :] 192 | 193 | def torch_to_np(img_var): 194 | '''Converts an image in torch.Tensor format to np.array. 195 | 196 | From 1 x C x W x H [0..1] to C x W x H [0..1] 197 | ''' 198 | return img_var.detach().cpu().numpy()[0] 199 | 200 | 201 | def optimize(optimizer_type, parameters, closure, LR, num_iter): 202 | """Runs optimization loop. 203 | 204 | Args: 205 | optimizer_type: 'LBFGS', 'adam', 'sgd' 206 | parameters: list of Tensors to optimize over 207 | closure: function, that returns loss variable 208 | LR: learning rate 209 | num_iter: number of iterations 210 | """ 211 | if optimizer_type == 'LBFGS': 212 | # Do several steps with adam first 213 | optimizer = torch.optim.Adam(parameters, lr=0.001) 214 | for j in range(100): 215 | optimizer.zero_grad() 216 | closure() 217 | optimizer.step() 218 | 219 | print('Starting optimization with LBFGS') 220 | def closure2(): 221 | optimizer.zero_grad() 222 | return closure() 223 | optimizer = torch.optim.LBFGS(parameters, max_iter=num_iter, lr=LR, tolerance_grad=-1, tolerance_change=-1) 224 | optimizer.step(closure2) 225 | 226 | elif optimizer_type == 'adam': 227 | print('Starting optimization with ADAM') 228 | optimizer = torch.optim.Adam(parameters, lr=LR) 229 | 230 | for j in range(num_iter): 231 | optimizer.zero_grad() 232 | closure() 233 | optimizer.step() 234 | elif optimizer_type == 'sgd': 235 | print('Starting optimization with SGD') 236 | optimizer = torch.optim.SGD(parameters, lr=LR, momentum=0.9) 237 | 238 | for j in range(num_iter): 239 | optimizer.zero_grad() 240 | closure() 241 | optimizer.step() 242 | else: 243 | assert False 244 | -------------------------------------------------------------------------------- /DIP/utils/denoising_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from .common_utils import * 3 | 4 | 5 | 6 | def get_noisy_image(img_np, sigma): 7 | """Adds Gaussian noise to an image. 8 | 9 | Args: 10 | img_np: image, np.array with values from 0 to 1 11 | sigma: std of the noise 12 | """ 13 | img_noisy_np = np.clip(img_np + np.random.normal(scale=sigma, size=img_np.shape), 0, 1).astype(np.float32) 14 | img_noisy_pil = np_to_pil(img_noisy_np) 15 | 16 | return img_noisy_pil, img_noisy_np -------------------------------------------------------------------------------- /DIP/utils/feature_inversion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.transforms as transforms 4 | import torchvision.models as models 5 | from .matcher import Matcher 6 | import os 7 | from collections import OrderedDict 8 | 9 | class View(nn.Module): 10 | def __init__(self): 11 | super(View, self).__init__() 12 | 13 | def forward(self, x): 14 | return x.view(-1) 15 | 16 | def get_vanilla_vgg_features(cut_idx=-1): 17 | if not os.path.exists('vgg_features.pth'): 18 | os.system( 19 | 'wget --no-check-certificate -N https://s3-us-west-2.amazonaws.com/jcjohns-models/vgg19-d01eb7cb.pth') 20 | vgg_weights = torch.load('vgg19-d01eb7cb.pth') 21 | # fix compatibility issues 22 | map = {'classifier.6.weight':u'classifier.7.weight', 'classifier.6.bias':u'classifier.7.bias'} 23 | vgg_weights = OrderedDict([(map[k] if k in map else k,v) for k,v in vgg_weights.iteritems()]) 24 | 25 | 26 | 27 | model = models.vgg19() 28 | model.classifier = nn.Sequential(View(), *model.classifier._modules.values()) 29 | 30 | 31 | model.load_state_dict(vgg_weights) 32 | 33 | torch.save(model.features, 'vgg_features.pth') 34 | torch.save(model.classifier, 'vgg_classifier.pth') 35 | 36 | vgg = torch.load('vgg_features.pth') 37 | if cut_idx > 36: 38 | vgg_classifier = torch.load('vgg_classifier.pth') 39 | vgg = nn.Sequential(*(vgg._modules.values() + vgg_classifier._modules.values())) 40 | 41 | vgg.eval() 42 | 43 | return vgg 44 | 45 | 46 | def get_matcher(net, opt): 47 | idxs = [x for x in opt['layers'].split(',')] 48 | matcher = Matcher(opt['what']) 49 | 50 | def hook(module, input, output): 51 | matcher(module, output) 52 | 53 | for i in idxs: 54 | net._modules[i].register_forward_hook(hook) 55 | 56 | return matcher 57 | 58 | 59 | 60 | def get_vgg(cut_idx=-1): 61 | f = get_vanilla_vgg_features(cut_idx) 62 | 63 | if cut_idx > 0: 64 | num_modules = len(f._modules) 65 | keys_to_delete = [f._modules.keys()[x] for x in range(cut_idx, num_modules)] 66 | for k in keys_to_delete: 67 | del f._modules[k] 68 | 69 | return f 70 | 71 | def vgg_preprocess_var(var): 72 | (r, g, b) = torch.chunk(var, 3, dim=1) 73 | bgr = torch.cat((b, g, r), 1) 74 | out = bgr * 255 - torch.autograd.Variable(vgg_mean[None, ...]).type(var.type()).expand_as(bgr) 75 | return out 76 | 77 | vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1) 78 | 79 | 80 | 81 | def get_preprocessor(imsize): 82 | def vgg_preprocess(tensor): 83 | (r, g, b) = torch.chunk(tensor, 3, dim=0) 84 | bgr = torch.cat((b, g, r), 0) 85 | out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) 86 | return out 87 | preprocess = transforms.Compose([ 88 | transforms.Resize(imsize), 89 | transforms.ToTensor(), 90 | transforms.Lambda(vgg_preprocess) 91 | ]) 92 | 93 | return preprocess 94 | 95 | 96 | def get_deprocessor(): 97 | def vgg_deprocess(tensor): 98 | bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 99 | (b, g, r) = torch.chunk(bgr, 3, dim=0) 100 | rgb = torch.cat((r, g, b), 0) 101 | return rgb 102 | deprocess = transforms.Compose([ 103 | transforms.Lambda(vgg_deprocess), 104 | transforms.Lambda(lambda x: torch.clamp(x, 0, 1)), 105 | transforms.ToPILImage() 106 | ]) 107 | return deprocess 108 | -------------------------------------------------------------------------------- /DIP/utils/inpainting_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import PIL.ImageDraw as ImageDraw 4 | import PIL.ImageFont as ImageFont 5 | from .common_utils import * 6 | 7 | def get_text_mask(for_image, sz=20): 8 | font_fname = '/usr/share/fonts/truetype/freefont/FreeSansBold.ttf' 9 | font_size = sz 10 | font = ImageFont.truetype(font_fname, font_size) 11 | 12 | img_mask = Image.fromarray(np.array(for_image)*0+255) 13 | draw = ImageDraw.Draw(img_mask) 14 | draw.text((128, 128), "hello world", font=font, fill='rgb(0, 0, 0)') 15 | 16 | return img_mask 17 | 18 | def get_bernoulli_mask(for_image, zero_fraction=0.95): 19 | img_mask_np=(np.random.random_sample(size=pil_to_np(for_image).shape) > zero_fraction).astype(int) 20 | img_mask = np_to_pil(img_mask_np) 21 | 22 | return img_mask 23 | -------------------------------------------------------------------------------- /DIP/utils/load_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import matplotlib.pyplot as plt 3 | from skimage import io 4 | import numpy as np 5 | import pickle 6 | import json 7 | import glob 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | def load_image_coco(num_image=64, crop_size=256): 12 | 13 | images = np.empty((0, crop_size, crop_size, 3), dtype=np.float32) 14 | 15 | coco_anno = json.load(open('/home/chengao/Dataset/annotations/instances_val2017.json')) 16 | # img_pil = img_pil[[0, 4, 9, 14, 20, 30, 45, 50, 51, 52, 63, 65, 75, 83, 90, 91],:,:,:] 17 | for idx in range(len(coco_anno['annotations'])): 18 | GT = coco_anno['annotations'][idx] 19 | image_id = GT['image_id'] 20 | H_box = [int(i) for i in GT['bbox']] 21 | category = GT['category_id'] 22 | 23 | if H_box[2] <= crop_size and H_box[2] > 150 / 256 * crop_size and H_box[3] <= crop_size and H_box[3] > 150 / 256 * crop_size: 24 | 25 | im_file = '/home/chengao/Dataset/val2017/' + (str(image_id)).zfill(12) + '.jpg' 26 | im_data = plt.imread(im_file) 27 | im_height, im_width, nbands = im_data.shape 28 | 29 | height_pad = crop_size - H_box[3] 30 | width_pad = crop_size - H_box[2] 31 | 32 | x0 = H_box[0] - width_pad // 2 33 | x1 = H_box[0] - width_pad // 2 + crop_size 34 | y0 = H_box[1] - height_pad // 2 35 | y1 = H_box[1] - height_pad // 2 + crop_size 36 | 37 | if x0 < 0: 38 | x1 = x1 - x0 39 | x0 = 0 40 | if x1 >= im_width: 41 | continue 42 | 43 | if y0 < 0: 44 | y1 = y1 - y0 45 | y0 = 0 46 | if y1 >= im_height: 47 | continue 48 | 49 | im_data = im_data[y0 : y1, x0 : x1, :].reshape(1, crop_size, crop_size, 3) 50 | 51 | images = np.concatenate((images, im_data), axis=0) 52 | 53 | if len(images) >= num_image: 54 | return images 55 | 56 | 57 | return images 58 | 59 | class DIV2KDataset(Dataset): 60 | 61 | def __init__(self, root_dir, transform=None): 62 | 63 | self.image_list = os.listdir(root_dir) 64 | self.root_dir = root_dir 65 | self.transform = transform 66 | 67 | def __len__(self): 68 | return len(self.image_list) 69 | 70 | def __getitem__(self, idx): 71 | img_name = os.path.join(self.root_dir, 72 | self.image_list[idx]) 73 | image = io.imread(img_name) 74 | image = image / 255. 75 | 76 | if self.transform: 77 | image = self.transform(image) 78 | 79 | return image 80 | 81 | class RandomCrop(object): 82 | 83 | def __init__(self, output_size): 84 | assert isinstance(output_size, (int, tuple)) 85 | if isinstance(output_size, int): 86 | self.output_size = (output_size, output_size) 87 | else: 88 | assert len(output_size) == 2 89 | self.output_size = output_size 90 | 91 | def __call__(self, image): 92 | 93 | h, w = image.shape[:2] 94 | new_h, new_w = self.output_size 95 | 96 | top = np.random.randint(0, h - new_h) 97 | left = np.random.randint(0, w - new_w) 98 | 99 | image = image[top: top + new_h, 100 | left: left + new_w] 101 | 102 | return image 103 | 104 | class ToTensor(object): 105 | """Convert ndarrays in sample to Tensors.""" 106 | 107 | def __call__(self, image): 108 | 109 | # swap color axis because 110 | # numpy image: H x W x C 111 | # torch image: C X H X W 112 | image = image.transpose((2, 0, 1)) 113 | return torch.from_numpy(image) -------------------------------------------------------------------------------- /DIP/utils/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Matcher: 5 | def __init__(self, how='gram_matrix', loss='mse'): 6 | self.mode = 'store' 7 | self.stored = {} 8 | self.losses = {} 9 | 10 | if how in all_features.keys(): 11 | self.get_statistics = all_features[how] 12 | else: 13 | assert False 14 | pass 15 | 16 | if loss in all_losses.keys(): 17 | self.loss = all_losses[loss] 18 | else: 19 | assert False 20 | 21 | def __call__(self, module, features): 22 | statistics = self.get_statistics(features) 23 | 24 | self.statistics = statistics 25 | if self.mode == 'store': 26 | self.stored[module] = statistics.detach().clone() 27 | elif self.mode == 'match': 28 | self.losses[module] = self.loss(statistics, self.stored[module]) 29 | 30 | def clean(self): 31 | self.losses = {} 32 | 33 | def gram_matrix(x): 34 | (b, ch, h, w) = x.size() 35 | features = x.view(b, ch, w * h) 36 | features_t = features.transpose(1, 2) 37 | gram = features.bmm(features_t) / (ch * h * w) 38 | return gram 39 | 40 | 41 | def features(x): 42 | return x 43 | 44 | 45 | all_features = { 46 | 'gram_matrix': gram_matrix, 47 | 'features': features, 48 | } 49 | 50 | all_losses = { 51 | 'mse': nn.MSELoss(), 52 | 'smoothL1': nn.SmoothL1Loss(), 53 | 'L1': nn.L1Loss(), 54 | } 55 | -------------------------------------------------------------------------------- /DIP/utils/perceptual_loss/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/DIP/utils/perceptual_loss/__init__.py -------------------------------------------------------------------------------- /DIP/utils/perceptual_loss/matcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Matcher: 6 | def __init__(self, how='gram_matrix', loss='mse', map_index=933): 7 | self.mode = 'store' 8 | self.stored = {} 9 | self.losses = {} 10 | 11 | if how in all_features.keys(): 12 | self.get_statistics = all_features[how] 13 | else: 14 | assert False 15 | pass 16 | 17 | if loss in all_losses.keys(): 18 | self.loss = all_losses[loss] 19 | else: 20 | assert False 21 | 22 | self.map_index = map_index 23 | self.method = 'match' 24 | 25 | 26 | def __call__(self, module, features): 27 | statistics = self.get_statistics(features) 28 | 29 | self.statistics = statistics 30 | if self.mode == 'store': 31 | self.stored[module] = statistics.detach() 32 | 33 | elif self.mode == 'match': 34 | 35 | if statistics.ndimension() == 2: 36 | 37 | if self.method == 'maximize': 38 | self.losses[module] = - statistics[0, self.map_index] 39 | else: 40 | self.losses[module] = torch.abs(300 - statistics[0, self.map_index]) 41 | 42 | else: 43 | ws = self.window_size 44 | 45 | t = statistics.detach() * 0 46 | 47 | s_cc = statistics[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0 48 | t_cc = t[:1, :, t.shape[2] // 2 - ws:t.shape[2] // 2 + ws, t.shape[3] // 2 - ws:t.shape[3] // 2 + ws] #* 1.0 49 | t_cc[:, self.map_index,...] = 1 50 | 51 | if self.method == 'maximize': 52 | self.losses[module] = -(s_cc * t_cc.contiguous()).sum() 53 | else: 54 | self.losses[module] = torch.abs(200 -(s_cc * t_cc.contiguous())).sum() 55 | 56 | 57 | def clean(self): 58 | self.losses = {} 59 | 60 | def gram_matrix(x): 61 | (b, ch, h, w) = x.size() 62 | features = x.view(b, ch, w * h) 63 | features_t = features.transpose(1, 2) 64 | gram = features.bmm(features_t) / (ch * h * w) 65 | return gram 66 | 67 | 68 | def features(x): 69 | return x 70 | 71 | 72 | all_features = { 73 | 'gram_matrix': gram_matrix, 74 | 'features': features, 75 | } 76 | 77 | all_losses = { 78 | 'mse': nn.MSELoss(), 79 | 'smoothL1': nn.SmoothL1Loss(), 80 | 'L1': nn.L1Loss(), 81 | } 82 | -------------------------------------------------------------------------------- /DIP/utils/perceptual_loss/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms as transforms 5 | import torchvision.models as models 6 | from .matcher import Matcher 7 | from collections import OrderedDict 8 | 9 | from torchvision.models.vgg import model_urls 10 | from torchvision.models import vgg19 11 | from torch.autograd import Variable 12 | 13 | from .vgg_modified import VGGModified 14 | 15 | def get_pretrained_net(name): 16 | """Loads pretrained network""" 17 | if name == 'alexnet_caffe': 18 | if not os.path.exists('alexnet-torch_py3.pth'): 19 | print('Downloading AlexNet') 20 | os.system('wget -O alexnet-torch_py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/77xSWvrDN0CiQtK/download') 21 | return torch.load('alexnet-torch_py3.pth') 22 | elif name == 'vgg19_caffe': 23 | if not os.path.exists('vgg19-caffe-py3.pth'): 24 | print('Downloading VGG-19') 25 | os.system('wget -O vgg19-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/HPcOFQTjXxbmp4X/download') 26 | 27 | vgg = get_vgg19_caffe() 28 | 29 | return vgg 30 | elif name == 'vgg16_caffe': 31 | if not os.path.exists('vgg16-caffe-py3.pth'): 32 | print('Downloading VGG-16') 33 | os.system('wget -O vgg16-caffe-py3.pth --no-check-certificate -nc https://box.skoltech.ru/index.php/s/TUZ62HnPKWdxyLr/download') 34 | 35 | vgg = get_vgg16_caffe() 36 | 37 | return vgg 38 | elif name == 'vgg19_pytorch_modified': 39 | # os.system('wget -O data/feature_inversion/vgg19-caffe.pth --no-check-certificate -nc https://www.dropbox.com/s/xlbdo688dy4keyk/vgg19-caffe.pth?dl=1') 40 | 41 | model = VGGModified(vgg19(pretrained=False), 0.2) 42 | model.load_state_dict(torch.load('vgg_pytorch_modified.pkl')['state_dict']) 43 | 44 | return model 45 | else: 46 | assert False 47 | 48 | 49 | class PerceputalLoss(nn.modules.loss._Loss): 50 | """ 51 | Assumes input image is in range [0,1] if `input_range` is 'sigmoid', [-1, 1] if 'tanh' 52 | """ 53 | def __init__(self, input_range='sigmoid', 54 | net_type = 'vgg_torch', 55 | input_preprocessing='corresponding', 56 | match=[{'layers':[11,20,29],'what':'features'}]): 57 | 58 | if input_range not in ['sigmoid', 'tanh']: 59 | assert False 60 | 61 | self.net = get_pretrained_net(net_type).cuda() 62 | 63 | self.matchers = [get_matcher(self.net, match_opts) for match_opts in match] 64 | 65 | preprocessing_correspondence = { 66 | 'vgg19_torch': vgg_preprocess_caffe, 67 | 'vgg16_torch': vgg_preprocess_caffe, 68 | 'vgg19_pytorch': vgg_preprocess_pytorch, 69 | 'vgg19_pytorch_modified': vgg_preprocess_pytorch, 70 | } 71 | 72 | if input_preprocessing == 'corresponding': 73 | self.preprocess_input = preprocessing_correspondence[net_type] 74 | else: 75 | self.preprocessing = preprocessing_correspondence[input_preprocessing] 76 | 77 | def preprocess_input(self, x): 78 | if self.input_range == 'tanh': 79 | x = (x + 1.) / 2. 80 | 81 | return self.preprocess(x) 82 | 83 | def __call__(self, x, y): 84 | 85 | # for 86 | self.matcher_content.mode = 'store' 87 | self.net(self.preprocess_input(y)); 88 | 89 | self.matcher_content.mode = 'match' 90 | self.net(self.preprocess_input(x)); 91 | 92 | return sum([sum(matcher.losses.values()) for matcher in self.matchers]) 93 | 94 | 95 | def get_vgg19_caffe(): 96 | model = vgg19() 97 | model.classifier = nn.Sequential(View(), *model.classifier._modules.values()) 98 | vgg = model.features 99 | vgg_classifier = model.classifier 100 | 101 | names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 102 | 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 103 | 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','conv3_4','relu3_4','pool3', 104 | 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','conv4_4','relu4_4','pool4', 105 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','conv5_4','relu5_4','pool5', 106 | 'torch_view','fc6','relu6','drop6','fc7','relu7','drop7','fc8'] 107 | 108 | model = nn.Sequential() 109 | for n, m in zip(names, list(vgg) + list(vgg_classifier)): 110 | model.add_module(n, m) 111 | 112 | model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) 113 | 114 | return model 115 | 116 | def get_vgg16_caffe(): 117 | vgg = torch.load('vgg16-caffe-py3.pth') 118 | 119 | names = ['conv1_1','relu1_1','conv1_2','relu1_2','pool1', 120 | 'conv2_1','relu2_1','conv2_2','relu2_2','pool2', 121 | 'conv3_1','relu3_1','conv3_2','relu3_2','conv3_3','relu3_3','pool3', 122 | 'conv4_1','relu4_1','conv4_2','relu4_2','conv4_3','relu4_3','pool4', 123 | 'conv5_1','relu5_1','conv5_2','relu5_2','conv5_3','relu5_3','pool5', 124 | 'torch_view','fc6','relu6','drop6','fc7','relu7','fc8'] 125 | 126 | model = nn.Sequential() 127 | for n, m in zip(names, list(vgg)): 128 | model.add_module(n, m) 129 | 130 | # model.load_state_dict(torch.load('vgg19-caffe-py3.pth')) 131 | 132 | return model 133 | 134 | 135 | class View(nn.Module): 136 | def __init__(self): 137 | super(View, self).__init__() 138 | 139 | def forward(self, x): 140 | return x.view(x.size(0), -1) 141 | 142 | 143 | def get_matcher(vgg, opt): 144 | # idxs = [int(x) for x in opt['layers'].split(',')] 145 | matcher = Matcher(opt['what'], 'mse', opt['map_idx']) 146 | 147 | def hook(module, input, output): 148 | matcher(module, output) 149 | 150 | for layer_name in opt['layers']: 151 | vgg._modules[layer_name].register_forward_hook(hook) 152 | 153 | return matcher 154 | 155 | 156 | def get_vgg(cut_idx=-1, vgg_type='pytorch'): 157 | f = get_vanilla_vgg_features(cut_idx, vgg_type) 158 | 159 | keys = [x for x in cnn._modules.keys()] 160 | max_idx = max(keys.index(x) for x in opt_content['layers'].split(',')) 161 | for k in keys[max_idx+1:]: 162 | cnn._modules.pop(k) 163 | 164 | return f 165 | 166 | vgg_mean = torch.FloatTensor([103.939, 116.779, 123.680]).view(3, 1, 1) 167 | def vgg_preprocess_caffe(var): 168 | (r, g, b) = torch.chunk(var, 3, dim=1) 169 | bgr = torch.cat((b, g, r), 1) 170 | out = bgr * 255 - torch.autograd.Variable(vgg_mean).type(var.type()) 171 | return out 172 | 173 | 174 | 175 | mean_pytorch = Variable(torch.Tensor([0.485, 0.456, 0.406]).view(3, 1, 1)) 176 | std_pytorch = Variable(torch.Tensor([0.229, 0.224, 0.225]).view(3, 1, 1)) 177 | 178 | def vgg_preprocess_pytorch(var): 179 | return (var - mean_pytorch.type_as(var))/std_pytorch.type_as(var) 180 | 181 | 182 | 183 | def get_preprocessor(imsize): 184 | def vgg_preprocess(tensor): 185 | (r, g, b) = torch.chunk(tensor, 3, dim=0) 186 | bgr = torch.cat((b, g, r), 0) 187 | out = bgr * 255 - vgg_mean.type(tensor.type()).expand_as(bgr) 188 | return out 189 | preprocess = transforms.Compose([ 190 | transforms.Resize(imsize), 191 | transforms.ToTensor(), 192 | transforms.Lambda(vgg_preprocess) 193 | ]) 194 | 195 | return preprocess 196 | 197 | 198 | def get_deprocessor(): 199 | def vgg_deprocess(tensor): 200 | bgr = (tensor + vgg_mean.expand_as(tensor)) / 255.0 201 | (b, g, r) = torch.chunk(bgr, 3, dim=0) 202 | rgb = torch.cat((r, g, b), 0) 203 | return rgb 204 | deprocess = transforms.Compose([ 205 | transforms.Lambda(vgg_deprocess), 206 | transforms.Lambda(lambda x: torch.clamp(x, 0, 1)), 207 | transforms.ToPILImage() 208 | ]) 209 | return deprocess 210 | 211 | -------------------------------------------------------------------------------- /DIP/utils/perceptual_loss/vgg_modified.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | class VGGModified(nn.Module): 4 | def __init__(self, vgg19_orig, slope=0.01): 5 | super(VGGModified, self).__init__() 6 | self.features = nn.Sequential() 7 | 8 | self.features.add_module(str(0), vgg19_orig.features[0]) 9 | self.features.add_module(str(1), nn.LeakyReLU(slope, True)) 10 | self.features.add_module(str(2), vgg19_orig.features[2]) 11 | self.features.add_module(str(3), nn.LeakyReLU(slope, True)) 12 | self.features.add_module(str(4), nn.AvgPool2d((2,2), (2,2))) 13 | 14 | self.features.add_module(str(5), vgg19_orig.features[5]) 15 | self.features.add_module(str(6), nn.LeakyReLU(slope, True)) 16 | self.features.add_module(str(7), vgg19_orig.features[7]) 17 | self.features.add_module(str(8), nn.LeakyReLU(slope, True)) 18 | self.features.add_module(str(9), nn.AvgPool2d((2,2), (2,2))) 19 | 20 | self.features.add_module(str(10), vgg19_orig.features[10]) 21 | self.features.add_module(str(11), nn.LeakyReLU(slope, True)) 22 | self.features.add_module(str(12), vgg19_orig.features[12]) 23 | self.features.add_module(str(13), nn.LeakyReLU(slope, True)) 24 | self.features.add_module(str(14), vgg19_orig.features[14]) 25 | self.features.add_module(str(15), nn.LeakyReLU(slope, True)) 26 | self.features.add_module(str(16), vgg19_orig.features[16]) 27 | self.features.add_module(str(17), nn.LeakyReLU(slope, True)) 28 | self.features.add_module(str(18), nn.AvgPool2d((2,2), (2,2))) 29 | 30 | self.features.add_module(str(19), vgg19_orig.features[19]) 31 | self.features.add_module(str(20), nn.LeakyReLU(slope, True)) 32 | self.features.add_module(str(21), vgg19_orig.features[21]) 33 | self.features.add_module(str(22), nn.LeakyReLU(slope, True)) 34 | self.features.add_module(str(23), vgg19_orig.features[23]) 35 | self.features.add_module(str(24), nn.LeakyReLU(slope, True)) 36 | self.features.add_module(str(25), vgg19_orig.features[25]) 37 | self.features.add_module(str(26), nn.LeakyReLU(slope, True)) 38 | self.features.add_module(str(27), nn.AvgPool2d((2,2), (2,2))) 39 | 40 | self.features.add_module(str(28), vgg19_orig.features[28]) 41 | self.features.add_module(str(29), nn.LeakyReLU(slope, True)) 42 | self.features.add_module(str(30), vgg19_orig.features[30]) 43 | self.features.add_module(str(31), nn.LeakyReLU(slope, True)) 44 | self.features.add_module(str(32), vgg19_orig.features[32]) 45 | self.features.add_module(str(33), nn.LeakyReLU(slope, True)) 46 | self.features.add_module(str(34), vgg19_orig.features[34]) 47 | self.features.add_module(str(35), nn.LeakyReLU(slope, True)) 48 | self.features.add_module(str(36), nn.AvgPool2d((2,2), (2,2))) 49 | 50 | self.classifier = nn.Sequential() 51 | 52 | self.classifier.add_module(str(0), vgg19_orig.classifier[0]) 53 | self.classifier.add_module(str(1), nn.LeakyReLU(slope, True)) 54 | self.classifier.add_module(str(2), nn.Dropout2d(p = 0.5)) 55 | self.classifier.add_module(str(3), vgg19_orig.classifier[3]) 56 | self.classifier.add_module(str(4), nn.LeakyReLU(slope, True)) 57 | self.classifier.add_module(str(5), nn.Dropout2d(p = 0.5)) 58 | self.classifier.add_module(str(6), vgg19_orig.classifier[6]) 59 | 60 | def forward(self, x): 61 | return self.classifier(self.features.forward(x)) -------------------------------------------------------------------------------- /DIP/utils/sr_utils.py: -------------------------------------------------------------------------------- 1 | from .common_utils import * 2 | 3 | def put_in_center(img_np, target_size): 4 | img_out = np.zeros([3, target_size[0], target_size[1]]) 5 | 6 | bbox = [ 7 | int((target_size[0] - img_np.shape[1]) / 2), 8 | int((target_size[1] - img_np.shape[2]) / 2), 9 | int((target_size[0] + img_np.shape[1]) / 2), 10 | int((target_size[1] + img_np.shape[2]) / 2), 11 | ] 12 | 13 | img_out[:, bbox[0]:bbox[2], bbox[1]:bbox[3]] = img_np 14 | 15 | return img_out 16 | 17 | 18 | def load_LR_HR_imgs_sr(fname, imsize, factor, enforse_div32=None): 19 | '''Loads an image, resizes it, center crops and downscales. 20 | 21 | Args: 22 | fname: path to the image 23 | imsize: new size for the image, -1 for no resizing 24 | factor: downscaling factor 25 | enforse_div32: if 'CROP' center crops an image, so that its dimensions are divisible by 32. 26 | ''' 27 | img_orig_pil, img_orig_np = get_image(fname, -1) 28 | 29 | if imsize != -1: 30 | img_orig_pil, img_orig_np = get_image(fname, imsize) 31 | 32 | # For comparison with GT 33 | if enforse_div32 == 'CROP': 34 | new_size = (img_orig_pil.size[0] - img_orig_pil.size[0] % 32, 35 | img_orig_pil.size[1] - img_orig_pil.size[1] % 32) 36 | 37 | bbox = [ 38 | (img_orig_pil.size[0] - new_size[0])/2, 39 | (img_orig_pil.size[1] - new_size[1])/2, 40 | (img_orig_pil.size[0] + new_size[0])/2, 41 | (img_orig_pil.size[1] + new_size[1])/2, 42 | ] 43 | 44 | img_HR_pil = img_orig_pil.crop(bbox) 45 | img_HR_np = pil_to_np(img_HR_pil) 46 | else: 47 | img_HR_pil, img_HR_np = img_orig_pil, img_orig_np 48 | 49 | LR_size = [ 50 | img_HR_pil.size[0] // factor, 51 | img_HR_pil.size[1] // factor 52 | ] 53 | 54 | img_LR_pil = img_HR_pil.resize(LR_size, Image.ANTIALIAS) 55 | img_LR_np = pil_to_np(img_LR_pil) 56 | 57 | # print('HR and LR resolutions: %s, %s' % (str(img_HR_pil.size), str (img_LR_pil.size))) 58 | 59 | return { 60 | 'orig_pil': img_orig_pil, 61 | 'orig_np': img_orig_np, 62 | 'LR_pil': img_LR_pil, 63 | 'LR_np': img_LR_np, 64 | 'HR_pil': img_HR_pil, 65 | 'HR_np': img_HR_np 66 | } 67 | 68 | 69 | def get_baselines(img_LR_pil, img_HR_pil): 70 | '''Gets `bicubic`, sharpened bicubic and `nearest` baselines.''' 71 | img_bicubic_pil = img_LR_pil.resize(img_HR_pil.size, Image.BICUBIC) 72 | img_bicubic_np = pil_to_np(img_bicubic_pil) 73 | 74 | img_nearest_pil = img_LR_pil.resize(img_HR_pil.size, Image.NEAREST) 75 | img_nearest_np = pil_to_np(img_nearest_pil) 76 | 77 | img_bic_sharp_pil = img_bicubic_pil.filter(PIL.ImageFilter.UnsharpMask()) 78 | img_bic_sharp_np = pil_to_np(img_bic_sharp_pil) 79 | 80 | return img_bicubic_np, img_bic_sharp_np, img_nearest_np 81 | 82 | 83 | 84 | def tv_loss(x, beta = 0.5): 85 | '''Calculates TV loss for an image `x`. 86 | 87 | Args: 88 | x: image, torch.Variable of torch.Tensor 89 | beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta` 90 | ''' 91 | dh = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2) 92 | dw = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2) 93 | 94 | return torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta)) 95 | -------------------------------------------------------------------------------- /DIP/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Ross Girshick 6 | # -------------------------------------------------------- 7 | 8 | import time 9 | 10 | class Timer(object): 11 | """A simple timer.""" 12 | def __init__(self): 13 | self.total_time = 0. 14 | self.calls = 0 15 | self.start_time = 0. 16 | self.diff = 0. 17 | self.average_time = 0. 18 | 19 | def tic(self): 20 | # using time.time instead of time.clock because time time.clock 21 | # does not normalize for multithreading 22 | self.start_time = time.time() 23 | 24 | def toc(self, average=True): 25 | self.diff = time.time() - self.start_time 26 | self.total_time += self.diff 27 | self.calls += 1 28 | self.average_time = self.total_time / self.calls 29 | if average: 30 | return self.average_time 31 | else: 32 | return self.diff -------------------------------------------------------------------------------- /NAS/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/NAS/__init__.py -------------------------------------------------------------------------------- /NAS/demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SepConv(nn.Module): 6 | 7 | def __init__(self, 8 | stride=2, 9 | mode='bicubic'): 10 | 11 | super(SepConv, self).__init__() 12 | 13 | self.op = nn.Sequential( 14 | nn.Upsample(scale_factor=2, mode='bicubic'), 15 | nn.ReLU(inplace=False), 16 | ) 17 | 18 | def forward(self, x): 19 | return self.op(x) 20 | 21 | 22 | 23 | data = torch.rand((4, 32, 256, 256)).cuda() 24 | 25 | net = SepConv().cuda() 26 | 27 | out = net(data) 28 | 29 | print('input dim:', data.shape) 30 | print('output dim:', out.shape) 31 | -------------------------------------------------------------------------------- /NAS/gen_id.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | from NAS import genotypes 7 | except ImportError: 8 | import genotypes 9 | 10 | try: 11 | from NAS import model 12 | except ImportError: 13 | import model 14 | 15 | try: 16 | from NAS import operations 17 | except ImportError: 18 | import operations 19 | 20 | def gen_layer(C_in, C_out, model_index): 21 | 22 | selected_index = model_index 23 | 24 | swap = False 25 | 26 | """ Bilinear """ 27 | if model_index >= 0 and model_index <= 251: 28 | prim_index = model_index // 63 29 | model_index = model_index % 63 30 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 31 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 32 | act_index = model_index % len(genotypes.ACTIVATION) 33 | 34 | if (model_index >= 60 and model_index <= 62): 35 | conv_index = 5 36 | 37 | """ DepthToSpace - Second """ 38 | if model_index >= 252 and model_index <= 311: 39 | swap = True 40 | prim_index = (model_index - 63) // 63 41 | model_index = model_index % 63 42 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 43 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 44 | act_index = model_index % len(genotypes.ACTIVATION) 45 | 46 | """ Transposed Convolution """ 47 | if model_index >= 312 and model_index <= 323: 48 | prim_index = 4 49 | conv_index = 5 50 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 51 | act_index = model_index % len(genotypes.ACTIVATION) 52 | 53 | prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_index] # adjust the spatial size 54 | conv_op = genotypes.UPSAMPLE_CONV[conv_index] # adjust the spatial size 55 | kernel_size = genotypes.KERNEL_SIZE[kernel_index] # select the kernel size 56 | act_op = genotypes.ACTIVATION[act_index] # select the kernel size 57 | 58 | if prim_op == 'pixel_shuffle' and conv_op == 'identity': 59 | return 60 | if act_op in ['none', 'ReLU']: 61 | return 62 | if kernel_size in ['1x1', '3x3', '7x7']: 63 | return 64 | else: 65 | print('{},'.format(selected_index), end="") 66 | 67 | #print(prim_op, conv_op, kernel_size, act_op) 68 | return 69 | #print('prim op:', prim_op) 70 | #print('conv op:', conv_op) 71 | #print('kernel size:', kernel_size) 72 | #print('act op:', act_op) 73 | #return 74 | 75 | if prim_op == 'pixel_shuffle': 76 | if not swap: 77 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 78 | C_out=C_out, 79 | kernel_size=kernel_size, 80 | act_op=act_op) 81 | 82 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=int(C_in/4), 83 | C_out=C_out, 84 | kernel_size=kernel_size, 85 | act_op=act_op) 86 | return nn.Sequential(prim_op_layer, conv_op_layer) 87 | 88 | else: 89 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 90 | C_out=int(C_out*4), 91 | kernel_size=kernel_size, 92 | act_op=act_op) 93 | 94 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 95 | C_out=C_out, 96 | kernel_size=kernel_size, 97 | act_op=act_op) 98 | return nn.Sequential(conv_op_layer, prim_op_layer) 99 | 100 | else: 101 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 102 | C_out=C_out, 103 | kernel_size=kernel_size, 104 | act_op=act_op) 105 | 106 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 107 | C_out=C_out, 108 | kernel_size=kernel_size, 109 | act_op=act_op) 110 | return nn.Sequential(prim_op_layer, conv_op_layer) 111 | 112 | for i in range(324): 113 | gen_layer(0, 0, model_index=i) 114 | #gen_layer(0, 0, model_index=189) 115 | -------------------------------------------------------------------------------- /NAS/gen_upsample_layer-prev.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | from NAS import genotypes 7 | except ImportError: 8 | import genotypes 9 | 10 | try: 11 | from NAS import model 12 | except ImportError: 13 | import model 14 | 15 | try: 16 | from NAS import operations 17 | except ImportError: 18 | import operations 19 | 20 | #random.seed(1) 21 | #torch.manual_seed(1) 22 | 23 | 24 | # def gen_layer(C_in, C_out, use_act, model_index): 25 | 26 | # prim_index = int(model_index / 5) 27 | # conv_index = model_index % 5 28 | # kernel_index = 0 29 | 30 | # if model_index >= 15 and model_index < 23: 31 | # prim_index = 3 32 | # conv_index = (model_index - 3) % 4 33 | 34 | # prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_index] # adjust the spatial size 35 | # conv_op = genotypes.UPSAMPLE_CONV[conv_index] # adjust the spatial size 36 | # kernel_size = genotypes.KERNEL_SIZE[kernel_index] # select the kernel size 37 | 38 | # if prim_op == 'trans_conv': # only one single layer 39 | # prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 40 | # C_out=C_out, 41 | # kernel_size=kernel_size, 42 | # use_act=use_act) 43 | # return prim_op_layer 44 | 45 | # elif prim_op == 'pixel_shuffle': 46 | # if model_index >= 15 and model_index < 23: 47 | # prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 48 | # C_out=C_out, 49 | # kernel_size=kernel_size, 50 | # use_act=use_act) 51 | 52 | # conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=int(C_in/4), 53 | # C_out=C_out, 54 | # kernel_size=kernel_size, 55 | # use_act=use_act) 56 | # layer = nn.Sequential(prim_op_layer, conv_op_layer) 57 | 58 | # else: 59 | # conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 60 | # C_out=int(C_out*4), 61 | # kernel_size=kernel_size, 62 | # use_act=use_act) 63 | 64 | # prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 65 | # C_out=C_out, 66 | # kernel_size=kernel_size, 67 | # use_act=use_act) 68 | # layer = nn.Sequential(conv_op_layer, prim_op_layer) 69 | 70 | # else: 71 | # prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 72 | # C_out=C_out, 73 | # kernel_size=kernel_size, 74 | # use_act=use_act) 75 | 76 | # conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 77 | # C_out=C_out, 78 | # kernel_size=kernel_size, 79 | # use_act=use_act) 80 | 81 | # layer = nn.Sequential(prim_op_layer, conv_op_layer) 82 | 83 | # return layer 84 | 85 | 86 | def gen_layer(C_in, C_out, use_act, model_index): 87 | 88 | prim_index = int(model_index / 5) 89 | conv_index = model_index % 5 90 | kernel_index = 0 91 | 92 | prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_index] # adjust the spatial size 93 | conv_op = genotypes.UPSAMPLE_CONV[conv_index] # adjust the spatial size 94 | kernel_size = genotypes.KERNEL_SIZE[kernel_index] # select the kernel size 95 | 96 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 97 | C_out=C_out, 98 | kernel_size=kernel_size, 99 | use_act=use_act) 100 | 101 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 102 | C_out=C_out, 103 | kernel_size=kernel_size, 104 | use_act=use_act) 105 | 106 | layer = nn.Sequential(conv_op_layer, prim_op_layer) 107 | 108 | return layer 109 | -------------------------------------------------------------------------------- /NAS/gen_upsample_layer.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | 5 | try: 6 | from NAS import genotypes 7 | except ImportError: 8 | import genotypes 9 | 10 | try: 11 | from NAS import model 12 | except ImportError: 13 | import model 14 | 15 | try: 16 | from NAS import operations 17 | except ImportError: 18 | import operations 19 | 20 | def gen_layer(C_in, C_out, model_index): 21 | 22 | swap = False 23 | 24 | """ Bilinear """ 25 | if model_index >= 0 and model_index <= 251: 26 | prim_index = model_index // 63 27 | model_index = model_index % 63 28 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 29 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 30 | act_index = model_index % len(genotypes.ACTIVATION) 31 | 32 | if (model_index >= 60 and model_index <= 62): 33 | conv_index = 5 34 | 35 | """ DepthToSpace - Second """ 36 | if model_index >= 252 and model_index <= 311: 37 | swap = True 38 | prim_index = (model_index - 63) // 63 39 | model_index = model_index % 63 40 | conv_index = ((model_index // len(genotypes.ACTIVATION)) // len(genotypes.KERNEL_SIZE)) % len(genotypes.UPSAMPLE_PRIMITIVE) 41 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 42 | act_index = model_index % len(genotypes.ACTIVATION) 43 | 44 | """ Transposed Convolution """ 45 | if model_index >= 312 and model_index <= 323: 46 | prim_index = 4 47 | conv_index = 5 48 | kernel_index = (model_index // len(genotypes.ACTIVATION)) % len(genotypes.KERNEL_SIZE) 49 | act_index = model_index % len(genotypes.ACTIVATION) 50 | 51 | prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_index] # adjust the spatial size 52 | conv_op = genotypes.UPSAMPLE_CONV[conv_index] # adjust the spatial size 53 | kernel_size = genotypes.KERNEL_SIZE[kernel_index] # select the kernel size 54 | act_op = genotypes.ACTIVATION[act_index] # select the kernel size 55 | 56 | #print('prim op:', prim_op) 57 | #print('conv op:', conv_op) 58 | #print('kernel size:', kernel_size) 59 | #print('act op:', act_op) 60 | #return 61 | 62 | if prim_op == 'pixel_shuffle': 63 | if not swap: 64 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 65 | C_out=C_out, 66 | kernel_size=kernel_size, 67 | act_op=act_op) 68 | 69 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=int(C_in/4), 70 | C_out=C_out, 71 | kernel_size=kernel_size, 72 | act_op=act_op) 73 | return nn.Sequential(prim_op_layer, conv_op_layer) 74 | 75 | else: 76 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 77 | C_out=int(C_out*4), 78 | kernel_size=kernel_size, 79 | act_op=act_op) 80 | 81 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 82 | C_out=C_out, 83 | kernel_size=kernel_size, 84 | act_op=act_op) 85 | return nn.Sequential(conv_op_layer, prim_op_layer) 86 | 87 | else: 88 | prim_op_layer = operations.UPSAMPLE_PRIMITIVE_OPS[prim_op](C_in=C_in, 89 | C_out=C_out, 90 | kernel_size=kernel_size, 91 | act_op=act_op) 92 | 93 | conv_op_layer = operations.UPSAMPLE_CONV_OPS[conv_op](C_in=C_in, 94 | C_out=C_out, 95 | kernel_size=kernel_size, 96 | act_op=act_op) 97 | return nn.Sequential(prim_op_layer, conv_op_layer) 98 | 99 | #for i in range(321): 100 | # gen_layer(0, 0, model_index=i) 101 | #gen_layer(0, 0, model_index=189) 102 | -------------------------------------------------------------------------------- /NAS/genotypes-prev.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | #Genotype = namedtuple('Genotype', 'downsample_conv downsample_method downsample_concat upsample_conv upsample_method upsample_concat') 5 | Genotype = namedtuple('Genotype', 'upsample_prim_method upsample_conv upsample_kernel upsample_concat') 6 | 7 | 8 | PRIMITIVES = [ 9 | 'sep_conv_3x3', 10 | 'sep_conv_5x5', 11 | 'sep_conv_7x7', 12 | 'dil_conv_3x3', 13 | 'dil_conv_5x5', 14 | 'dil_conv_7x7', 15 | ] 16 | 17 | 18 | UPSAMPLE_METHOD = [ 19 | 'bilinear_conv_3x3', 20 | 'bilinear_conv_5x5', 21 | 'bilinear_conv_7x7', 22 | 'trans_conv_3x3', 23 | 'trans_conv_5x5', 24 | 'trans_conv_7x7', 25 | 'bilinear_additive_3x3', 26 | 'bilinear_additive_5x5', 27 | 'bilinear_additive_7x7', 28 | 'depth_to_space_3x3', 29 | 'depth_to_space_5x5', 30 | 'depth_to_space_7x7', 31 | 'factorized_reduce', 32 | #'none', 33 | ] 34 | 35 | 36 | DOWNSAMPLE_METHOD = [ 37 | 'conv_downsample_3x3_dilation_1', 38 | 'conv_downsample_3x3_dilation_2', 39 | 'conv_downsample_5x5_dilation_1', 40 | 'conv_downsample_5x5_dilation_2', 41 | 'conv_downsample_7x7_dilation_1', 42 | 'conv_downsample_7x7_dilation_2', 43 | 'avg_pool_3x3', 44 | 'avg_pool_5x5', 45 | 'avg_pool_7x7', 46 | 'max_pool_3x3', 47 | 'max_pool_5x5', 48 | 'max_pool_7x7', 49 | 'factorized_reduce', 50 | #'none', 51 | ] 52 | 53 | 54 | 55 | """ 56 | 57 | Newly added 58 | 59 | Description: 60 | - decompose the upsampling operations into two primitives 61 | 62 | Upsample = upsample primitive + conv 63 | - upsample primitive: change the spatial size 64 | - conv: maintain the channel size 65 | 66 | Goal: 67 | - search for the upsample operation 68 | - replace the upsampling operation in DIP's network with the searched operation 69 | 70 | Note: 71 | - experiments are finished 72 | 73 | """ 74 | 75 | UPSAMPLE_PRIMITIVE = [ 76 | 'bilinear', 77 | 'bicubic', 78 | 'nearest', 79 | 'pixel_shuffle', 80 | 'trans_conv', 81 | ] 82 | 83 | UPSAMPLE_CONV = [ 84 | 'conv', 85 | 'trans_conv', 86 | 'split_stack_sum', # additive 87 | 'sep_conv', 88 | 'identity', 89 | ] 90 | 91 | 92 | KERNEL_SIZE = [ 93 | '3x3', 94 | #'4x4', 95 | #'5x5', 96 | #'7x7', 97 | ] 98 | 99 | 100 | DILATION_RATE = [ 101 | '1', 102 | '2', 103 | '3', 104 | ] 105 | -------------------------------------------------------------------------------- /NAS/genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | 4 | """ 5 | 6 | Newly added 7 | 8 | Description: 9 | - decompose the upsampling operations into two primitives 10 | 11 | Upsample = upsample primitive + conv 12 | - upsample primitive: change the spatial size 13 | - conv: maintain the channel size 14 | 15 | Goal: 16 | - search for the upsample operation 17 | - replace the upsampling operation in DIP's network with the searched operation 18 | 19 | Note: 20 | - do not need a separate point-wise convolution as it is only a 1x1 conv 21 | 22 | """ 23 | 24 | UPSAMPLE_PRIMITIVE = [ 25 | 'bilinear', 26 | 'bicubic', 27 | 'nearest', 28 | 'pixel_shuffle', 29 | 'trans_conv', # stride = 1 30 | ] 31 | 32 | UPSAMPLE_CONV = [ 33 | 'conv', 34 | 'trans_conv', # stride = 2 35 | 'split_stack_sum', # additive 36 | 'sep_conv', 37 | 'depth_wise_conv', 38 | 'identity', 39 | ] 40 | 41 | 42 | KERNEL_SIZE = [ 43 | '1x1', 44 | '3x3', 45 | '5x5', 46 | '7x7', 47 | ] 48 | 49 | 50 | DILATION_RATE = [ 51 | '1', 52 | '2', 53 | '3', 54 | ] 55 | 56 | 57 | ACTIVATION = [ 58 | 'none', 59 | 'ReLU', 60 | 'LeakyReLU', 61 | ] 62 | -------------------------------------------------------------------------------- /NAS/index_to_model_mapping.log: -------------------------------------------------------------------------------- 1 | bilinear -> conv -> 1x1 -> none 2 | bilinear -> conv -> 1x1 -> ReLU 3 | bilinear -> conv -> 1x1 -> LeakyReLU 4 | bilinear -> conv -> 3x3 -> none 5 | bilinear -> conv -> 3x3 -> ReLU 6 | bilinear -> conv -> 3x3 -> LeakyReLU 7 | bilinear -> conv -> 5x5 -> none 8 | bilinear -> conv -> 5x5 -> ReLU 9 | bilinear -> conv -> 5x5 -> LeakyReLU 10 | bilinear -> conv -> 7x7 -> none 11 | bilinear -> conv -> 7x7 -> ReLU 12 | bilinear -> conv -> 7x7 -> LeakyReLU 13 | bilinear -> trans_conv -> 1x1 -> none 14 | bilinear -> trans_conv -> 1x1 -> ReLU 15 | bilinear -> trans_conv -> 1x1 -> LeakyReLU 16 | bilinear -> trans_conv -> 3x3 -> none 17 | bilinear -> trans_conv -> 3x3 -> ReLU 18 | bilinear -> trans_conv -> 3x3 -> LeakyReLU 19 | bilinear -> trans_conv -> 5x5 -> none 20 | bilinear -> trans_conv -> 5x5 -> ReLU 21 | bilinear -> trans_conv -> 5x5 -> LeakyReLU 22 | bilinear -> trans_conv -> 7x7 -> none 23 | bilinear -> trans_conv -> 7x7 -> ReLU 24 | bilinear -> trans_conv -> 7x7 -> LeakyReLU 25 | bilinear -> split_stack_sum -> 1x1 -> none 26 | bilinear -> split_stack_sum -> 1x1 -> ReLU 27 | bilinear -> split_stack_sum -> 1x1 -> LeakyReLU 28 | bilinear -> split_stack_sum -> 3x3 -> none 29 | bilinear -> split_stack_sum -> 3x3 -> ReLU 30 | bilinear -> split_stack_sum -> 3x3 -> LeakyReLU 31 | bilinear -> split_stack_sum -> 5x5 -> none 32 | bilinear -> split_stack_sum -> 5x5 -> ReLU 33 | bilinear -> split_stack_sum -> 5x5 -> LeakyReLU 34 | bilinear -> split_stack_sum -> 7x7 -> none 35 | bilinear -> split_stack_sum -> 7x7 -> ReLU 36 | bilinear -> split_stack_sum -> 7x7 -> LeakyReLU 37 | bilinear -> sep_conv -> 1x1 -> none 38 | bilinear -> sep_conv -> 1x1 -> ReLU 39 | bilinear -> sep_conv -> 1x1 -> LeakyReLU 40 | bilinear -> sep_conv -> 3x3 -> none 41 | bilinear -> sep_conv -> 3x3 -> ReLU 42 | bilinear -> sep_conv -> 3x3 -> LeakyReLU 43 | bilinear -> sep_conv -> 5x5 -> none 44 | bilinear -> sep_conv -> 5x5 -> ReLU 45 | bilinear -> sep_conv -> 5x5 -> LeakyReLU 46 | bilinear -> sep_conv -> 7x7 -> none 47 | bilinear -> sep_conv -> 7x7 -> ReLU 48 | bilinear -> sep_conv -> 7x7 -> LeakyReLU 49 | bilinear -> depth_wise_conv -> 1x1 -> none 50 | bilinear -> depth_wise_conv -> 1x1 -> ReLU 51 | bilinear -> depth_wise_conv -> 1x1 -> LeakyReLU 52 | bilinear -> depth_wise_conv -> 3x3 -> none 53 | bilinear -> depth_wise_conv -> 3x3 -> ReLU 54 | bilinear -> depth_wise_conv -> 3x3 -> LeakyReLU 55 | bilinear -> depth_wise_conv -> 5x5 -> none 56 | bilinear -> depth_wise_conv -> 5x5 -> ReLU 57 | bilinear -> depth_wise_conv -> 5x5 -> LeakyReLU 58 | bilinear -> depth_wise_conv -> 7x7 -> none 59 | bilinear -> depth_wise_conv -> 7x7 -> ReLU 60 | bilinear -> depth_wise_conv -> 7x7 -> LeakyReLU 61 | bilinear -> identity -> 1x1 -> none 62 | bilinear -> identity -> 1x1 -> ReLU 63 | bilinear -> identity -> 1x1 -> LeakyReLU 64 | bicubic -> conv -> 1x1 -> none 65 | bicubic -> conv -> 1x1 -> ReLU 66 | bicubic -> conv -> 1x1 -> LeakyReLU 67 | bicubic -> conv -> 3x3 -> none 68 | bicubic -> conv -> 3x3 -> ReLU 69 | bicubic -> conv -> 3x3 -> LeakyReLU 70 | bicubic -> conv -> 5x5 -> none 71 | bicubic -> conv -> 5x5 -> ReLU 72 | bicubic -> conv -> 5x5 -> LeakyReLU 73 | bicubic -> conv -> 7x7 -> none 74 | bicubic -> conv -> 7x7 -> ReLU 75 | bicubic -> conv -> 7x7 -> LeakyReLU 76 | bicubic -> trans_conv -> 1x1 -> none 77 | bicubic -> trans_conv -> 1x1 -> ReLU 78 | bicubic -> trans_conv -> 1x1 -> LeakyReLU 79 | bicubic -> trans_conv -> 3x3 -> none 80 | bicubic -> trans_conv -> 3x3 -> ReLU 81 | bicubic -> trans_conv -> 3x3 -> LeakyReLU 82 | bicubic -> trans_conv -> 5x5 -> none 83 | bicubic -> trans_conv -> 5x5 -> ReLU 84 | bicubic -> trans_conv -> 5x5 -> LeakyReLU 85 | bicubic -> trans_conv -> 7x7 -> none 86 | bicubic -> trans_conv -> 7x7 -> ReLU 87 | bicubic -> trans_conv -> 7x7 -> LeakyReLU 88 | bicubic -> split_stack_sum -> 1x1 -> none 89 | bicubic -> split_stack_sum -> 1x1 -> ReLU 90 | bicubic -> split_stack_sum -> 1x1 -> LeakyReLU 91 | bicubic -> split_stack_sum -> 3x3 -> none 92 | bicubic -> split_stack_sum -> 3x3 -> ReLU 93 | bicubic -> split_stack_sum -> 3x3 -> LeakyReLU 94 | bicubic -> split_stack_sum -> 5x5 -> none 95 | bicubic -> split_stack_sum -> 5x5 -> ReLU 96 | bicubic -> split_stack_sum -> 5x5 -> LeakyReLU 97 | bicubic -> split_stack_sum -> 7x7 -> none 98 | bicubic -> split_stack_sum -> 7x7 -> ReLU 99 | bicubic -> split_stack_sum -> 7x7 -> LeakyReLU 100 | bicubic -> sep_conv -> 1x1 -> none 101 | bicubic -> sep_conv -> 1x1 -> ReLU 102 | bicubic -> sep_conv -> 1x1 -> LeakyReLU 103 | bicubic -> sep_conv -> 3x3 -> none 104 | bicubic -> sep_conv -> 3x3 -> ReLU 105 | bicubic -> sep_conv -> 3x3 -> LeakyReLU 106 | bicubic -> sep_conv -> 5x5 -> none 107 | bicubic -> sep_conv -> 5x5 -> ReLU 108 | bicubic -> sep_conv -> 5x5 -> LeakyReLU 109 | bicubic -> sep_conv -> 7x7 -> none 110 | bicubic -> sep_conv -> 7x7 -> ReLU 111 | bicubic -> sep_conv -> 7x7 -> LeakyReLU 112 | bicubic -> depth_wise_conv -> 1x1 -> none 113 | bicubic -> depth_wise_conv -> 1x1 -> ReLU 114 | bicubic -> depth_wise_conv -> 1x1 -> LeakyReLU 115 | bicubic -> depth_wise_conv -> 3x3 -> none 116 | bicubic -> depth_wise_conv -> 3x3 -> ReLU 117 | bicubic -> depth_wise_conv -> 3x3 -> LeakyReLU 118 | bicubic -> depth_wise_conv -> 5x5 -> none 119 | bicubic -> depth_wise_conv -> 5x5 -> ReLU 120 | bicubic -> depth_wise_conv -> 5x5 -> LeakyReLU 121 | bicubic -> depth_wise_conv -> 7x7 -> none 122 | bicubic -> depth_wise_conv -> 7x7 -> ReLU 123 | bicubic -> depth_wise_conv -> 7x7 -> LeakyReLU 124 | bicubic -> identity -> 1x1 -> none 125 | bicubic -> identity -> 1x1 -> ReLU 126 | bicubic -> identity -> 1x1 -> LeakyReLU 127 | nearest -> conv -> 1x1 -> none 128 | nearest -> conv -> 1x1 -> ReLU 129 | nearest -> conv -> 1x1 -> LeakyReLU 130 | nearest -> conv -> 3x3 -> none 131 | nearest -> conv -> 3x3 -> ReLU 132 | nearest -> conv -> 3x3 -> LeakyReLU 133 | nearest -> conv -> 5x5 -> none 134 | nearest -> conv -> 5x5 -> ReLU 135 | nearest -> conv -> 5x5 -> LeakyReLU 136 | nearest -> conv -> 7x7 -> none 137 | nearest -> conv -> 7x7 -> ReLU 138 | nearest -> conv -> 7x7 -> LeakyReLU 139 | nearest -> trans_conv -> 1x1 -> none 140 | nearest -> trans_conv -> 1x1 -> ReLU 141 | nearest -> trans_conv -> 1x1 -> LeakyReLU 142 | nearest -> trans_conv -> 3x3 -> none 143 | nearest -> trans_conv -> 3x3 -> ReLU 144 | nearest -> trans_conv -> 3x3 -> LeakyReLU 145 | nearest -> trans_conv -> 5x5 -> none 146 | nearest -> trans_conv -> 5x5 -> ReLU 147 | nearest -> trans_conv -> 5x5 -> LeakyReLU 148 | nearest -> trans_conv -> 7x7 -> none 149 | nearest -> trans_conv -> 7x7 -> ReLU 150 | nearest -> trans_conv -> 7x7 -> LeakyReLU 151 | nearest -> split_stack_sum -> 1x1 -> none 152 | nearest -> split_stack_sum -> 1x1 -> ReLU 153 | nearest -> split_stack_sum -> 1x1 -> LeakyReLU 154 | nearest -> split_stack_sum -> 3x3 -> none 155 | nearest -> split_stack_sum -> 3x3 -> ReLU 156 | nearest -> split_stack_sum -> 3x3 -> LeakyReLU 157 | nearest -> split_stack_sum -> 5x5 -> none 158 | nearest -> split_stack_sum -> 5x5 -> ReLU 159 | nearest -> split_stack_sum -> 5x5 -> LeakyReLU 160 | nearest -> split_stack_sum -> 7x7 -> none 161 | nearest -> split_stack_sum -> 7x7 -> ReLU 162 | nearest -> split_stack_sum -> 7x7 -> LeakyReLU 163 | nearest -> sep_conv -> 1x1 -> none 164 | nearest -> sep_conv -> 1x1 -> ReLU 165 | nearest -> sep_conv -> 1x1 -> LeakyReLU 166 | nearest -> sep_conv -> 3x3 -> none 167 | nearest -> sep_conv -> 3x3 -> ReLU 168 | nearest -> sep_conv -> 3x3 -> LeakyReLU 169 | nearest -> sep_conv -> 5x5 -> none 170 | nearest -> sep_conv -> 5x5 -> ReLU 171 | nearest -> sep_conv -> 5x5 -> LeakyReLU 172 | nearest -> sep_conv -> 7x7 -> none 173 | nearest -> sep_conv -> 7x7 -> ReLU 174 | nearest -> sep_conv -> 7x7 -> LeakyReLU 175 | nearest -> depth_wise_conv -> 1x1 -> none 176 | nearest -> depth_wise_conv -> 1x1 -> ReLU 177 | nearest -> depth_wise_conv -> 1x1 -> LeakyReLU 178 | nearest -> depth_wise_conv -> 3x3 -> none 179 | nearest -> depth_wise_conv -> 3x3 -> ReLU 180 | nearest -> depth_wise_conv -> 3x3 -> LeakyReLU 181 | nearest -> depth_wise_conv -> 5x5 -> none 182 | nearest -> depth_wise_conv -> 5x5 -> ReLU 183 | nearest -> depth_wise_conv -> 5x5 -> LeakyReLU 184 | nearest -> depth_wise_conv -> 7x7 -> none 185 | nearest -> depth_wise_conv -> 7x7 -> ReLU 186 | nearest -> depth_wise_conv -> 7x7 -> LeakyReLU 187 | nearest -> identity -> 1x1 -> none 188 | nearest -> identity -> 1x1 -> ReLU 189 | nearest -> identity -> 1x1 -> LeakyReLU 190 | pixel_shuffle -> conv -> 1x1 -> none 191 | pixel_shuffle -> conv -> 1x1 -> ReLU 192 | pixel_shuffle -> conv -> 1x1 -> LeakyReLU 193 | pixel_shuffle -> conv -> 3x3 -> none 194 | pixel_shuffle -> conv -> 3x3 -> ReLU 195 | pixel_shuffle -> conv -> 3x3 -> LeakyReLU 196 | pixel_shuffle -> conv -> 5x5 -> none 197 | pixel_shuffle -> conv -> 5x5 -> ReLU 198 | pixel_shuffle -> conv -> 5x5 -> LeakyReLU 199 | pixel_shuffle -> conv -> 7x7 -> none 200 | pixel_shuffle -> conv -> 7x7 -> ReLU 201 | pixel_shuffle -> conv -> 7x7 -> LeakyReLU 202 | pixel_shuffle -> trans_conv -> 1x1 -> none 203 | pixel_shuffle -> trans_conv -> 1x1 -> ReLU 204 | pixel_shuffle -> trans_conv -> 1x1 -> LeakyReLU 205 | pixel_shuffle -> trans_conv -> 3x3 -> none 206 | pixel_shuffle -> trans_conv -> 3x3 -> ReLU 207 | pixel_shuffle -> trans_conv -> 3x3 -> LeakyReLU 208 | pixel_shuffle -> trans_conv -> 5x5 -> none 209 | pixel_shuffle -> trans_conv -> 5x5 -> ReLU 210 | pixel_shuffle -> trans_conv -> 5x5 -> LeakyReLU 211 | pixel_shuffle -> trans_conv -> 7x7 -> none 212 | pixel_shuffle -> trans_conv -> 7x7 -> ReLU 213 | pixel_shuffle -> trans_conv -> 7x7 -> LeakyReLU 214 | pixel_shuffle -> split_stack_sum -> 1x1 -> none 215 | pixel_shuffle -> split_stack_sum -> 1x1 -> ReLU 216 | pixel_shuffle -> split_stack_sum -> 1x1 -> LeakyReLU 217 | pixel_shuffle -> split_stack_sum -> 3x3 -> none 218 | pixel_shuffle -> split_stack_sum -> 3x3 -> ReLU 219 | pixel_shuffle -> split_stack_sum -> 3x3 -> LeakyReLU 220 | pixel_shuffle -> split_stack_sum -> 5x5 -> none 221 | pixel_shuffle -> split_stack_sum -> 5x5 -> ReLU 222 | pixel_shuffle -> split_stack_sum -> 5x5 -> LeakyReLU 223 | pixel_shuffle -> split_stack_sum -> 7x7 -> none 224 | pixel_shuffle -> split_stack_sum -> 7x7 -> ReLU 225 | pixel_shuffle -> split_stack_sum -> 7x7 -> LeakyReLU 226 | pixel_shuffle -> sep_conv -> 1x1 -> none 227 | pixel_shuffle -> sep_conv -> 1x1 -> ReLU 228 | pixel_shuffle -> sep_conv -> 1x1 -> LeakyReLU 229 | pixel_shuffle -> sep_conv -> 3x3 -> none 230 | pixel_shuffle -> sep_conv -> 3x3 -> ReLU 231 | pixel_shuffle -> sep_conv -> 3x3 -> LeakyReLU 232 | pixel_shuffle -> sep_conv -> 5x5 -> none 233 | pixel_shuffle -> sep_conv -> 5x5 -> ReLU 234 | pixel_shuffle -> sep_conv -> 5x5 -> LeakyReLU 235 | pixel_shuffle -> sep_conv -> 7x7 -> none 236 | pixel_shuffle -> sep_conv -> 7x7 -> ReLU 237 | pixel_shuffle -> sep_conv -> 7x7 -> LeakyReLU 238 | pixel_shuffle -> depth_wise_conv -> 1x1 -> none 239 | pixel_shuffle -> depth_wise_conv -> 1x1 -> ReLU 240 | pixel_shuffle -> depth_wise_conv -> 1x1 -> LeakyReLU 241 | pixel_shuffle -> depth_wise_conv -> 3x3 -> none 242 | pixel_shuffle -> depth_wise_conv -> 3x3 -> ReLU 243 | pixel_shuffle -> depth_wise_conv -> 3x3 -> LeakyReLU 244 | pixel_shuffle -> depth_wise_conv -> 5x5 -> none 245 | pixel_shuffle -> depth_wise_conv -> 5x5 -> ReLU 246 | pixel_shuffle -> depth_wise_conv -> 5x5 -> LeakyReLU 247 | pixel_shuffle -> depth_wise_conv -> 7x7 -> none 248 | pixel_shuffle -> depth_wise_conv -> 7x7 -> ReLU 249 | pixel_shuffle -> depth_wise_conv -> 7x7 -> LeakyReLU 250 | pixel_shuffle -> identity -> 1x1 -> none 251 | pixel_shuffle -> identity -> 1x1 -> ReLU 252 | pixel_shuffle -> identity -> 1x1 -> LeakyReLU 253 | pixel_shuffle -> conv -> 1x1 -> none 254 | pixel_shuffle -> conv -> 1x1 -> ReLU 255 | pixel_shuffle -> conv -> 1x1 -> LeakyReLU 256 | pixel_shuffle -> conv -> 3x3 -> none 257 | pixel_shuffle -> conv -> 3x3 -> ReLU 258 | pixel_shuffle -> conv -> 3x3 -> LeakyReLU 259 | pixel_shuffle -> conv -> 5x5 -> none 260 | pixel_shuffle -> conv -> 5x5 -> ReLU 261 | pixel_shuffle -> conv -> 5x5 -> LeakyReLU 262 | pixel_shuffle -> conv -> 7x7 -> none 263 | pixel_shuffle -> conv -> 7x7 -> ReLU 264 | pixel_shuffle -> conv -> 7x7 -> LeakyReLU 265 | pixel_shuffle -> trans_conv -> 1x1 -> none 266 | pixel_shuffle -> trans_conv -> 1x1 -> ReLU 267 | pixel_shuffle -> trans_conv -> 1x1 -> LeakyReLU 268 | pixel_shuffle -> trans_conv -> 3x3 -> none 269 | pixel_shuffle -> trans_conv -> 3x3 -> ReLU 270 | pixel_shuffle -> trans_conv -> 3x3 -> LeakyReLU 271 | pixel_shuffle -> trans_conv -> 5x5 -> none 272 | pixel_shuffle -> trans_conv -> 5x5 -> ReLU 273 | pixel_shuffle -> trans_conv -> 5x5 -> LeakyReLU 274 | pixel_shuffle -> trans_conv -> 7x7 -> none 275 | pixel_shuffle -> trans_conv -> 7x7 -> ReLU 276 | pixel_shuffle -> trans_conv -> 7x7 -> LeakyReLU 277 | pixel_shuffle -> split_stack_sum -> 1x1 -> none 278 | pixel_shuffle -> split_stack_sum -> 1x1 -> ReLU 279 | pixel_shuffle -> split_stack_sum -> 1x1 -> LeakyReLU 280 | pixel_shuffle -> split_stack_sum -> 3x3 -> none 281 | pixel_shuffle -> split_stack_sum -> 3x3 -> ReLU 282 | pixel_shuffle -> split_stack_sum -> 3x3 -> LeakyReLU 283 | pixel_shuffle -> split_stack_sum -> 5x5 -> none 284 | pixel_shuffle -> split_stack_sum -> 5x5 -> ReLU 285 | pixel_shuffle -> split_stack_sum -> 5x5 -> LeakyReLU 286 | pixel_shuffle -> split_stack_sum -> 7x7 -> none 287 | pixel_shuffle -> split_stack_sum -> 7x7 -> ReLU 288 | pixel_shuffle -> split_stack_sum -> 7x7 -> LeakyReLU 289 | pixel_shuffle -> sep_conv -> 1x1 -> none 290 | pixel_shuffle -> sep_conv -> 1x1 -> ReLU 291 | pixel_shuffle -> sep_conv -> 1x1 -> LeakyReLU 292 | pixel_shuffle -> sep_conv -> 3x3 -> none 293 | pixel_shuffle -> sep_conv -> 3x3 -> ReLU 294 | pixel_shuffle -> sep_conv -> 3x3 -> LeakyReLU 295 | pixel_shuffle -> sep_conv -> 5x5 -> none 296 | pixel_shuffle -> sep_conv -> 5x5 -> ReLU 297 | pixel_shuffle -> sep_conv -> 5x5 -> LeakyReLU 298 | pixel_shuffle -> sep_conv -> 7x7 -> none 299 | pixel_shuffle -> sep_conv -> 7x7 -> ReLU 300 | pixel_shuffle -> sep_conv -> 7x7 -> LeakyReLU 301 | pixel_shuffle -> depth_wise_conv -> 1x1 -> none 302 | pixel_shuffle -> depth_wise_conv -> 1x1 -> ReLU 303 | pixel_shuffle -> depth_wise_conv -> 1x1 -> LeakyReLU 304 | pixel_shuffle -> depth_wise_conv -> 3x3 -> none 305 | pixel_shuffle -> depth_wise_conv -> 3x3 -> ReLU 306 | pixel_shuffle -> depth_wise_conv -> 3x3 -> LeakyReLU 307 | pixel_shuffle -> depth_wise_conv -> 5x5 -> none 308 | pixel_shuffle -> depth_wise_conv -> 5x5 -> ReLU 309 | pixel_shuffle -> depth_wise_conv -> 5x5 -> LeakyReLU 310 | pixel_shuffle -> depth_wise_conv -> 7x7 -> none 311 | pixel_shuffle -> depth_wise_conv -> 7x7 -> ReLU 312 | pixel_shuffle -> depth_wise_conv -> 7x7 -> LeakyReLU 313 | trans_conv -> identity -> 1x1 -> none 314 | trans_conv -> identity -> 1x1 -> ReLU 315 | trans_conv -> identity -> 1x1 -> LeakyReLU 316 | trans_conv -> identity -> 3x3 -> none 317 | trans_conv -> identity -> 3x3 -> ReLU 318 | trans_conv -> identity -> 3x3 -> LeakyReLU 319 | trans_conv -> identity -> 5x5 -> none 320 | trans_conv -> identity -> 5x5 -> ReLU 321 | trans_conv -> identity -> 5x5 -> LeakyReLU 322 | -------------------------------------------------------------------------------- /NAS/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | try: 6 | from NAS import operations 7 | except ImportError: 8 | import operations 9 | 10 | try: 11 | from NAS import utils 12 | except ImportError: 13 | import utils 14 | 15 | 16 | class Cell(nn.Module): 17 | 18 | def __init__(self, 19 | genotype, 20 | C_prev, 21 | C_curr, 22 | C_prev_prev=None, 23 | op_type='downsample'): 24 | 25 | super(Cell, self).__init__() 26 | 27 | self.op_type = op_type 28 | 29 | if self.op_type == 'downsample': 30 | self.preprocess0 = operations.ReLUConvBN(C_prev, C_curr, 1, 1, 0) 31 | conv_op_names, indices = zip(*genotype.downsample_conv) 32 | op_names, _ = zip(*genotype.downsample_method) 33 | concat = genotype.downsample_concat 34 | else: 35 | self.preprocess0 = operations.ReLUConvBN(C_prev, C_curr, 1, 1, 0) 36 | if C_prev_prev is not None: 37 | self.preprocess1 = operations.ReLUConvBN(C_prev_prev, C_curr, 1, 1, 0) 38 | conv_op_names, indices = zip(*genotype.upsample_conv) 39 | op_names, _ = zip(*genotype.upsample_method) 40 | concat = genotype.upsample_concat 41 | 42 | self._compile(C_curr=C_curr, 43 | conv_op_names=conv_op_names, 44 | op_names=op_names, 45 | indices=indices, 46 | concat=concat) 47 | 48 | 49 | def _compile(self, C_curr, conv_op_names, op_names, indices, concat): 50 | assert len(op_names) == len(indices) 51 | assert len(conv_op_names) == len(indices) 52 | 53 | self.num_of_nodes = len(op_names) // 2 54 | self._concat = concat 55 | self.multiplier = len(concat) 56 | 57 | self._ops = nn.ModuleList() 58 | 59 | for index in range(len(op_names)): 60 | 61 | if self.op_type == 'downsample': 62 | downsample_name = op_names[index] 63 | conv_name = conv_op_names[index] 64 | 65 | #stride = 2 if index < 2 else 1 66 | stride = 2 if indices[index] < 2 else 1 67 | 68 | downsample_op = operations.DOWNSAMPLE_OPS[downsample_name](C_in=C_curr, stride=stride) 69 | conv_op = operations.CONV_OPS[conv_name](C_in=C_curr, C_out=C_curr, affine=True) 70 | 71 | #print('\n\n[Downsample Op]:', downsample_op) 72 | #print('\n[Conv Op]:', conv_op) 73 | 74 | op = nn.Sequential(downsample_op, conv_op) 75 | 76 | #print('\n[combined Op]:', op) 77 | 78 | else: # upsample 79 | upsample_name = op_names[index] 80 | conv_name = conv_op_names[index] 81 | 82 | #stride = 2 if index < 2 else 1 83 | stride = 2 if indices[index] < 2 else 1 84 | 85 | upsample_op = operations.UPSAMPLE_OPS[upsample_name](C_in=C_curr, stride=stride) 86 | conv_op = operations.CONV_OPS[conv_name](C_in=C_curr, C_out=C_curr, affine=True) 87 | 88 | #print('\n\n[Upsample Op]:', upsample_op) 89 | #print('\n[Conv Op]:', conv_op) 90 | 91 | op = nn.Sequential(upsample_op, conv_op) 92 | 93 | #print('\n[combined Op]:', op) 94 | 95 | self._ops += [op] 96 | 97 | self._indices = indices 98 | 99 | 100 | def forward(self, s0, drop_prob, s1=None): 101 | 102 | #print('[Cell] before s0 shape:', s0.shape) 103 | s0 = self.preprocess0(s0) # C_prev 104 | #print('[Cell] after s0 shape:', s0.shape, '\n\n') 105 | if s1 is None: 106 | s1 = s0 107 | else: 108 | s1 = self.preprocess1(s1) 109 | 110 | states = [s0, s1] 111 | for i in range(self.num_of_nodes): 112 | h1 = states[self._indices[2*i]] 113 | h2 = states[self._indices[2*i+1]] 114 | op1 = self._ops[2*i] 115 | op2 = self._ops[2*i+1] 116 | h1 = op1(h1) 117 | h2 = op2(h2) 118 | if self.training and drop_prob > 0.: 119 | if not isinstance(op1, operations.Identity): 120 | h1 = utils.drop_path(h1, drop_prob) 121 | if not isinstance(op2, operations.Identity): 122 | h2 = utils.drop_path(h2, drop_prob) 123 | s = h1 + h2 124 | states += [s] 125 | 126 | out = torch.cat([states[i] for i in self._concat], dim=1) 127 | 128 | #print('[Cell] out dim:', out.shape) 129 | 130 | return out 131 | 132 | 133 | 134 | class NetworkDIP(nn.Module): 135 | 136 | def __init__(self, 137 | genotype, 138 | num_input_channel=3, 139 | num_output_channel=3, 140 | concat_x=False, 141 | need_bias=True, 142 | norm_layer=nn.InstanceNorm2d, 143 | pad='zero', 144 | filters=[64, 128, 256, 512, 1024], 145 | init_filters=3, # for the airplane case 146 | feature_scale=4, 147 | drop_path_prob=0.2): 148 | 149 | super(NetworkDIP, self).__init__() 150 | 151 | self._layers = len(filters) 152 | self.drop_path_prob = drop_path_prob 153 | 154 | filters = [x // feature_scale for x in filters] 155 | 156 | stem_output_channel = filters[0] if not concat_x else filters[0] - num_input_channel # stem's output channel 157 | 158 | self.stem = Stem(num_input_channel=num_input_channel, 159 | num_output_channel=stem_output_channel, 160 | norm_layer=norm_layer, 161 | need_bias=need_bias, 162 | pad=pad) 163 | 164 | self.cells = nn.ModuleList() 165 | 166 | """ Initializa downsample cells first """ 167 | op_type = 'downsample' 168 | C_prev = stem_output_channel # same as stem's output channel 169 | for i in range(self._layers): 170 | C_curr = filters[i] 171 | cell = Cell(genotype, C_prev=C_prev, C_curr=C_curr, op_type=op_type) 172 | self.cells += [cell] 173 | C_prev = cell.multiplier * C_curr 174 | 175 | 176 | """ Initializa upsample cells first """ 177 | op_type = 'upsample' 178 | up_mode = genotype.upsample_method 179 | C_prev_prev = None 180 | for i in range(self._layers-1, -1, -1): 181 | 182 | #print('[NetworkDIP] Upsample multiplier, filter:', cell.multiplier, filters[i]) 183 | 184 | C_prev = cell.multiplier * filters[i] 185 | if i > 0: 186 | C_curr = filters[i-1] 187 | else: 188 | C_curr = C_prev // 2 # output channel of the NetworkDIP 189 | 190 | 191 | #print('[NetworkDIP] C_prev_prev, C_prev, C_curr:', C_prev_prev, C_prev, C_curr, '\n\n') 192 | cell = Cell(genotype, C_prev_prev=C_prev_prev, C_prev=C_prev, C_curr=C_curr, op_type=op_type) 193 | self.cells += [cell] 194 | C_prev_prev = self.cells[i-1].multiplier * filters[i-1] 195 | 196 | C_curr = cell.multiplier * C_curr 197 | 198 | self.last_layer = nn.Conv2d(in_channels=C_curr, out_channels=num_output_channel, kernel_size=1) 199 | self.last_activ = nn.Sigmoid() 200 | 201 | 202 | def forward(self, data): 203 | 204 | s0 = self.stem(data) 205 | 206 | output_list = [] 207 | 208 | for i, cell in enumerate(self.cells): 209 | if i < len(self.cells) / 2 + 1: 210 | # no skip connection (encoder part and the first cell in the decoder) 211 | #s0 = cell(s0, drop_prob=self.drop_path_prob) 212 | s0 = cell(s0, drop_prob=0) 213 | output_list.append(s0) 214 | else: 215 | s1 = output_list[len(self.cells) - 1 - i] 216 | #s0 = cell(s0=s0, s1=s1, drop_prob=self.drop_path_prob) 217 | s0 = cell(s0=s0, s1=s1, drop_prob=0) 218 | 219 | s0 = self.last_layer(s0) 220 | s0 = self.last_activ(s0) 221 | 222 | return s0 223 | 224 | 225 | 226 | class Stem(nn.Module): 227 | 228 | def __init__(self, 229 | num_input_channel, 230 | num_output_channel, 231 | norm_layer, 232 | need_bias, 233 | pad): 234 | 235 | super(Stem, self).__init__() 236 | 237 | if norm_layer is not None: 238 | self.conv1= nn.Sequential( 239 | conv(num_input_channel, num_output_channel, 3, bias=need_bias, pad=pad), 240 | norm_layer(num_output_channel), 241 | nn.ReLU(), 242 | ) 243 | 244 | self.conv2= nn.Sequential( 245 | conv(num_output_channel, num_output_channel, 3, bias=need_bias, pad=pad), 246 | norm_layer(num_output_channel), 247 | nn.ReLU(), 248 | ) 249 | 250 | else: 251 | self.conv1= nn.Sequential( 252 | conv(num_input_channel, num_output_channel, 3, bias=need_bias, pad=pad), 253 | nn.ReLU(), 254 | ) 255 | 256 | self.conv2= nn.Sequential( 257 | conv(num_output_channel, num_output_channel, 3, bias=need_bias, pad=pad), 258 | nn.ReLU(), 259 | ) 260 | 261 | 262 | def forward(self, inputs): 263 | outputs = self.conv1(inputs) 264 | outputs = self.conv2(outputs) 265 | return outputs 266 | 267 | 268 | def conv(in_f, out_f, kernel_size, stride=1, bias=True, pad='zero', downsample_mode='stride'): 269 | 270 | downsampler = None 271 | if stride != 1 and downsample_mode != 'stride': 272 | 273 | if downsample_mode == 'avg': 274 | downsampler = nn.AvgPool2d(stride, stride) 275 | 276 | elif downsample_mode == 'max': 277 | downsampler = nn.MaxPool2d(stride, stride) 278 | 279 | elif downsample_mode in ['lanczos2', 'lanczos3']: 280 | downsampler = Downsampler(n_planes=out_f, factor=stride, 281 | kernel_type=downsample_mode, phase=0.5, preserve_size=True) 282 | else: 283 | assert False 284 | 285 | stride = 1 286 | 287 | padder = None 288 | to_pad = int((kernel_size - 1) / 2) 289 | 290 | if pad == 'reflection': 291 | padder = nn.ReflectionPad2d(to_pad) 292 | to_pad = 0 293 | 294 | convolver = nn.Conv2d(in_f, out_f, kernel_size, stride, padding=to_pad, bias=bias) 295 | 296 | 297 | layers = filter(lambda x: x is not None, [padder, convolver, downsampler]) 298 | 299 | return nn.Sequential(*layers) 300 | -------------------------------------------------------------------------------- /NAS/model_gen.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | try: 5 | from NAS import genotypes 6 | except ImportError: 7 | import genotypes 8 | 9 | try: 10 | from NAS import model 11 | except ImportError: 12 | import model 13 | 14 | 15 | #random.seed(1) 16 | #torch.manual_seed(1) 17 | 18 | 19 | def gen_ops(num_of_ops): 20 | 21 | prim_ops_list = [] 22 | conv_list = [] 23 | kernel_list = [] 24 | ops_concat = list(range(2, 2+int(num_of_ops/2))) 25 | 26 | # number of operations for the upsampling 27 | for i in range(num_of_ops): 28 | 29 | prim_op_idx = random.randint(0, len(genotypes.UPSAMPLE_PRIMITIVE)-1) 30 | sampled_prim_op = genotypes.UPSAMPLE_PRIMITIVE[prim_op_idx] # adjust the spatial size 31 | 32 | conv_op_idx = random.randint(0, len(genotypes.UPSAMPLE_CONV)-1) 33 | sampled_conv_op = genotypes.UPSAMPLE_CONV[conv_op_idx] # adjust the num of channels 34 | 35 | kernel_idx = random.randint(0, len(genotypes.KERNEL_SIZE)-1) 36 | sampled_kernel = genotypes.KERNEL_SIZE[kernel_idx] # kernel size 37 | 38 | max_node_id = int((i+2)/2) # upper bound of the input id 39 | input_id = random.randint(0, max_node_id) # sample an input id 40 | 41 | prim_ops_list.append((sampled_prim_op, input_id)) 42 | conv_list.append((sampled_conv_op, input_id)) 43 | 44 | if i == 0: 45 | kernel_list.append(sampled_kernel) 46 | 47 | if input_id in ops_concat: 48 | ops_concat.remove(input_id) 49 | 50 | return prim_ops_list, conv_list, kernel_list, ops_concat 51 | 52 | 53 | 54 | def random_search(num_of_ops): 55 | 56 | upsample_prim_method, upsample_conv, upsample_kernel, upsample_concat = gen_ops(num_of_ops) 57 | 58 | DIPGenotype = genotypes.Genotype( 59 | upsample_prim_method, 60 | upsample_conv, 61 | upsample_kernel, 62 | upsample_concat, 63 | ) 64 | 65 | return DIPGenotype 66 | 67 | 68 | 69 | def model_gen(search_type='random_search', num_input_channel=32, num_of_nodes=4): 70 | 71 | if search_type == 'random_search': 72 | num_of_ops = num_of_nodes * 2 73 | sampled_genotype = random_search(num_of_ops=num_of_ops) # a genotype 74 | 75 | net = model.NetworkDIP(genotype=sampled_genotype, 76 | num_input_channel=num_input_channel) 77 | 78 | return net, sampled_genotype 79 | -------------------------------------------------------------------------------- /NAS/operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | UPSAMPLE_PRIMITIVE_OPS = { 6 | 'bilinear': lambda C_in, C_out, kernel_size, act_op: BilinearOp(stride=2, upsample_mode='bilinear', act_op=act_op), 7 | 'bicubic': lambda C_in, C_out, kernel_size, act_op: BilinearOp(stride=2, upsample_mode='bicubic', act_op=act_op), 8 | 'nearest': lambda C_in, C_out, kernel_size, act_op: BilinearOp(stride=2, upsample_mode='nearest', act_op=act_op), 9 | 'trans_conv': lambda C_in, C_out, kernel_size, act_op: TransConvOp(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op, stride=2), 10 | 'pixel_shuffle': lambda C_in, C_out, kernel_size, act_op: DepthToSpaceOp(act_op=act_op, stride=2), 11 | } 12 | 13 | 14 | UPSAMPLE_CONV_OPS = { 15 | 'conv': lambda C_in, C_out, kernel_size, act_op: ConvOp(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op), 16 | 'trans_conv': lambda C_in, C_out, kernel_size, act_op: TransConvOp(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op, stride=1), 17 | 'split_stack_sum': lambda C_in, C_out, kernel_size, act_op: SplitStackSum(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op), 18 | 'sep_conv': lambda C_in, C_out, kernel_size, act_op: SepConvOp(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op), 19 | 'depth_wise_conv': lambda C_in, C_out, kernel_size, act_op: DepthWiseConvOp(C_in=C_in, C_out=C_out, kernel_size=kernel_size, act_op=act_op), 20 | 'identity': lambda C_in, C_out, kernel_size, act_op: Identity(), 21 | } 22 | 23 | 24 | KERNEL_SIZE_OPS = { 25 | '1x1': 1, 26 | '3x3': 3, 27 | '4x4': 4, 28 | '5x5': 5, 29 | '7x7': 7, 30 | } 31 | 32 | 33 | DILATION_RATE_OPS = { 34 | '1': 1, 35 | '2': 2, 36 | '3': 3, 37 | } 38 | 39 | 40 | PADDING_OPS = { 41 | '1x1': 0, 42 | '3x3': 1, 43 | '5x5': 2, 44 | '7x7': 3, 45 | } 46 | 47 | 48 | ACTIVATION_OPS = { 49 | 'none': None, 50 | 'ReLU': nn.ReLU(), 51 | 'LeakyReLU': nn.LeakyReLU(0.2, inplace=False), 52 | } 53 | 54 | 55 | class BilinearOp(nn.Module): 56 | 57 | def __init__(self, 58 | stride, 59 | upsample_mode, 60 | act_op): 61 | 62 | super(BilinearOp, self).__init__() 63 | 64 | activation = ACTIVATION_OPS[act_op] 65 | 66 | if not activation: 67 | self.op = nn.Sequential( 68 | nn.Upsample(scale_factor=stride, mode=upsample_mode), 69 | ) 70 | 71 | else: 72 | self.op = nn.Sequential( 73 | nn.Upsample(scale_factor=stride, mode=upsample_mode), 74 | activation, 75 | ) 76 | 77 | def forward(self, x): 78 | return self.op(x) 79 | 80 | 81 | class DepthToSpaceOp(nn.Module): 82 | 83 | def __init__(self, 84 | stride, 85 | act_op, 86 | affine=True): 87 | 88 | super(DepthToSpaceOp, self).__init__() 89 | 90 | activation = ACTIVATION_OPS[act_op] 91 | 92 | if not activation: 93 | self.op = nn.Sequential( 94 | nn.PixelShuffle(stride), 95 | ) 96 | 97 | else: 98 | self.op = nn.Sequential( 99 | nn.PixelShuffle(stride), 100 | activation, 101 | ) 102 | 103 | 104 | def forward(self, x): 105 | return self.op(x) 106 | 107 | 108 | class TransConvOp(nn.Module): 109 | 110 | def __init__(self, 111 | C_in, 112 | C_out, 113 | kernel_size, 114 | stride, 115 | act_op, 116 | affine=True): 117 | 118 | super(TransConvOp, self).__init__() 119 | 120 | padding = PADDING_OPS[kernel_size] 121 | kernel_size = KERNEL_SIZE_OPS[kernel_size] 122 | activation = ACTIVATION_OPS[act_op] 123 | 124 | if not activation: 125 | self.op = nn.Sequential( 126 | nn.ConvTranspose2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=stride-1), 127 | ) 128 | 129 | else: 130 | self.op = nn.Sequential( 131 | nn.ConvTranspose2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=stride-1), 132 | activation, 133 | ) 134 | 135 | def forward(self, x): 136 | return self.op(x) 137 | 138 | 139 | class ConvOp(nn.Module): 140 | 141 | def __init__(self, 142 | C_in, 143 | C_out, 144 | kernel_size, 145 | act_op, 146 | affine=True): 147 | 148 | super(ConvOp, self).__init__() 149 | 150 | padding = PADDING_OPS[kernel_size] 151 | kernel_size = KERNEL_SIZE_OPS[kernel_size] 152 | activation = ACTIVATION_OPS[act_op] 153 | 154 | if not activation: 155 | self.op = nn.Sequential( 156 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, padding=padding, bias=False), 157 | ) 158 | 159 | else: 160 | self.op = nn.Sequential( 161 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, padding=padding, bias=False), 162 | activation, 163 | ) 164 | 165 | 166 | def forward(self, x): 167 | return self.op(x) 168 | 169 | 170 | class Identity(nn.Module): 171 | 172 | def __init__(self): 173 | 174 | super(Identity, self).__init__() 175 | 176 | def forward(self, x): 177 | return x 178 | 179 | 180 | class SplitStackSum(nn.Module): 181 | 182 | def __init__(self, 183 | C_in, 184 | C_out, 185 | kernel_size, 186 | act_op, 187 | split=4, 188 | affine=True): 189 | 190 | super(SplitStackSum, self).__init__() 191 | 192 | padding = PADDING_OPS[kernel_size] 193 | kernel_size = KERNEL_SIZE_OPS[kernel_size] 194 | activation = ACTIVATION_OPS[act_op] 195 | 196 | self.chuck_size = int(C_in/split) 197 | 198 | if not activation: 199 | self.op = nn.Sequential( 200 | nn.Conv2d(int(C_in/split), C_out, kernel_size=kernel_size, padding=padding, bias=False), 201 | ) 202 | 203 | else: 204 | self.op = nn.Sequential( 205 | nn.Conv2d(int(C_in/split), C_out, kernel_size=kernel_size, padding=padding, bias=False), 206 | activation, 207 | ) 208 | 209 | 210 | def forward(self, x): 211 | split = torch.split(x, self.chuck_size, dim=1) # the resulting number of channels will be 1/4 of the number of input channels 212 | stack = torch.stack(split, dim=1) 213 | out = torch.sum(stack, dim=1) 214 | out = self.op(out) 215 | return out 216 | 217 | 218 | class SepConvOp(nn.Module): 219 | 220 | def __init__(self, 221 | C_in, 222 | C_out, 223 | kernel_size, 224 | act_op, 225 | affine=True): 226 | 227 | super(SepConvOp, self).__init__() 228 | 229 | padding = PADDING_OPS[kernel_size] 230 | kernel_size = KERNEL_SIZE_OPS[kernel_size] 231 | activation = ACTIVATION_OPS[act_op] 232 | 233 | if not activation: 234 | self.op = nn.Sequential( 235 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, groups=C_in, bias=False), # per chaneel conv 236 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), # pointwise conv (1x1 conv) 237 | ) 238 | 239 | else: 240 | self.op = nn.Sequential( 241 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, groups=C_in, bias=False), # per chaneel conv 242 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), # pointwise conv (1x1 conv) 243 | activation, 244 | ) 245 | 246 | def forward(self, x): 247 | return self.op(x) 248 | 249 | 250 | class DepthWiseConvOp(nn.Module): 251 | 252 | def __init__(self, 253 | C_in, 254 | C_out, 255 | kernel_size, 256 | act_op, 257 | affine=True): 258 | 259 | super(DepthWiseConvOp, self).__init__() 260 | 261 | padding = PADDING_OPS[kernel_size] 262 | kernel_size = KERNEL_SIZE_OPS[kernel_size] 263 | activation = ACTIVATION_OPS[act_op] 264 | 265 | if not activation: 266 | self.op = nn.Sequential( 267 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, padding=padding, groups=C_out, bias=False), # per chaneel conv 268 | ) 269 | 270 | else: 271 | self.op = nn.Sequential( 272 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, padding=padding, groups=C_out, bias=False), # per chaneel conv 273 | activation, 274 | ) 275 | 276 | def forward(self, x): 277 | return self.op(x) 278 | -------------------------------------------------------------------------------- /NAS/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | 4 | 5 | def drop_path(x, drop_prob, use_cuda=True): 6 | if drop_prob > 0.: 7 | keep_prob = 1.-drop_prob 8 | #mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 9 | if use_cuda: 10 | mask = Variable(torch.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)).cuda() 11 | else: 12 | mask = Variable(torch.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 13 | x.div_(keep_prob) 14 | x.mul_(mask) 15 | return x 16 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # NAS-DIP: Learning Deep Image Prior with Neural Architecture Search 2 | 3 | This repository contains the source code for the paper NAS-DIP: Learning Deep Image Prior with Neural Architecture Search. 4 | 5 | 6 | 7 | ## Abstract 8 | Recent work has shown that the structure of deep convolutional neural networks can be used as a structured image prior for solving various inverse image restoration tasks. Instead of using hand-designed architectures, we propose to search for neural architectures that capture stronger image priors. Building upon a generic U-Net architecture, our core contribution lies in designing new search spaces for (1) an upsampling cell and (2) a pattern of cross-scale residual connections. We search for an improved network by leveraging an existing neural architecture search algorithm (using reinforcement learning with a recurrent neural network controller). We validate the effectiveness of our method via a wide variety of applications, including image restoration, dehazing, image-to-image translation, and matrix factorization. Extensive experimental results show that our algorithm performs favorably against state-of-the-art learning-free approaches and reaches competitive performance with existing learning-based methods in some cases. 9 | 10 | ## Citation 11 | If you find our code useful, please consider citing our work using the following bibtex: 12 | ``` 13 | @inproceedings{NAS-DIP, 14 | title={NAS-DIP: Learning Deep Image Prior with Neural Architecture Search}, 15 | author={Chen, Yun-Chun and Gao, Chen and Robb, Esther and Huang, Jia-Bin}, 16 | booktitle={European Conference on Computer Vision (ECCV)}, 17 | year={2020} 18 | } 19 | ``` 20 | 21 | ## Acknowledgement 22 | - This code is heavily borrowed from [Ulyanov et al.](https://github.com/DmitryUlyanov/deep-image-prior) 23 | -------------------------------------------------------------------------------- /img/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YunChunChen/NAS-DIP-pytorch/3dfb4cf6312599097a5a193d22fd8591467e1a6f/img/teaser.png --------------------------------------------------------------------------------