├── .gitignore ├── README.md ├── denoising ├── datasets.py ├── pytorch_ssim.py ├── architect.py ├── test_model.py ├── utils.py ├── train_model.py ├── train_search.py └── search_space.py └── deraining ├── datasets.py ├── pytorch_ssim.py ├── architect.py ├── test_model.py ├── utils.py ├── train_model.py ├── train_search.py └── search_space.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CLEARER: Multi-Scale Neural Architecture Search for Image Restoration 2 | 3 | Pytorch implementation of CLEARER (NeurIPS 2020) [[paper](https://papers.nips.cc/paper/2020/hash/c6e81542b125c36346d9167691b8bd09-Abstract.html)] based on [nvidia-docker](https://github.com/NVIDIA/nvidia-docker) and [ufoym/deepo](https://github.com/ufoym/deepo) 4 | 5 | ## Dependencies 6 | 7 | * python == 3.6.8 8 | 9 | * pytorch == 1.4.0 10 | 11 | * torchvision == 0.5.0 12 | 13 | * CUDA == 10.1 14 | -------------------------------------------------------------------------------- /denoising/datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | from utils import crop_patch, random_augmentation, crop_img 9 | 10 | class PatchSimulateDataset(Dataset): 11 | def __init__(self, pth, length, patch_size, sigma): 12 | super(PatchSimulateDataset, self).__init__() 13 | self.len = length 14 | self.patch_size = patch_size 15 | self.sigma = sigma 16 | 17 | rgb2gray = transforms.Grayscale(1) 18 | self.images = [] 19 | self.num_images = 0 20 | for p in glob.glob(pth+'*.jpg'): 21 | img = Image.open(p) 22 | self.images.append(np.array(rgb2gray(img))[..., np.newaxis]) 23 | self.num_images += 1 24 | 25 | self.transform = transforms.ToTensor() 26 | 27 | def __getitem__(self, index): 28 | idx = random.randint(0, self.num_images-1) 29 | img_patch = crop_patch(self.images[idx], self.patch_size) 30 | 31 | # generate gaussian noise N(0, sigma^2) 32 | noise = np.random.randn(*(img_patch.shape)) 33 | noise_patch = np.clip(img_patch + noise*self.sigma, 0, 255).astype(np.uint8) 34 | 35 | aug_list = random_augmentation(img_patch, noise_patch) 36 | return self.transform(aug_list[1]), self.transform(aug_list[0]) 37 | 38 | def __len__(self): 39 | return self.len -------------------------------------------------------------------------------- /deraining/datasets.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | from PIL import Image 4 | import numpy as np 5 | from torchvision import transforms 6 | from torch.utils.data import Dataset 7 | 8 | from utils import random_augmentation 9 | 10 | # Rain800 dataset 11 | class Rain800(Dataset): 12 | def __init__(self, pth, length, patch_size): 13 | super(Rain800, self).__init__() 14 | self.rain_images = [] 15 | self.clear_images = [] 16 | self.patch_size = patch_size 17 | self.len = length 18 | self.num_images = 0 19 | for p in glob.glob(pth+'*.jpg'): 20 | img = np.array(Image.open(p)) 21 | assert img.shape[1]%2==0 22 | self.clear_images.append(img[:,:img.shape[1]//2,:]) 23 | self.rain_images.append(img[:,img.shape[1]//2:,:]) 24 | self.num_images += 1 25 | # Image.fromarray(img[:,:img.shape[1]//2,:]).save('datasets/rain800/training_split/clear/'+p.split('/')[-1]) 26 | # Image.fromarray(img[:,img.shape[1]//2:,:]).save('datasets/rain800/training_split/rain/'+p.split('/')[-1]) 27 | self.transform = transforms.ToTensor() 28 | 29 | def __getitem__(self,index): 30 | idx = random.randint(0, self.num_images-1) 31 | H = self.clear_images[idx].shape[0] 32 | W = self.clear_images[idx].shape[1] 33 | ind_H = random.randint(0, H-self.patch_size) 34 | ind_W = random.randint(0, W-self.patch_size) 35 | 36 | clear_patch = self.clear_images[idx][ind_H:ind_H+self.patch_size, ind_W:ind_W+self.patch_size] 37 | rain_patch = self.rain_images[idx][ind_H:ind_H+self.patch_size, ind_W:ind_W+self.patch_size] 38 | aug_list = random_augmentation(clear_patch, rain_patch) 39 | return self.transform(aug_list[1]), self.transform(aug_list[0]) 40 | 41 | def __len__(self): 42 | return self.len -------------------------------------------------------------------------------- /denoising/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code: https://github.com/Po-Hsun-Su/pytorch-ssim 3 | ''' 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from math import exp 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 12 | return gauss/gauss.sum() 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 21 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 22 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1*mu2 27 | 28 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 30 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 31 | 32 | C1 = 0.01**2 33 | C2 = 0.03**2 34 | 35 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | # (B, C, H, W) 69 | def ssim(img1, img2, window_size = 11, size_average = True): 70 | (_, channel, _, _) = img1.size() 71 | window = create_window(window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | return _ssim(img1, img2, window, window_size, channel, size_average) 78 | -------------------------------------------------------------------------------- /deraining/pytorch_ssim.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Code: https://github.com/Po-Hsun-Su/pytorch-ssim 3 | ''' 4 | import torch 5 | import torch.nn.functional as F 6 | from torch.autograd import Variable 7 | import numpy as np 8 | from math import exp 9 | 10 | def gaussian(window_size, sigma): 11 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 12 | return gauss/gauss.sum() 13 | 14 | def create_window(window_size, channel): 15 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 16 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 17 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 18 | return window 19 | 20 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 21 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 22 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 23 | 24 | mu1_sq = mu1.pow(2) 25 | mu2_sq = mu2.pow(2) 26 | mu1_mu2 = mu1*mu2 27 | 28 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 29 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 30 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 31 | 32 | C1 = 0.01**2 33 | C2 = 0.03**2 34 | 35 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 36 | 37 | if size_average: 38 | return ssim_map.mean() 39 | else: 40 | return ssim_map.mean(1).mean(1).mean(1) 41 | 42 | class SSIM(torch.nn.Module): 43 | def __init__(self, window_size = 11, size_average = True): 44 | super(SSIM, self).__init__() 45 | self.window_size = window_size 46 | self.size_average = size_average 47 | self.channel = 1 48 | self.window = create_window(window_size, self.channel) 49 | 50 | def forward(self, img1, img2): 51 | (_, channel, _, _) = img1.size() 52 | 53 | if channel == self.channel and self.window.data.type() == img1.data.type(): 54 | window = self.window 55 | else: 56 | window = create_window(self.window_size, channel) 57 | 58 | if img1.is_cuda: 59 | window = window.cuda(img1.get_device()) 60 | window = window.type_as(img1) 61 | 62 | self.window = window 63 | self.channel = channel 64 | 65 | 66 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 67 | 68 | # (B, C, H, W) 69 | def ssim(img1, img2, window_size = 11, size_average = True): 70 | (_, channel, _, _) = img1.size() 71 | window = create_window(window_size, channel) 72 | 73 | if img1.is_cuda: 74 | window = window.cuda(img1.get_device()) 75 | window = window.type_as(img1) 76 | 77 | return _ssim(img1, img2, window, window_size, channel, size_average) 78 | -------------------------------------------------------------------------------- /deraining/architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | def _concat(xs): 6 | return torch.cat([x.view(-1) for x in xs]) 7 | 8 | class Architect(): 9 | def __init__(self, model, args): 10 | self.network_momentum = args.momentum 11 | self.network_weight_decay = args.weight_decay 12 | self.model = model 13 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 14 | 15 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): 16 | self.optimizer.zero_grad() 17 | if unrolled: 18 | print('args.unrolled should be False') 19 | exit() 20 | self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer) 21 | else: 22 | self._backward_step(input_valid, target_valid) 23 | self.optimizer.step() 24 | 25 | def _backward_step(self, input_valid, target_valid): 26 | loss = self.model.loss(input_valid, target_valid) 27 | loss.backward() 28 | 29 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): 30 | unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer) 31 | unrolled_loss = unrolled_model.loss(input_valid, target_valid) 32 | unrolled_loss.backward() 33 | 34 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 35 | vector = [v.grad.data for v in unrolled_model.parameters()] 36 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 37 | 38 | for g, ig in zip(dalpha, implicit_grads): 39 | g.data.sub_(eta, ig.data) 40 | 41 | for v, g in zip(self.model.arch_parameters(), dalpha): 42 | if v.grad is None: 43 | v.grad = g.data 44 | else: 45 | v.grad.data.copy_(g.data) 46 | 47 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 48 | loss = self.model.loss(input, target) 49 | theta = _concat(self.model.parameters()).data 50 | try: 51 | moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum) 52 | except: 53 | moment = torch.zeros_like(theta) 54 | dtheta = _concat(torch.autograd.grad(loss, self.model.parameters(), retain_graph=True)).data + self.network_weight_decay*theta 55 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) 56 | return unrolled_model 57 | 58 | def _construct_model_from_theta(self, theta): 59 | model_new = self.model.new() 60 | model_dict = self.model.state_dict() 61 | 62 | params, offset = {}, 0 63 | for k, v in self.model.named_parameters(): 64 | v_length = np.prod(v.size()) 65 | params[k] = theta[offset: offset+v_length].view(v.size()) 66 | offset += v_length 67 | 68 | assert offset == len(theta) 69 | model_dict.update(params) 70 | model_new.load_state_dict(model_dict) 71 | return model_new.cuda() 72 | 73 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 74 | R = r / _concat(vector).norm() 75 | for p, v in zip(self.model.parameters(), vector): 76 | p.data.add_(R, v) 77 | loss = self.model.loss(input, target) 78 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 79 | 80 | for p, v in zip(self.model.parameters(), vector): 81 | p.data.sub_(2*R, v) 82 | loss = self.model.loss(input, target) 83 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 84 | 85 | for p, v in zip(self.model.parameters(), vector): 86 | p.data.add_(R, v) 87 | 88 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] -------------------------------------------------------------------------------- /denoising/architect.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | 5 | def _concat(xs): 6 | return torch.cat([x.view(-1) for x in xs]) 7 | 8 | class Architect(): 9 | def __init__(self, model, args): 10 | self.network_momentum = args.momentum 11 | self.network_weight_decay = args.weight_decay 12 | self.model = model 13 | self.optimizer = torch.optim.Adam(self.model.arch_parameters(), lr=args.arch_learning_rate, betas=(0.5, 0.999), weight_decay=args.arch_weight_decay) 14 | 15 | def step(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer, unrolled): 16 | self.optimizer.zero_grad() 17 | if unrolled: 18 | print('args.unrolled should be False') 19 | exit() 20 | # self._backward_step_unrolled(input_train, target_train, input_valid, target_valid, eta, network_optimizer) 21 | else: 22 | self._backward_step(input_valid, target_valid) 23 | self.optimizer.step() 24 | 25 | def _backward_step(self, input_valid, target_valid): 26 | loss = self.model.loss(input_valid, target_valid) 27 | loss.backward() 28 | 29 | def _backward_step_unrolled(self, input_train, target_train, input_valid, target_valid, eta, network_optimizer): 30 | unrolled_model = self._compute_unrolled_model(input_train, target_train, eta, network_optimizer) 31 | unrolled_loss = unrolled_model.loss(input_valid, target_valid) 32 | unrolled_loss.backward() 33 | 34 | dalpha = [v.grad for v in unrolled_model.arch_parameters()] 35 | vector = [v.grad.data for v in unrolled_model.parameters()] 36 | implicit_grads = self._hessian_vector_product(vector, input_train, target_train) 37 | 38 | for g, ig in zip(dalpha, implicit_grads): 39 | g.data.sub_(eta, ig.data) 40 | 41 | for v, g in zip(self.model.arch_parameters(), dalpha): 42 | if v.grad is None: 43 | v.grad = g.data 44 | else: 45 | v.grad.data.copy_(g.data) 46 | 47 | def _compute_unrolled_model(self, input, target, eta, network_optimizer): 48 | loss = self.model.loss(input, target) 49 | theta = _concat(self.model.parameters()).data 50 | try: 51 | moment = _concat(network_optimizer.state[v]['momentum_buffer'] for v in self.model.parameters()).mul_(self.network_momentum) 52 | except: 53 | moment = torch.zeros_like(theta) 54 | dtheta = _concat(torch.autograd.grad(loss, self.model.parameters(), retain_graph=True)).data + self.network_weight_decay*theta 55 | unrolled_model = self._construct_model_from_theta(theta.sub(eta, moment+dtheta)) 56 | return unrolled_model 57 | 58 | def _construct_model_from_theta(self, theta): 59 | model_new = self.model.new() 60 | model_dict = self.model.state_dict() 61 | 62 | params, offset = {}, 0 63 | for k, v in self.model.named_parameters(): 64 | v_length = np.prod(v.size()) 65 | params[k] = theta[offset: offset+v_length].view(v.size()) 66 | offset += v_length 67 | 68 | assert offset == len(theta) 69 | model_dict.update(params) 70 | model_new.load_state_dict(model_dict) 71 | return model_new.cuda() 72 | 73 | def _hessian_vector_product(self, vector, input, target, r=1e-2): 74 | R = r / _concat(vector).norm() 75 | for p, v in zip(self.model.parameters(), vector): 76 | p.data.add_(R, v) 77 | loss = self.model.loss(input, target) 78 | grads_p = torch.autograd.grad(loss, self.model.arch_parameters()) 79 | 80 | for p, v in zip(self.model.parameters(), vector): 81 | p.data.sub_(2*R, v) 82 | loss = self.model.loss(input, target) 83 | grads_n = torch.autograd.grad(loss, self.model.arch_parameters()) 84 | 85 | for p, v in zip(self.model.parameters(), vector): 86 | p.data.add_(R, v) 87 | 88 | return [(x-y).div_(2*R) for x, y in zip(grads_p, grads_n)] -------------------------------------------------------------------------------- /deraining/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import time 6 | import logging 7 | import glob 8 | import argparse 9 | import torchvision 10 | from PIL import Image 11 | 12 | import utils 13 | import pytorch_ssim 14 | 15 | parser = argparse.ArgumentParser("deraining") 16 | parser.add_argument('--data', type=str, default='datasets/rain800/test_syn/*.jpg', help='location of the data corpus') 17 | parser.add_argument('--patch_size', type=int, default=64, help='patch size') 18 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 19 | parser.add_argument('--gpu', type=str, default='1', help='gpu device id') 20 | parser.add_argument('--model_path', type=str, default='eval-EXP-20201112-191755/last_weights.pt') 21 | args = parser.parse_args() 22 | 23 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 24 | 25 | args.save = 'test-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 26 | utils.create_exp_dir(args.save) 27 | log_format = '%(asctime)s %(message)s' 28 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 29 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 30 | fh.setFormatter(logging.Formatter(log_format)) 31 | logging.getLogger().addHandler(fh) 32 | 33 | def main(): 34 | if not torch.cuda.is_available(): 35 | logging.info('no gpu device available') 36 | sys.exit(1) 37 | 38 | logging.info("args = %s", args) 39 | 40 | torch.backends.cudnn.benchmark = True 41 | torch.backends.cudnn.enabled=True 42 | 43 | model = torch.load(args.model_path) 44 | model = model.cuda() 45 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 46 | 47 | psnr, ssim = infer(args.data, model) 48 | logging.info('psnr:%6f ssim:%6f', psnr, ssim) 49 | 50 | def infer(data_path, model): 51 | psnr = utils.AvgrageMeter() 52 | ssim = utils.AvgrageMeter() 53 | 54 | model.eval() 55 | transforms = torchvision.transforms.Compose([ 56 | torchvision.transforms.ToTensor() 57 | ]) 58 | 59 | with torch.no_grad(): 60 | for step, pt in enumerate(glob.glob(data_path)): 61 | image = np.array(Image.open(pt)) 62 | 63 | clear_image = utils.crop_img(image[:,:image.shape[1]//2,:], base=args.patch_size) 64 | rain_image = utils.crop_img(image[:,image.shape[1]//2:,:], base=args.patch_size) 65 | 66 | # # Test on whole image 67 | # input = transforms(rain_image).unsqueeze(dim=0).cuda() 68 | # target = transforms(clear_image).unsqueeze(dim=0).cuda(async=True) 69 | # logits = model(input) 70 | # n = input.size(0) 71 | 72 | # Test on whole image with data augmentation 73 | target = transforms(clear_image).unsqueeze(dim=0).cuda() 74 | for i in range(8): 75 | im = utils.data_augmentation(rain_image,i) 76 | input = transforms(im.copy()).unsqueeze(dim=0).cuda() 77 | begin_time = time.time() 78 | if i == 0: 79 | logits = utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i) 80 | else: 81 | logits = logits + utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i) 82 | end_time = time.time() 83 | n = input.size(0) 84 | logits = transforms(logits/8).unsqueeze(dim=0).cuda() 85 | 86 | # # Test on patches2patches 87 | # noise_patches = utils.slice_image2patches(rain_image, patch_size=args.patch_size) 88 | # image_patches = utils.slice_image2patches(clear_image, patch_size=args.patch_size) 89 | # input = torch.tensor(noise_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda() 90 | # target = torch.tensor(image_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda() 91 | # logits = model(input) 92 | # n = input.size(0) 93 | 94 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 95 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 96 | psnr.update(p, n) 97 | ssim.update(s, n) 98 | print('psnr:%6f ssim:%6f' % (p, s)) 99 | 100 | # Image.fromarray(rain_image).save(args.save+'/'+str(step)+'_noise.png') 101 | # Image.fromarray(np.clip(logits[0].cpu().numpy().transpose(1,2,0)*255, 0, 255).astype(np.uint8)).save(args.save+'/'+str(step)+'_denoised.png') 102 | 103 | return psnr.avg, ssim.avg 104 | 105 | if __name__ == '__main__': 106 | main() -------------------------------------------------------------------------------- /denoising/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | import time 6 | import logging 7 | import glob 8 | import argparse 9 | import torchvision 10 | from PIL import Image 11 | import utils 12 | import pytorch_ssim 13 | 14 | parser = argparse.ArgumentParser("denoising") 15 | parser.add_argument('--data', type=str, default='./datasets/BSD500/test/*.jpg', help='location of the data') 16 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 17 | parser.add_argument('--gpu', type=str, default='1', help='gpu device id') 18 | parser.add_argument('--model_path', type=str, default='eval-EXP-20201112-173215/best_loss_weights.pt') 19 | parser.add_argument('--sigma', type=int, default=30, help='noise sigma') 20 | args = parser.parse_args() 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 23 | 24 | args.save = 'test-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 25 | utils.create_exp_dir(args.save) 26 | log_format = '%(asctime)s %(message)s' 27 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 28 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 29 | fh.setFormatter(logging.Formatter(log_format)) 30 | logging.getLogger().addHandler(fh) 31 | 32 | def main(): 33 | if not torch.cuda.is_available(): 34 | logging.info('no gpu device available') 35 | sys.exit(1) 36 | 37 | logging.info("args = %s", args) 38 | 39 | torch.backends.cudnn.benchmark = True 40 | torch.backends.cudnn.enabled=True 41 | 42 | model = torch.load(args.model_path) 43 | model = model.cuda() 44 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 45 | 46 | psnr, ssim, time = infer(args.data, model) 47 | logging.info('psnr:%6f ssim:%6f time:%6f', psnr, ssim, time) 48 | 49 | def infer(data_path, model): 50 | psnr = utils.AvgrageMeter() 51 | ssim = utils.AvgrageMeter() 52 | times = utils.AvgrageMeter() 53 | 54 | model.eval() 55 | rgb2gray = torchvision.transforms.Grayscale(1) 56 | transforms = torchvision.transforms.Compose([ 57 | torchvision.transforms.ToTensor() 58 | ]) 59 | 60 | with torch.no_grad(): 61 | for step, pt in enumerate(glob.glob(data_path)): 62 | image = utils.crop_img(np.array(rgb2gray(Image.open(pt)))[..., np.newaxis]) 63 | noise_map = np.random.randn(*(image.shape))*args.sigma 64 | noise_img = np.clip(image+noise_map,0,255).astype(np.uint8) 65 | 66 | # # Test on whole image 67 | # input = transforms(noise_img).unsqueeze(dim=0).cuda() 68 | # target = transforms(image).unsqueeze(dim=0).cuda() 69 | # begin_time = time.time() 70 | # logits = model(input) 71 | # end_time = time.time() 72 | # n = input.size(0) 73 | 74 | # Test on whole image with data augmentation 75 | target = transforms(image).unsqueeze(dim=0).cuda() 76 | for i in range(8): 77 | im = utils.data_augmentation(noise_img,i) 78 | input = transforms(im.copy()).unsqueeze(dim=0).cuda() 79 | begin_time = time.time() 80 | if i == 0: 81 | logits = utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i) 82 | else: 83 | logits = logits + utils.inverse_augmentation(model(input).cpu().numpy().transpose(0,2,3,1)[0],i) 84 | end_time = time.time() 85 | n = input.size(0) 86 | logits = transforms(logits/8).unsqueeze(dim=0).cuda() 87 | 88 | # # Test on patches2patches 89 | # noise_patches = utils.slice_image2patches(noise_img, patch_size=64) 90 | # image_patches = utils.slice_image2patches(image, patch_size=64) 91 | # input = torch.tensor(noise_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda() 92 | # target = torch.tensor(image_patches.transpose(0,3,1,2)/255.0, dtype=torch.float32).cuda() 93 | # begin_time = time.time() 94 | # logits = model(input) 95 | # end_time = time.time() 96 | # n = input.size(0) 97 | 98 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 99 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 100 | t = end_time-begin_time 101 | psnr.update(p, n) 102 | ssim.update(s, n) 103 | times.update(t,n) 104 | print('psnr:%6f ssim:%6f time:%6f' % (p, s, t)) 105 | 106 | # Image.fromarray(noise_img[...,0]).save(args.save+'/'+str(step)+'_noise.png') 107 | # Image.fromarray(np.clip(logits[0,0].cpu().numpy()*255, 0, 255).astype(np.uint8)).save(args.save+'/'+str(step)+'_denoised.png') 108 | 109 | return psnr.avg, ssim.avg, times.avg 110 | 111 | if __name__ == '__main__': 112 | main() -------------------------------------------------------------------------------- /denoising/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import torch 5 | import random 6 | import shutil 7 | import torchvision.transforms as transforms 8 | import torch.nn.functional as F 9 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 10 | from torch.nn import init 11 | 12 | class AvgrageMeter(object): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.avg = 0 18 | self.sum = 0 19 | self.cnt = 0 20 | 21 | def update(self, val, n=1): 22 | self.sum += val * n 23 | self.cnt += n 24 | self.avg = self.sum / self.cnt 25 | 26 | def count_parameters_in_MB(model): 27 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 28 | 29 | def create_exp_dir(path, scripts_to_save=None): 30 | if not os.path.exists(path): 31 | os.mkdir(path) 32 | print('Experiment dir : {}'.format(path)) 33 | 34 | if scripts_to_save is not None: 35 | os.mkdir(os.path.join(path, 'scripts')) 36 | for script in scripts_to_save: 37 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 38 | shutil.copyfile(script, dst_file) 39 | 40 | # randomly crop a patch from image 41 | def crop_patch(im, pch_size): 42 | H = im.shape[0] 43 | W = im.shape[1] 44 | ind_H = random.randint(0, H-pch_size) 45 | ind_W = random.randint(0, W-pch_size) 46 | pch = im[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size] 47 | return pch 48 | 49 | # crop an image to the multiple of base 50 | def crop_img(image, base=64): 51 | h = image.shape[0] 52 | w = image.shape[1] 53 | crop_h = h % base 54 | crop_w = w % base 55 | return image[crop_h//2:h-crop_h+crop_h//2, crop_w//2:w-crop_w+crop_w//2, :] 56 | 57 | # image (H, W, C) -> patches (B, H, W, C) 58 | def slice_image2patches(image, patch_size=64, overlap=0): 59 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 60 | H = image.shape[0] 61 | W = image.shape[1] 62 | patches = [] 63 | image_padding = np.pad(image,((overlap,overlap),(overlap,overlap),(0,0)),mode='edge') 64 | for h in range(H//patch_size): 65 | for w in range(W//patch_size): 66 | idx_h = [h*patch_size,(h+1)*patch_size+overlap] 67 | idx_w = [w*patch_size,(w+1)*patch_size+overlap] 68 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1],idx_w[0]:idx_w[1],:],axis=0)) 69 | return np.concatenate(patches,axis=0) 70 | 71 | # patches (B, H, W, C) -> image (H, W, C) 72 | def splice_patches2image(patches, image_size, overlap=0): 73 | assert len(image_size) > 1 74 | assert patches.shape[-3] == patches.shape[-2] 75 | H = image_size[0] 76 | W = image_size[1] 77 | patch_size = patches.shape[-2]-overlap 78 | image = np.zeros(image_size) 79 | idx = 0 80 | for h in range(H//patch_size): 81 | for w in range(W//patch_size): 82 | image[h*patch_size:(h+1)*patch_size,w*patch_size:(w+1)*patch_size,:] = patches[idx,overlap:patch_size+overlap,overlap:patch_size+overlap,:] 83 | idx += 1 84 | return image 85 | 86 | def inverse_augmentation(image, mode): 87 | if mode == 0: 88 | out = image 89 | elif mode == 1: 90 | out = np.flipud(image) 91 | elif mode == 2: 92 | out = np.rot90(image,k=-1) 93 | elif mode == 3: 94 | out = np.flipud(image) 95 | out = np.rot90(out, k=-1) 96 | elif mode == 4: 97 | out = np.rot90(image, k=-2) 98 | elif mode == 5: 99 | out = np.flipud(image) 100 | out = np.rot90(out, k=-2) 101 | elif mode == 6: 102 | out = np.rot90(image, k=-3) 103 | elif mode == 7: 104 | out = np.flipud(image) 105 | out = np.rot90(out, k=-3) 106 | else: 107 | raise Exception('Invalid choice of image transformation') 108 | return out 109 | 110 | def data_augmentation(image, mode): 111 | if mode == 0: 112 | # original 113 | out = image 114 | elif mode == 1: 115 | # flip up and down 116 | out = np.flipud(image) 117 | elif mode == 2: 118 | # rotate counterwise 90 degree 119 | out = np.rot90(image) 120 | elif mode == 3: 121 | # rotate 90 degree and flip up and down 122 | out = np.rot90(image) 123 | out = np.flipud(out) 124 | elif mode == 4: 125 | # rotate 180 degree 126 | out = np.rot90(image, k=2) 127 | elif mode == 5: 128 | # rotate 180 degree and flip 129 | out = np.rot90(image, k=2) 130 | out = np.flipud(out) 131 | elif mode == 6: 132 | # rotate 270 degree 133 | out = np.rot90(image, k=3) 134 | elif mode == 7: 135 | # rotate 270 degree and flip 136 | out = np.rot90(image, k=3) 137 | out = np.flipud(out) 138 | else: 139 | raise Exception('Invalid choice of image transformation') 140 | return out 141 | 142 | def random_augmentation(*args): 143 | out = [] 144 | flag_aug = random.randint(1,7) 145 | for data in args: 146 | out.append(data_augmentation(data, flag_aug).copy()) 147 | return out 148 | 149 | def save(model, model_path): 150 | torch.save({'state_dict':model.state_dict(), 'arch_param': model.arch_parameters()[0]}, model_path) 151 | 152 | # pred and tar (B, C, H, W) 153 | def compute_psnr(pred, tar): 154 | assert pred.shape == tar.shape 155 | pred = pred.transpose(0,2,3,1) 156 | tar = tar.transpose(0,2,3,1) 157 | psnr = 0 158 | for i in range(pred.shape[0]): 159 | psnr += compare_psnr(tar[i],pred[i]) 160 | return psnr/pred.shape[0] -------------------------------------------------------------------------------- /deraining/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import math 4 | import torch 5 | import random 6 | import shutil 7 | import torchvision.transforms as transforms 8 | import torch.nn.functional as F 9 | from skimage.metrics import peak_signal_noise_ratio as compare_psnr 10 | from torch.nn import init 11 | 12 | class AvgrageMeter(object): 13 | def __init__(self): 14 | self.reset() 15 | 16 | def reset(self): 17 | self.avg = 0 18 | self.sum = 0 19 | self.cnt = 0 20 | 21 | def update(self, val, n=1): 22 | self.sum += val * n 23 | self.cnt += n 24 | self.avg = self.sum / self.cnt 25 | 26 | def count_parameters_in_MB(model): 27 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 28 | 29 | def create_exp_dir(path, scripts_to_save=None): 30 | if not os.path.exists(path): 31 | os.mkdir(path) 32 | print('Experiment dir : {}'.format(path)) 33 | 34 | if scripts_to_save is not None: 35 | os.mkdir(os.path.join(path, 'scripts')) 36 | for script in scripts_to_save: 37 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 38 | shutil.copyfile(script, dst_file) 39 | 40 | # randomly crop a patch from image 41 | def crop_patch(im, pch_size): 42 | H = im.shape[0] 43 | W = im.shape[1] 44 | ind_H = random.randint(0, H-pch_size) 45 | ind_W = random.randint(0, W-pch_size) 46 | pch = im[ind_H:ind_H+pch_size, ind_W:ind_W+pch_size] 47 | return pch 48 | 49 | # crop an image to the multiple of base 50 | def crop_img(image, base=64): 51 | h = image.shape[0] 52 | w = image.shape[1] 53 | crop_h = h % base 54 | crop_w = w % base 55 | return image[crop_h//2:h-crop_h+crop_h//2, crop_w//2:w-crop_w+crop_w//2, :] 56 | 57 | # image (H, W, C) -> patches (B, H, W, C) 58 | def slice_image2patches(image, patch_size=64, overlap=0): 59 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 60 | H = image.shape[0] 61 | W = image.shape[1] 62 | patches = [] 63 | image_padding = np.pad(image,((overlap,overlap),(overlap,overlap),(0,0)),mode='edge') 64 | for h in range(H//patch_size): 65 | for w in range(W//patch_size): 66 | idx_h = [h*patch_size,(h+1)*patch_size+overlap] 67 | idx_w = [w*patch_size,(w+1)*patch_size+overlap] 68 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1],idx_w[0]:idx_w[1],:],axis=0)) 69 | return np.concatenate(patches,axis=0) 70 | 71 | # patches (B, H, W, C) -> image (H, W, C) 72 | def splice_patches2image(patches, image_size, overlap=0): 73 | assert len(image_size) > 1 74 | assert patches.shape[-3] == patches.shape[-2] 75 | H = image_size[0] 76 | W = image_size[1] 77 | patch_size = patches.shape[-2]-overlap 78 | image = np.zeros(image_size) 79 | idx = 0 80 | for h in range(H//patch_size): 81 | for w in range(W//patch_size): 82 | image[h*patch_size:(h+1)*patch_size,w*patch_size:(w+1)*patch_size,:] = patches[idx,overlap:patch_size+overlap,overlap:patch_size+overlap,:] 83 | idx += 1 84 | return image 85 | 86 | def inverse_augmentation(image, mode): 87 | if mode == 0: 88 | out = image 89 | elif mode == 1: 90 | out = np.flipud(image) 91 | elif mode == 2: 92 | out = np.rot90(image,k=-1) 93 | elif mode == 3: 94 | out = np.flipud(image) 95 | out = np.rot90(out, k=-1) 96 | elif mode == 4: 97 | out = np.rot90(image, k=-2) 98 | elif mode == 5: 99 | out = np.flipud(image) 100 | out = np.rot90(out, k=-2) 101 | elif mode == 6: 102 | out = np.rot90(image, k=-3) 103 | elif mode == 7: 104 | out = np.flipud(image) 105 | out = np.rot90(out, k=-3) 106 | else: 107 | raise Exception('Invalid choice of image transformation') 108 | return out 109 | 110 | def data_augmentation(image, mode): 111 | if mode == 0: 112 | # original 113 | out = image 114 | elif mode == 1: 115 | # flip up and down 116 | out = np.flipud(image) 117 | elif mode == 2: 118 | # rotate counterwise 90 degree 119 | out = np.rot90(image) 120 | elif mode == 3: 121 | # rotate 90 degree and flip up and down 122 | out = np.rot90(image) 123 | out = np.flipud(out) 124 | elif mode == 4: 125 | # rotate 180 degree 126 | out = np.rot90(image, k=2) 127 | elif mode == 5: 128 | # rotate 180 degree and flip 129 | out = np.rot90(image, k=2) 130 | out = np.flipud(out) 131 | elif mode == 6: 132 | # rotate 270 degree 133 | out = np.rot90(image, k=3) 134 | elif mode == 7: 135 | # rotate 270 degree and flip 136 | out = np.rot90(image, k=3) 137 | out = np.flipud(out) 138 | else: 139 | raise Exception('Invalid choice of image transformation') 140 | return out 141 | 142 | def random_augmentation(*args): 143 | out = [] 144 | flag_aug = random.randint(1,7) 145 | for data in args: 146 | out.append(data_augmentation(data, flag_aug).copy()) 147 | return out 148 | 149 | def save(model, model_path): 150 | torch.save({'state_dict':model.state_dict(), 'arch_param': model.arch_parameters()[0]}, model_path) 151 | 152 | # pred and tar (B, C, H, W) 153 | def compute_psnr(pred, tar): 154 | assert pred.shape == tar.shape 155 | pred = pred.transpose(0,2,3,1) 156 | tar = tar.transpose(0,2,3,1) 157 | psnr = 0 158 | for i in range(pred.shape[0]): 159 | psnr += compare_psnr(tar[i],pred[i]) 160 | return psnr/pred.shape[0] -------------------------------------------------------------------------------- /deraining/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import math 6 | import numpy as np 7 | import torch 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | 12 | import utils 13 | import pytorch_ssim 14 | from search_space import HighResolutionNet 15 | from datasets import Rain800 16 | 17 | parser = argparse.ArgumentParser("deraining") 18 | parser.add_argument('--data', type=str, default='./datasets/rain800/', help='location of the data') 19 | parser.add_argument('--epochs', type=int, default=2000, help='num of training epochs') 20 | parser.add_argument('--steps', type=int, default=100, help='steps of each epoch') 21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 22 | parser.add_argument('--learning_rate', type=float, default=2e-3, help='init learning rate') 23 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 24 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 25 | parser.add_argument('--patch_size', type=int, default=64, help='patch size') 26 | parser.add_argument('--gpu', type=str, default='3', help='gpu device ids') 27 | parser.add_argument('--ckt_path', type=str, default='search-EXP-20201112-155305/last_weights.pt', help='checkpoint path of search') 28 | args = parser.parse_args() 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 31 | 32 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 33 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 34 | log_format = '%(asctime)s %(message)s' 35 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 36 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 37 | fh.setFormatter(logging.Formatter(log_format)) 38 | logging.getLogger().addHandler(fh) 39 | 40 | MSELoss = torch.nn.MSELoss().cuda() 41 | 42 | def main(): 43 | if not torch.cuda.is_available(): 44 | logging.info('no gpu device available') 45 | sys.exit(1) 46 | 47 | logging.info("args = %s", args) 48 | 49 | torch.backends.cudnn.benchmark = True 50 | torch.backends.cudnn.enabled=True 51 | 52 | model = HighResolutionNet(ckt_path=args.ckt_path) 53 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 54 | model.cuda() 55 | 56 | optimizer = torch.optim.Adam(model.parameters(), args.learning_rate) 57 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) 58 | 59 | train_samples = Rain800(args.data+'training/', args.steps*args.batch_size, args.patch_size) 60 | train_queue = torch.utils.data.DataLoader(train_samples, batch_size=args.batch_size, pin_memory=True) 61 | val_samples = Rain800(args.data+'test_syn/', 50*args.batch_size, args.patch_size) 62 | valid_queue = torch.utils.data.DataLoader(val_samples, batch_size=args.batch_size, pin_memory=True) 63 | 64 | best_psnr = 0 65 | best_psnr_epoch = 0 66 | best_ssim = 0 67 | best_ssim_epoch = 0 68 | best_loss = float("inf") 69 | best_loss_epoch = 0 70 | for epoch in range(args.epochs): 71 | logging.info('epoch %d/%d lr %e', epoch+1, args.epochs, scheduler.get_lr()[0]) 72 | 73 | # training 74 | train(train_queue, model, optimizer) 75 | # validation 76 | psnr, ssim, loss = infer(valid_queue, model) 77 | 78 | if psnr > best_psnr and not math.isinf(psnr): 79 | torch.save(model, os.path.join(args.save, 'best_psnr_weights.pt')) 80 | best_psnr_epoch = epoch+1 81 | best_psnr = psnr 82 | if ssim > best_ssim: 83 | torch.save(model, os.path.join(args.save, 'best_ssim_weights.pt')) 84 | best_ssim_epoch = epoch+1 85 | best_ssim = ssim 86 | if loss < best_loss: 87 | torch.save(model, os.path.join(args.save, 'best_loss_weights.pt')) 88 | best_loss_epoch = epoch+1 89 | best_loss = loss 90 | 91 | scheduler.step() 92 | logging.info('psnr:%6f ssim:%6f loss:%6f -- best_psnr:%6f best_ssim:%6f best_loss:%6f', psnr, ssim, loss, best_psnr, best_ssim, best_loss) 93 | 94 | logging.info('BEST_LOSS(epoch):%6f(%d), BEST_PSNR(epoch):%6f(%d), BEST_SSIM(epoch):%6f(%d)', best_loss, best_loss_epoch, best_psnr, best_psnr_epoch, best_ssim, best_ssim_epoch) 95 | torch.save(model, os.path.join(args.save, 'last_weights.pt')) 96 | 97 | 98 | def train(train_queue, model, optimizer): 99 | for step, (input, target) in enumerate(train_queue): 100 | print('--steps:%d--' % step) 101 | model.train() 102 | input = input.cuda() 103 | target = target.cuda() 104 | optimizer.zero_grad() 105 | logits = model(input) 106 | loss = MSELoss(logits, target) 107 | loss.backward() 108 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 109 | optimizer.step() 110 | 111 | def infer(valid_queue, model): 112 | psnr = utils.AvgrageMeter() 113 | ssim = utils.AvgrageMeter() 114 | loss = utils.AvgrageMeter() 115 | 116 | model.eval() 117 | with torch.no_grad(): 118 | for _, (input, target) in enumerate(valid_queue): 119 | input = input.cuda() 120 | target = target.cuda() 121 | logits = model(input) 122 | 123 | l = MSELoss(logits, target) 124 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 125 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 126 | n = input.size(0) 127 | psnr.update(p, n) 128 | ssim.update(s, n) 129 | loss.update(l, n) 130 | 131 | return psnr.avg, ssim.avg, loss.avg 132 | 133 | if __name__ == '__main__': 134 | main() 135 | -------------------------------------------------------------------------------- /denoising/train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import math 6 | import numpy as np 7 | import torch 8 | import logging 9 | import argparse 10 | import torch.nn as nn 11 | 12 | import utils 13 | import pytorch_ssim 14 | from search_space import HighResolutionNet 15 | from datasets import PatchSimulateDataset 16 | 17 | parser = argparse.ArgumentParser("Denoising") 18 | parser.add_argument('--data', type=str, default='./datasets/BSD500/', help='location of the data corpus') 19 | parser.add_argument('--epochs', type=int, default=1000, help='num of training epochs') 20 | parser.add_argument('--steps', type=int, default=100, help='steps of each epoch') 21 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 22 | parser.add_argument('--learning_rate', type=float, default=1e-2, help='init learning rate') 23 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 24 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 25 | parser.add_argument('--patch_size', type=int, default=64, help='patch size') 26 | parser.add_argument('--gpu', type=str, default='1', help='gpu device ids') 27 | parser.add_argument('--sigma', type=int, default=30, help='noise level') 28 | parser.add_argument('--ckt_path', type=str, default='search-EXP-20201112-143409/last_weights.pt', help='checkpoint path of search') 29 | args = parser.parse_args() 30 | 31 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 32 | 33 | args.save = 'eval-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 34 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 35 | log_format = '%(asctime)s %(message)s' 36 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 37 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 38 | fh.setFormatter(logging.Formatter(log_format)) 39 | logging.getLogger().addHandler(fh) 40 | 41 | MSELoss = torch.nn.MSELoss().cuda() 42 | 43 | def main(): 44 | if not torch.cuda.is_available(): 45 | logging.info('no gpu device available') 46 | sys.exit(1) 47 | 48 | logging.info("args = %s", args) 49 | 50 | torch.backends.cudnn.benchmark = True 51 | torch.backends.cudnn.enabled=True 52 | 53 | model = HighResolutionNet(ckt_path=args.ckt_path) 54 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 55 | model.cuda() 56 | 57 | optimizer = torch.optim.Adam(model.parameters(), args.learning_rate) 58 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs)) 59 | 60 | train_samples = PatchSimulateDataset(args.data+'train/', args.steps*args.batch_size, args.patch_size, sigma=args.sigma) 61 | val_samples = PatchSimulateDataset(args.data+'test/', 50*args.batch_size, args.patch_size, sigma=args.sigma) 62 | 63 | train_queue = torch.utils.data.DataLoader(train_samples, batch_size=args.batch_size, pin_memory=True) 64 | valid_queue = torch.utils.data.DataLoader(val_samples, batch_size=args.batch_size, pin_memory=True) 65 | 66 | best_psnr = 0 67 | best_psnr_epoch = 0 68 | best_ssim = 0 69 | best_ssim_epoch = 0 70 | best_loss = float("inf") 71 | best_loss_epoch = 0 72 | for epoch in range(args.epochs): 73 | logging.info('epoch %d/%d lr %e', epoch+1, args.epochs, scheduler.get_lr()[0]) 74 | 75 | # training 76 | train(train_queue, model, optimizer) 77 | # validation 78 | psnr, ssim, loss = infer(valid_queue, model) 79 | 80 | if psnr > best_psnr and not math.isinf(psnr): 81 | torch.save(model, os.path.join(args.save, 'best_psnr_weights.pt')) 82 | best_psnr_epoch = epoch+1 83 | best_psnr = psnr 84 | if ssim > best_ssim: 85 | torch.save(model, os.path.join(args.save, 'best_ssim_weights.pt')) 86 | best_ssim_epoch = epoch+1 87 | best_ssim = ssim 88 | if loss < best_loss: 89 | torch.save(model, os.path.join(args.save, 'best_loss_weights.pt')) 90 | best_loss_epoch = epoch+1 91 | best_loss = loss 92 | 93 | scheduler.step() 94 | logging.info('psnr:%6f ssim:%6f loss:%6f -- best_psnr:%6f best_ssim:%6f best_loss:%6f', psnr, ssim, loss, best_psnr, best_ssim, best_loss) 95 | 96 | logging.info('BEST_LOSS(epoch):%6f(%d), BEST_PSNR(epoch):%6f(%d), BEST_SSIM(epoch):%6f(%d)', best_loss, best_loss_epoch, best_psnr, best_psnr_epoch, best_ssim, best_ssim_epoch) 97 | torch.save(model, os.path.join(args.save, 'last_weights.pt')) 98 | 99 | 100 | def train(train_queue, model, optimizer): 101 | for step, (input, target) in enumerate(train_queue): 102 | print('--steps:%d--' % step) 103 | model.train() 104 | input = input.cuda() 105 | target = target.cuda() 106 | optimizer.zero_grad() 107 | logits = model(input) 108 | loss = MSELoss(logits, target) 109 | loss.backward() 110 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 111 | optimizer.step() 112 | 113 | def infer(valid_queue, model): 114 | psnr = utils.AvgrageMeter() 115 | ssim = utils.AvgrageMeter() 116 | loss = utils.AvgrageMeter() 117 | 118 | model.eval() 119 | with torch.no_grad(): 120 | for _, (input, target) in enumerate(valid_queue): 121 | input = input.cuda() 122 | target = target.cuda() 123 | logits = model(input) 124 | l = MSELoss(logits, target) 125 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 126 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 127 | n = input.size(0) 128 | psnr.update(p, n) 129 | ssim.update(s, n) 130 | loss.update(l, n) 131 | return psnr.avg, ssim.avg, loss.avg 132 | 133 | if __name__ == '__main__': 134 | main() -------------------------------------------------------------------------------- /deraining/train_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import torch 7 | import logging 8 | import argparse 9 | import torch.nn as nn 10 | import math 11 | 12 | import utils 13 | import pytorch_ssim 14 | from search_space import SearchSpace 15 | from architect import Architect 16 | from datasets import Rain800 17 | import pytorch_ssim 18 | 19 | parser = argparse.ArgumentParser("deraining") 20 | parser.add_argument('--data', type=str, default='./datasets/rain800/', help='location of the data') 21 | parser.add_argument('--epochs', type=int, default=200, help='num of training epochs') 22 | parser.add_argument('--steps', type=int, default=50, help='steps of each epoch') 23 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 24 | parser.add_argument('--patch_size', type=int, default=64, help='patch size') 25 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') 26 | parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') 27 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 28 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 29 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 30 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 31 | parser.add_argument('--unrolled', type=bool, default=False, help='use one-step unrolled validation loss') 32 | parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch encoding') 33 | parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch encoding') 34 | parser.add_argument('--gpu', type=str, default='1', help='gpu device ids') 35 | args = parser.parse_args() 36 | 37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 38 | 39 | args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 40 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 41 | log_format = '%(asctime)s %(message)s' 42 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 43 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 44 | fh.setFormatter(logging.Formatter(log_format)) 45 | logging.getLogger().addHandler(fh) 46 | 47 | def main(): 48 | if not torch.cuda.is_available(): 49 | logging.info('no gpu device available') 50 | sys.exit(1) 51 | 52 | logging.info("args = %s", args) 53 | 54 | torch.backends.cudnn.benchmark = True 55 | torch.backends.cudnn.enabled=True 56 | 57 | model = SearchSpace() 58 | model.cuda() 59 | 60 | optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 61 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min) 62 | 63 | architect = Architect(model, args) 64 | 65 | train_samples = Rain800(args.data+'training/', args.steps*args.batch_size, args.patch_size) 66 | train_queue = torch.utils.data.DataLoader(train_samples, batch_size=args.batch_size, pin_memory=True) 67 | val_samples = Rain800(args.data+'test_syn/', 30*args.batch_size, args.patch_size) 68 | valid_queue = torch.utils.data.DataLoader(val_samples, batch_size=args.batch_size, pin_memory=True) 69 | 70 | best_psnr = 0 71 | best_psnr_epoch = 0 72 | best_ssim = 0 73 | best_ssim_epoch = 0 74 | best_loss = float("inf") 75 | best_loss_epoch = 0 76 | for epoch in range(args.epochs): 77 | lr = scheduler.get_lr()[0] 78 | logging.info('epoch %d/%d lr %e', epoch+1, args.epochs, lr) 79 | 80 | # training 81 | train(epoch, train_queue, valid_queue, model, architect, optimizer, lr) 82 | # validation 83 | psnr, ssim, loss = infer(valid_queue, model) 84 | 85 | if psnr > best_psnr and not math.isinf(psnr): 86 | utils.save(model, os.path.join(args.save, 'best_psnr_weights.pt')) 87 | best_psnr_epoch = epoch+1 88 | best_psnr = psnr 89 | if ssim > best_ssim: 90 | utils.save(model, os.path.join(args.save, 'best_ssim_weights.pt')) 91 | best_ssim_epoch = epoch+1 92 | best_ssim = ssim 93 | if loss < best_loss: 94 | utils.save(model, os.path.join(args.save, 'best_loss_weights.pt')) 95 | best_loss_epoch = epoch+1 96 | best_loss = loss 97 | 98 | scheduler.step() 99 | logging.info('psnr:%6f ssim:%6f loss:%6f -- best_psnr:%6f best_ssim:%6f best_loss:%6f', psnr, ssim, loss, best_psnr, best_ssim, best_loss) 100 | logging.info('arch:%s', torch.argmax(model.arch_parameters()[0], dim=1)) 101 | 102 | logging.info('BEST_LOSS(epoch):%6f(%d), BEST_PSNR(epoch):%6f(%d), BEST_SSIM(epoch):%6f(%d)', best_loss, best_loss_epoch, best_psnr, best_psnr_epoch, best_ssim, best_ssim_epoch) 103 | utils.save(model, os.path.join(args.save, 'last_weights.pt')) 104 | 105 | def train(epoch, train_queue, valid_queue, model, architect, optimizer, lr): 106 | for step, (input, target) in enumerate(train_queue): 107 | print('--steps:%d--' % step) 108 | 109 | input = input.cuda() 110 | target = target.cuda() 111 | 112 | if (epoch+1) > 10: 113 | model.eval() 114 | input_search, target_search = next(iter(valid_queue)) 115 | input_search = input_search.cuda() 116 | target_search = target_search.cuda() 117 | architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled) 118 | 119 | model.train() 120 | optimizer.zero_grad() 121 | loss = model.loss(input, target) 122 | loss.backward() 123 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 124 | optimizer.step() 125 | 126 | MSELoss = torch.nn.MSELoss().cuda() 127 | def infer(valid_queue, model): 128 | psnr = utils.AvgrageMeter() 129 | ssim = utils.AvgrageMeter() 130 | loss = utils.AvgrageMeter() 131 | model.eval() 132 | 133 | with torch.no_grad(): 134 | for _, (input, target) in enumerate(valid_queue): 135 | input = input.cuda() 136 | target = target.cuda() 137 | logits = model(input) 138 | 139 | l = MSELoss(logits, target) 140 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 141 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 142 | n = input.size(0) 143 | psnr.update(p, n) 144 | ssim.update(s, n) 145 | loss.update(l, n) 146 | 147 | return psnr.avg, ssim.avg, loss.avg 148 | 149 | if __name__ == '__main__': 150 | main() -------------------------------------------------------------------------------- /denoising/train_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import glob 5 | import numpy as np 6 | import torch 7 | import logging 8 | import argparse 9 | import torch.nn as nn 10 | import math 11 | 12 | import utils 13 | import pytorch_ssim 14 | from search_space import SearchSpace 15 | from architect import Architect 16 | from datasets import PatchSimulateDataset 17 | 18 | parser = argparse.ArgumentParser("denoising") 19 | parser.add_argument('--data', type=str, default='./datasets/BSD500/', help='location of the data') 20 | parser.add_argument('--epochs', type=int, default=200, help='num of training epochs') 21 | parser.add_argument('--steps', type=int, default=50, help='steps of each epoch') 22 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 23 | parser.add_argument('--patch_size', type=int, default=64, help='patch size') 24 | parser.add_argument('--learning_rate', type=float, default=0.025, help='init learning rate') 25 | parser.add_argument('--learning_rate_min', type=float, default=0.001, help='min learning rate') 26 | parser.add_argument('--momentum', type=float, default=0.9, help='momentum') 27 | parser.add_argument('--weight_decay', type=float, default=3e-4, help='weight decay') 28 | parser.add_argument('--save', type=str, default='EXP', help='experiment name') 29 | parser.add_argument('--grad_clip', type=float, default=5, help='gradient clipping') 30 | parser.add_argument('--unrolled', type=bool, default=False, help='use one-step unrolled validation loss') 31 | parser.add_argument('--arch_learning_rate', type=float, default=3e-4, help='learning rate for arch param') 32 | parser.add_argument('--arch_weight_decay', type=float, default=1e-3, help='weight decay for arch param') 33 | parser.add_argument('--gpu', type=str, default='1', help='gpu device ids') 34 | parser.add_argument('--sigma', type=int, default=30, help='noise level') 35 | args = parser.parse_args() 36 | 37 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 38 | 39 | args.save = 'search-{}-{}'.format(args.save, time.strftime("%Y%m%d-%H%M%S")) 40 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 41 | log_format = '%(asctime)s %(message)s' 42 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, format=log_format, datefmt='%m/%d %I:%M:%S %p') 43 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 44 | fh.setFormatter(logging.Formatter(log_format)) 45 | logging.getLogger().addHandler(fh) 46 | 47 | def main(): 48 | if not torch.cuda.is_available(): 49 | logging.info('no gpu device available') 50 | sys.exit(1) 51 | 52 | logging.info("args = %s", args) 53 | 54 | torch.backends.cudnn.benchmark = True 55 | torch.backends.cudnn.enabled=True 56 | 57 | model = SearchSpace() 58 | model.cuda() 59 | 60 | optimizer = torch.optim.SGD(model.parameters(), args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) 61 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=args.learning_rate_min) 62 | 63 | architect = Architect(model, args) 64 | 65 | train_samples = PatchSimulateDataset(args.data+'train/', args.steps*args.batch_size, args.patch_size, sigma=args.sigma) 66 | train_queue = torch.utils.data.DataLoader(train_samples, batch_size=args.batch_size, pin_memory=True) 67 | val_samples = PatchSimulateDataset(args.data+'val/', 30*args.batch_size, args.patch_size, sigma=args.sigma) 68 | valid_queue = torch.utils.data.DataLoader(val_samples, batch_size=args.batch_size, pin_memory=True) 69 | 70 | best_psnr = 0 71 | best_psnr_epoch = 0 72 | best_ssim = 0 73 | best_ssim_epoch = 0 74 | best_loss = float("inf") 75 | best_loss_epoch = 0 76 | for epoch in range(args.epochs): 77 | lr = scheduler.get_lr()[0] 78 | logging.info('epoch %d/%d lr %e', epoch+1, args.epochs, lr) 79 | 80 | # training 81 | train(epoch, train_queue, valid_queue, model, architect, optimizer, lr) 82 | # validation 83 | psnr, ssim, loss = infer(valid_queue, model) 84 | 85 | if psnr > best_psnr and not math.isinf(psnr): 86 | utils.save(model, os.path.join(args.save, 'best_psnr_weights.pt')) 87 | best_psnr_epoch = epoch+1 88 | best_psnr = psnr 89 | if ssim > best_ssim: 90 | utils.save(model, os.path.join(args.save, 'best_ssim_weights.pt')) 91 | best_ssim_epoch = epoch+1 92 | best_ssim = ssim 93 | if loss < best_loss: 94 | utils.save(model, os.path.join(args.save, 'best_loss_weights.pt')) 95 | best_loss_epoch = epoch+1 96 | best_loss = loss 97 | 98 | scheduler.step() 99 | logging.info('psnr:%6f ssim:%6f loss:%6f -- best_psnr:%6f best_ssim:%6f best_loss:%6f', psnr, ssim, loss, best_psnr, best_ssim, best_loss) 100 | logging.info('arch:%s', torch.argmax(model.arch_parameters()[0], dim=1)) 101 | 102 | logging.info('BEST_LOSS(epoch):%6f(%d), BEST_PSNR(epoch):%6f(%d), BEST_SSIM(epoch):%6f(%d)', best_loss, best_loss_epoch, best_psnr, best_psnr_epoch, best_ssim, best_ssim_epoch) 103 | utils.save(model, os.path.join(args.save, 'last_weights.pt')) 104 | 105 | def train(epoch, train_queue, valid_queue, model, architect, optimizer, lr): 106 | for step, (input, target) in enumerate(train_queue): 107 | print('--steps:%d--' % step) 108 | 109 | input = input.cuda() 110 | target = target.cuda() 111 | if (epoch+1) > 10: 112 | model.eval() 113 | input_search, target_search = next(iter(valid_queue)) 114 | input_search = input_search.cuda() 115 | target_search = target_search.cuda() 116 | architect.step(input, target, input_search, target_search, lr, optimizer, unrolled=args.unrolled) 117 | 118 | model.train() 119 | optimizer.zero_grad() 120 | loss = model.loss(input, target) 121 | loss.backward() 122 | nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) 123 | optimizer.step() 124 | 125 | MSELoss = torch.nn.MSELoss().cuda() 126 | def infer(valid_queue, model): 127 | psnr = utils.AvgrageMeter() 128 | ssim = utils.AvgrageMeter() 129 | loss = utils.AvgrageMeter() 130 | model.eval() 131 | 132 | with torch.no_grad(): 133 | for _, (input, target) in enumerate(valid_queue): 134 | input = input.cuda() 135 | target = target.cuda() 136 | 137 | logits = model(input) 138 | l = MSELoss(logits, target) 139 | 140 | s = pytorch_ssim.ssim(torch.clamp(logits,0,1), target) 141 | p = utils.compute_psnr(np.clip(logits.detach().cpu().numpy(),0,1), target.detach().cpu().numpy()) 142 | n = input.size(0) 143 | psnr.update(p, n) 144 | ssim.update(s, n) 145 | loss.update(l, n) 146 | 147 | return psnr.avg, ssim.avg, loss.avg 148 | 149 | if __name__ == '__main__': 150 | main() -------------------------------------------------------------------------------- /denoising/search_space.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Bottleneck(nn.Module): 5 | expansion = 4 6 | def __init__(self, inplanes, planes, stride=1, downsample=None): 7 | super(Bottleneck, self).__init__() 8 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 9 | self.bn1 = nn.BatchNorm2d(planes) 10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn2 = nn.BatchNorm2d(planes) 12 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 13 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.downsample = downsample 16 | self.stride = stride 17 | 18 | def forward(self, x): 19 | residual = x 20 | out = self.conv1(x) 21 | out = self.bn1(out) 22 | out = self.relu(out) 23 | out = self.conv2(out) 24 | out = self.bn2(out) 25 | out = self.relu(out) 26 | out = self.conv3(out) 27 | out = self.bn3(out) 28 | if self.downsample is not None: 29 | residual = self.downsample(x) 30 | out += residual 31 | out = self.relu(out) 32 | return out 33 | 34 | class ResidualBlock(nn.Module): 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(ResidualBlock, self).__init__() 37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(planes, momentum=0.1) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.1) 42 | self.downsample = downsample 43 | 44 | def forward(self, x): 45 | residual = x 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | out += residual 54 | out = self.relu(out) 55 | return out 56 | 57 | # parallel module 58 | class HorizontalModule(nn.Module): 59 | def __init__(self, num_inchannels, num_blocks=None): 60 | self.num_branches = len(num_inchannels) 61 | assert self.num_branches > 1 62 | if num_blocks is None: 63 | num_blocks = [1 for _ in range(self.num_branches)] 64 | else: 65 | assert self.num_branches==len(num_blocks) 66 | 67 | super(HorizontalModule, self).__init__() 68 | 69 | self.branches = nn.ModuleList() 70 | for i in range(self.num_branches): 71 | layers=[] 72 | for _ in range(num_blocks[i]): 73 | layers.append(ResidualBlock(num_inchannels[i], num_inchannels[i])) 74 | self.branches.append(nn.Sequential(*layers)) 75 | 76 | def forward(self, x): 77 | for i in range(self.num_branches): 78 | x[i] = self.branches[i](x[i]) 79 | return x 80 | 81 | class FusionModule(nn.Module): 82 | def __init__(self, num_inchannels, multi_scale_output=True): 83 | super(FusionModule, self).__init__() 84 | self.num_branches = len(num_inchannels) 85 | self.multi_scale_output = multi_scale_output 86 | assert self.num_branches > 1 87 | 88 | self.relu = nn.ReLU(True) 89 | self.fuse_layers = nn.ModuleList() 90 | for i in range(self.num_branches if self.multi_scale_output else 1): 91 | fuse_layer = [] 92 | for j in range(self.num_branches): 93 | if j > i: 94 | # Unsample 95 | fuse_layer.append(nn.Sequential( 96 | nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), 97 | nn.BatchNorm2d(num_inchannels[i]), 98 | nn.Upsample(scale_factor=2**(j-i), mode='nearest') 99 | )) 100 | elif j == i: 101 | # identity 102 | fuse_layer.append(None) 103 | else: 104 | # downsample 105 | conv3x3s = [] 106 | for k in range(i-j): 107 | if k == i - j - 1: 108 | num_outchannels_conv3x3 = num_inchannels[i] 109 | conv3x3s.append(nn.Sequential( 110 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 111 | nn.BatchNorm2d(num_outchannels_conv3x3) 112 | )) 113 | else: 114 | num_outchannels_conv3x3 = num_inchannels[j] 115 | conv3x3s.append(nn.Sequential( 116 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 117 | nn.BatchNorm2d(num_outchannels_conv3x3), 118 | nn.ReLU(True) 119 | )) 120 | fuse_layer.append(nn.Sequential(*conv3x3s)) 121 | self.fuse_layers.append(nn.ModuleList(fuse_layer)) 122 | 123 | def forward(self, x): 124 | x_fuse = [] 125 | for i in range(len(self.fuse_layers)): 126 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 127 | for j in range(1, self.num_branches): 128 | if i == j: 129 | y = y + x[j] 130 | else: 131 | y = y + self.fuse_layers[i][j](x[j]) 132 | x_fuse.append(self.relu(y)) 133 | return x_fuse 134 | 135 | class TransitionModule(nn.Module): 136 | def __init__(self, num_channels_pre_layer, num_channels_cur_layer): 137 | super(TransitionModule, self).__init__() 138 | 139 | self.num_branches_cur = len(num_channels_cur_layer) 140 | self.num_branches_pre = len(num_channels_pre_layer) 141 | 142 | self.transition_layers = nn.ModuleList() 143 | for i in range(self.num_branches_cur): 144 | if i < self.num_branches_pre: 145 | # horizontal transition 146 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 147 | # alter channels 148 | self.transition_layers.append(nn.Sequential( 149 | nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), 150 | nn.BatchNorm2d(num_channels_cur_layer[i]), 151 | nn.ReLU(inplace=True) 152 | )) 153 | else: 154 | # no operation 155 | self.transition_layers.append(None) 156 | else: 157 | # downsample transition (new scales) 158 | conv3x3s = [] 159 | for j in range(i+1-self.num_branches_pre): 160 | inchannels = num_channels_pre_layer[-1] 161 | outchannels = num_channels_cur_layer[i] if j == i-self.num_branches_pre else inchannels 162 | conv3x3s.append(nn.Sequential( 163 | nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), 164 | nn.BatchNorm2d(outchannels), 165 | nn.ReLU(inplace=True) 166 | )) 167 | self.transition_layers.append(nn.Sequential(*conv3x3s)) 168 | 169 | def forward(self, x): 170 | x_list = [] 171 | for i in range(self.num_branches_cur): 172 | if self.transition_layers[i] is not None: 173 | if i < self.num_branches_pre: 174 | x_list.append(self.transition_layers[i](x[i])) 175 | else: 176 | x_list.append(self.transition_layers[i](x[-1])) 177 | else: 178 | x_list.append(x[i]) 179 | return x_list 180 | 181 | class MixModule(nn.Module): 182 | def __init__(self, num_inchannels): 183 | super(MixModule, self).__init__() 184 | self.horizontal = HorizontalModule(num_inchannels) 185 | self.fuse = FusionModule(num_inchannels) 186 | self.softmax = nn.Softmax(dim=0) 187 | 188 | def forward(self, x, weights): 189 | h_x = self.horizontal(x) 190 | f_x = self.fuse(x) 191 | 192 | weights = self.softmax(weights) 193 | 194 | x_list = [] 195 | for i in range(len(x)): 196 | x_list.append(weights[0]*h_x[i]+weights[1]*f_x[i]) 197 | 198 | return x_list 199 | 200 | 201 | class SearchSpace(nn.Module): 202 | # list_channels of branches; list_modules of stages 203 | def __init__(self, in_channel=1, list_channels=[32, 64, 128, 256], list_modules=[2,4,4,4]): 204 | super(SearchSpace,self).__init__() 205 | self.list_channels = list_channels 206 | self.list_modules = list_modules 207 | 208 | self.stage1 = self.__make_layer(in_channel, list_channels[0], list_modules[0]) 209 | 210 | self.transition1 = TransitionModule([list_channels[0]*Bottleneck.expansion],list_channels[:2]) 211 | self.stage2 = self.__make_stage(list_channels[:2],list_modules[1]) 212 | 213 | self.transition2 = TransitionModule(list_channels[:2],list_channels[:3]) 214 | self.stage3 = self.__make_stage(list_channels[:3],list_modules[2]) 215 | 216 | self.transition3 = TransitionModule(list_channels[:3],list_channels[:4]) 217 | self.stage4 = self.__make_stage(list_channels[:4],list_modules[3]) 218 | 219 | self.final_fuse = FusionModule(list_channels, multi_scale_output=False) 220 | self.final_conv = nn.Sequential( 221 | nn.Conv2d(list_channels[0], in_channel, 1, 1, bias=False) 222 | ) 223 | 224 | self.__init_architecture() 225 | self.init_weights() 226 | 227 | self.MSELoss = torch.nn.MSELoss().cuda() 228 | 229 | 230 | def __make_layer(self, inplanes, planes, blocks, stride=1): 231 | layers = [] 232 | downsample = None 233 | if stride != 1 or inplanes != planes * Bottleneck.expansion: 234 | downsample = nn.Sequential( 235 | nn.Conv2d(inplanes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(planes * Bottleneck.expansion), 237 | ) 238 | layers.append(Bottleneck(inplanes, planes, stride, downsample)) 239 | inplanes = planes * Bottleneck.expansion 240 | for _ in range(1, blocks): 241 | layers.append(Bottleneck(inplanes, planes)) 242 | return nn.Sequential(*layers) 243 | 244 | def __make_stage(self, num_inchannels, num_module): 245 | modules = nn.ModuleList() 246 | for _ in range(num_module): 247 | modules.append(MixModule(num_inchannels)) 248 | return modules 249 | 250 | def new(self): 251 | model_new = SearchSpace().cuda() 252 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): 253 | x.data.copy_(y.data) 254 | return model_new 255 | 256 | def __init_architecture(self): 257 | self.arch_param = torch.randn(sum(self.list_modules[1:]), 2, dtype=torch.float).cuda()*1e-3 258 | self.arch_param.requires_grad = True 259 | 260 | def arch_parameters(self): 261 | return [self.arch_param] 262 | 263 | def init_weights(self): 264 | print('=> init weights from normal distribution') 265 | for m in self.modules(): 266 | if isinstance(m, nn.Conv2d): 267 | nn.init.normal_(m.weight, std=0.001) 268 | for name, _ in m.named_parameters(): 269 | if name in ['bias']: 270 | nn.init.constant_(m.bias, 0) 271 | elif isinstance(m, nn.BatchNorm2d): 272 | nn.init.constant_(m.weight, 1) 273 | nn.init.constant_(m.bias, 0) 274 | elif isinstance(m, nn.ConvTranspose2d): 275 | nn.init.normal_(m.weight, std=0.001) 276 | for name, _ in m.named_parameters(): 277 | if name in ['bias']: 278 | nn.init.constant_(m.bias, 0) 279 | 280 | def forward(self, x): 281 | x = self.stage1(x) 282 | 283 | x_list = self.transition1([x]) 284 | idx = 0 285 | for i in range(self.list_modules[1]): 286 | x_list = self.stage2[i](x_list, self.arch_param[idx+i]) 287 | #print([idx+i, self.arch_param[idx+i]]) 288 | 289 | x_list = self.transition2(x_list) 290 | idx = idx + self.list_modules[1] 291 | for i in range(self.list_modules[2]): 292 | x_list = self.stage3[i](x_list,self.arch_param[idx+i]) 293 | #print([idx+i, self.arch_param[idx+i]]) 294 | 295 | x_list = self.transition3(x_list) 296 | idx = idx + self.list_modules[2] 297 | for i in range(self.list_modules[3]): 298 | x_list = self.stage4[i](x_list,self.arch_param[idx+i]) 299 | #print([idx+i, self.arch_param[idx+i]]) 300 | 301 | x = self.final_fuse(x_list)[0] 302 | x = self.final_conv(x) 303 | return x 304 | 305 | def loss(self, input, target): 306 | logits = self(input) 307 | mse_loss = self.MSELoss(logits, target) 308 | arch_weights = torch.nn.functional.softmax(self.arch_param, dim=1) 309 | 310 | # regular_loss 311 | regular_loss = -arch_weights*torch.log10(arch_weights)-(1-arch_weights)*torch.log10(1-arch_weights) 312 | # latency loss: computing the average complexity (params num) of modules by running `python search_space.py` 313 | # latency_loss = torch.mean(arch_weights[:,0]*0.7+arch_weights[:,1]*0.3) 314 | 315 | return mse_loss + regular_loss.mean()*0.01 #+ latency_loss*0.1 316 | 317 | # multi-resolution network 318 | class HighResolutionNet(nn.Module): 319 | def __init__(self, ckt_path, in_channel=1, list_channels=[32, 64, 128, 256], list_modules=[2,4,4,4]): 320 | super(HighResolutionNet, self).__init__() 321 | self.list_channels = list_channels 322 | self.list_modules = list_modules 323 | 324 | arch = torch.load(ckt_path)['arch_param'] 325 | softmax = nn.Softmax(dim=1) 326 | arch_soft = softmax(arch) 327 | arch_hard = torch.argmax(arch_soft, dim=1) 328 | 329 | self.stage1 = self.__make_layer(in_channel, list_channels[0], list_modules[0]) 330 | 331 | self.transition1 = TransitionModule([list_channels[0]*Bottleneck.expansion],list_channels[:2]) 332 | idx = 0 333 | self.stage2 = self.__make_stage(arch_hard[idx:list_modules[1]], list_channels[:2], list_modules[1]) 334 | 335 | self.transition2 = TransitionModule(list_channels[:2],list_channels[:3]) 336 | idx += list_modules[1] 337 | self.stage3 = self.__make_stage(arch_hard[idx:idx+list_modules[2]], list_channels[:3],list_modules[2]) 338 | 339 | self.transition3 = TransitionModule(list_channels[:3],list_channels[:4]) 340 | idx += list_modules[2] 341 | self.stage4 = self.__make_stage(arch_hard[idx:idx+list_modules[3]], list_channels[:4],list_modules[3]) 342 | 343 | self.final_fuse = FusionModule(list_channels, multi_scale_output=False) 344 | self.final_conv = nn.Sequential( 345 | nn.Conv2d(list_channels[0], in_channel, 1, 1, bias=False) 346 | ) 347 | 348 | self.init_weights() 349 | 350 | def __make_layer(self, inplanes, planes, blocks, stride=1): 351 | layers = [] 352 | downsample = None 353 | if stride != 1 or inplanes != planes * Bottleneck.expansion: 354 | downsample = nn.Sequential( 355 | nn.Conv2d(inplanes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False), 356 | nn.BatchNorm2d(planes * Bottleneck.expansion), 357 | ) 358 | layers.append(Bottleneck(inplanes, planes, stride, downsample)) 359 | inplanes = planes * Bottleneck.expansion 360 | for _ in range(1, blocks): 361 | layers.append(Bottleneck(inplanes, planes)) 362 | return nn.Sequential(*layers) 363 | 364 | def __make_stage(self, arch, num_inchannels, num_module): 365 | modules = [] 366 | for i in range(num_module): 367 | if arch[i] == 0: 368 | modules.append(HorizontalModule(num_inchannels)) 369 | elif arch[i] == 1: 370 | modules.append(FusionModule(num_inchannels)) 371 | return nn.Sequential(*modules) 372 | 373 | def init_weights(self): 374 | print('=> init weights from normal distribution') 375 | for m in self.modules(): 376 | if isinstance(m, nn.Conv2d): 377 | nn.init.normal_(m.weight, std=0.001) 378 | for name, _ in m.named_parameters(): 379 | if name in ['bias']: 380 | nn.init.constant_(m.bias, 0) 381 | elif isinstance(m, nn.BatchNorm2d): 382 | nn.init.constant_(m.weight, 1) 383 | nn.init.constant_(m.bias, 0) 384 | elif isinstance(m, nn.ConvTranspose2d): 385 | nn.init.normal_(m.weight, std=0.001) 386 | for name, _ in m.named_parameters(): 387 | if name in ['bias']: 388 | nn.init.constant_(m.bias, 0) 389 | 390 | def forward(self, x): 391 | x = self.stage1(x) 392 | x_list = self.transition1([x]) 393 | x_list = self.stage2(x_list) 394 | x_list = self.transition2(x_list) 395 | x_list = self.stage3(x_list) 396 | x_list = self.transition3(x_list) 397 | x_list = self.stage4(x_list) 398 | x = self.final_fuse(x_list)[0] 399 | x = self.final_conv(x) 400 | return x 401 | 402 | if __name__=='__main__': 403 | ''' 404 | Computing the average complexity (params) of modules. 405 | ''' 406 | from utils import count_parameters_in_MB 407 | import numpy as np 408 | list_channels=[32, 64, 128, 256] 409 | run_time_ratio = [] 410 | params = np.zeros((2,3)) 411 | for i in range(1, len(list_channels)): 412 | H = HorizontalModule(list_channels[:i+1]).cuda() 413 | F = FusionModule(list_channels[:i+1]).cuda() 414 | params[0,i-1] = count_parameters_in_MB(H) # parallel module 415 | params[1,i-1] = count_parameters_in_MB(F) # fusion module 416 | comp = np.sum(params, axis=1)/np.sum(params) 417 | print('params ratio: %.1f %.1f' % (comp[0], comp[1])) -------------------------------------------------------------------------------- /deraining/search_space.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class Bottleneck(nn.Module): 5 | expansion = 4 6 | def __init__(self, inplanes, planes, stride=1, downsample=None): 7 | super(Bottleneck, self).__init__() 8 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 9 | self.bn1 = nn.BatchNorm2d(planes) 10 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 11 | self.bn2 = nn.BatchNorm2d(planes) 12 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 13 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 14 | self.relu = nn.ReLU(inplace=True) 15 | self.downsample = downsample 16 | self.stride = stride 17 | 18 | def forward(self, x): 19 | residual = x 20 | out = self.conv1(x) 21 | out = self.bn1(out) 22 | out = self.relu(out) 23 | out = self.conv2(out) 24 | out = self.bn2(out) 25 | out = self.relu(out) 26 | out = self.conv3(out) 27 | out = self.bn3(out) 28 | if self.downsample is not None: 29 | residual = self.downsample(x) 30 | out += residual 31 | out = self.relu(out) 32 | return out 33 | 34 | class ResidualBlock(nn.Module): 35 | def __init__(self, inplanes, planes, stride=1, downsample=None): 36 | super(ResidualBlock, self).__init__() 37 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 38 | self.bn1 = nn.BatchNorm2d(planes, momentum=0.1) 39 | self.relu = nn.ReLU(inplace=True) 40 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 41 | self.bn2 = nn.BatchNorm2d(planes, momentum=0.1) 42 | self.downsample = downsample 43 | 44 | def forward(self, x): 45 | residual = x 46 | out = self.conv1(x) 47 | out = self.bn1(out) 48 | out = self.relu(out) 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | if self.downsample is not None: 52 | residual = self.downsample(x) 53 | out += residual 54 | out = self.relu(out) 55 | return out 56 | 57 | # parallel module 58 | class HorizontalModule(nn.Module): 59 | def __init__(self, num_inchannels, num_blocks=None): 60 | self.num_branches = len(num_inchannels) 61 | assert self.num_branches > 1 62 | if num_blocks is None: 63 | num_blocks = [1 for _ in range(self.num_branches)] 64 | else: 65 | assert self.num_branches==len(num_blocks) 66 | 67 | super(HorizontalModule, self).__init__() 68 | 69 | self.branches = nn.ModuleList() 70 | for i in range(self.num_branches): 71 | layers=[] 72 | for _ in range(num_blocks[i]): 73 | layers.append(ResidualBlock(num_inchannels[i], num_inchannels[i])) 74 | self.branches.append(nn.Sequential(*layers)) 75 | 76 | def forward(self, x): 77 | for i in range(self.num_branches): 78 | x[i] = self.branches[i](x[i]) 79 | return x 80 | 81 | class FusionModule(nn.Module): 82 | def __init__(self, num_inchannels, multi_scale_output=True): 83 | super(FusionModule, self).__init__() 84 | self.num_branches = len(num_inchannels) 85 | self.multi_scale_output = multi_scale_output 86 | assert self.num_branches > 1 87 | 88 | self.relu = nn.ReLU(True) 89 | self.fuse_layers = nn.ModuleList() 90 | for i in range(self.num_branches if self.multi_scale_output else 1): 91 | fuse_layer = [] 92 | for j in range(self.num_branches): 93 | if j > i: 94 | # Unsample 95 | fuse_layer.append(nn.Sequential( 96 | nn.Conv2d(num_inchannels[j], num_inchannels[i], 1, 1, 0, bias=False), 97 | nn.BatchNorm2d(num_inchannels[i]), 98 | nn.Upsample(scale_factor=2**(j-i), mode='nearest') 99 | )) 100 | elif j == i: 101 | # identity 102 | fuse_layer.append(None) 103 | else: 104 | # downsample 105 | conv3x3s = [] 106 | for k in range(i-j): 107 | if k == i - j - 1: 108 | num_outchannels_conv3x3 = num_inchannels[i] 109 | conv3x3s.append(nn.Sequential( 110 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 111 | nn.BatchNorm2d(num_outchannels_conv3x3) 112 | )) 113 | else: 114 | num_outchannels_conv3x3 = num_inchannels[j] 115 | conv3x3s.append(nn.Sequential( 116 | nn.Conv2d(num_inchannels[j], num_outchannels_conv3x3, 3, 2, 1, bias=False), 117 | nn.BatchNorm2d(num_outchannels_conv3x3), 118 | nn.ReLU(True) 119 | )) 120 | fuse_layer.append(nn.Sequential(*conv3x3s)) 121 | self.fuse_layers.append(nn.ModuleList(fuse_layer)) 122 | 123 | def forward(self, x): 124 | x_fuse = [] 125 | for i in range(len(self.fuse_layers)): 126 | y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) 127 | for j in range(1, self.num_branches): 128 | if i == j: 129 | y = y + x[j] 130 | else: 131 | y = y + self.fuse_layers[i][j](x[j]) 132 | x_fuse.append(self.relu(y)) 133 | return x_fuse 134 | 135 | class TransitionModule(nn.Module): 136 | def __init__(self, num_channels_pre_layer, num_channels_cur_layer): 137 | super(TransitionModule, self).__init__() 138 | 139 | self.num_branches_cur = len(num_channels_cur_layer) 140 | self.num_branches_pre = len(num_channels_pre_layer) 141 | 142 | self.transition_layers = nn.ModuleList() 143 | for i in range(self.num_branches_cur): 144 | if i < self.num_branches_pre: 145 | # horizontal transition 146 | if num_channels_cur_layer[i] != num_channels_pre_layer[i]: 147 | # alter channels 148 | self.transition_layers.append(nn.Sequential( 149 | nn.Conv2d(num_channels_pre_layer[i], num_channels_cur_layer[i], 3, 1, 1, bias=False), 150 | nn.BatchNorm2d(num_channels_cur_layer[i]), 151 | nn.ReLU(inplace=True) 152 | )) 153 | else: 154 | # no operation 155 | self.transition_layers.append(None) 156 | else: 157 | # downsample transition (new scales) 158 | conv3x3s = [] 159 | for j in range(i+1-self.num_branches_pre): 160 | inchannels = num_channels_pre_layer[-1] 161 | outchannels = num_channels_cur_layer[i] if j == i-self.num_branches_pre else inchannels 162 | conv3x3s.append(nn.Sequential( 163 | nn.Conv2d(inchannels, outchannels, 3, 2, 1, bias=False), 164 | nn.BatchNorm2d(outchannels), 165 | nn.ReLU(inplace=True) 166 | )) 167 | self.transition_layers.append(nn.Sequential(*conv3x3s)) 168 | 169 | def forward(self, x): 170 | x_list = [] 171 | for i in range(self.num_branches_cur): 172 | if self.transition_layers[i] is not None: 173 | if i < self.num_branches_pre: 174 | x_list.append(self.transition_layers[i](x[i])) 175 | else: 176 | x_list.append(self.transition_layers[i](x[-1])) 177 | else: 178 | x_list.append(x[i]) 179 | return x_list 180 | 181 | class MixModule(nn.Module): 182 | def __init__(self, num_inchannels): 183 | super(MixModule, self).__init__() 184 | self.horizontal = HorizontalModule(num_inchannels) 185 | self.fuse = FusionModule(num_inchannels) 186 | self.softmax = nn.Softmax(dim=0) 187 | 188 | def forward(self, x, weights): 189 | h_x = self.horizontal(x) 190 | f_x = self.fuse(x) 191 | 192 | weights = self.softmax(weights) 193 | 194 | x_list = [] 195 | for i in range(len(x)): 196 | x_list.append(weights[0]*h_x[i]+weights[1]*f_x[i]) 197 | 198 | return x_list 199 | 200 | 201 | class SearchSpace(nn.Module): 202 | # list_channels of branches; list_modules of stages 203 | def __init__(self, in_channel=3, list_channels=[32, 64, 128, 256], list_modules=[2,4,4,4]): 204 | super(SearchSpace,self).__init__() 205 | self.list_channels = list_channels 206 | self.list_modules = list_modules 207 | 208 | self.stage1 = self.__make_layer(in_channel, list_channels[0], list_modules[0]) 209 | 210 | self.transition1 = TransitionModule([list_channels[0]*Bottleneck.expansion],list_channels[:2]) 211 | self.stage2 = self.__make_stage(list_channels[:2],list_modules[1]) 212 | 213 | self.transition2 = TransitionModule(list_channels[:2],list_channels[:3]) 214 | self.stage3 = self.__make_stage(list_channels[:3],list_modules[2]) 215 | 216 | self.transition3 = TransitionModule(list_channels[:3],list_channels[:4]) 217 | self.stage4 = self.__make_stage(list_channels[:4],list_modules[3]) 218 | 219 | self.final_fuse = FusionModule(list_channels, multi_scale_output=False) 220 | self.final_conv = nn.Sequential( 221 | nn.Conv2d(list_channels[0], in_channel, 1, 1, bias=False) 222 | ) 223 | 224 | self.__init_architecture() 225 | self.init_weights() 226 | 227 | self.MSELoss = torch.nn.MSELoss().cuda() 228 | 229 | 230 | def __make_layer(self, inplanes, planes, blocks, stride=1): 231 | layers = [] 232 | downsample = None 233 | if stride != 1 or inplanes != planes * Bottleneck.expansion: 234 | downsample = nn.Sequential( 235 | nn.Conv2d(inplanes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False), 236 | nn.BatchNorm2d(planes * Bottleneck.expansion), 237 | ) 238 | layers.append(Bottleneck(inplanes, planes, stride, downsample)) 239 | inplanes = planes * Bottleneck.expansion 240 | for _ in range(1, blocks): 241 | layers.append(Bottleneck(inplanes, planes)) 242 | return nn.Sequential(*layers) 243 | 244 | def __make_stage(self, num_inchannels, num_module): 245 | modules = nn.ModuleList() 246 | for _ in range(num_module): 247 | modules.append(MixModule(num_inchannels)) 248 | return modules 249 | 250 | def new(self): 251 | model_new = SearchSpace().cuda() 252 | for x, y in zip(model_new.arch_parameters(), self.arch_parameters()): 253 | x.data.copy_(y.data) 254 | return model_new 255 | 256 | def __init_architecture(self): 257 | self.arch_param = torch.randn(sum(self.list_modules[1:]), 2, dtype=torch.float).cuda()*1e-3 258 | self.arch_param.requires_grad = True 259 | 260 | def arch_parameters(self): 261 | return [self.arch_param] 262 | 263 | def init_weights(self): 264 | print('=> init weights from normal distribution') 265 | for m in self.modules(): 266 | if isinstance(m, nn.Conv2d): 267 | nn.init.normal_(m.weight, std=0.001) 268 | for name, _ in m.named_parameters(): 269 | if name in ['bias']: 270 | nn.init.constant_(m.bias, 0) 271 | elif isinstance(m, nn.BatchNorm2d): 272 | nn.init.constant_(m.weight, 1) 273 | nn.init.constant_(m.bias, 0) 274 | elif isinstance(m, nn.ConvTranspose2d): 275 | nn.init.normal_(m.weight, std=0.001) 276 | for name, _ in m.named_parameters(): 277 | if name in ['bias']: 278 | nn.init.constant_(m.bias, 0) 279 | 280 | def forward(self, x): 281 | x = self.stage1(x) 282 | 283 | x_list = self.transition1([x]) 284 | idx = 0 285 | for i in range(self.list_modules[1]): 286 | x_list = self.stage2[i](x_list, self.arch_param[idx+i]) 287 | #print([idx+i, self.arch_param[idx+i]]) 288 | 289 | x_list = self.transition2(x_list) 290 | idx = idx + self.list_modules[1] 291 | for i in range(self.list_modules[2]): 292 | x_list = self.stage3[i](x_list,self.arch_param[idx+i]) 293 | #print([idx+i, self.arch_param[idx+i]]) 294 | 295 | x_list = self.transition3(x_list) 296 | idx = idx + self.list_modules[2] 297 | for i in range(self.list_modules[3]): 298 | x_list = self.stage4[i](x_list,self.arch_param[idx+i]) 299 | #print([idx+i, self.arch_param[idx+i]]) 300 | 301 | x = self.final_fuse(x_list)[0] 302 | x = self.final_conv(x) 303 | return x 304 | 305 | def loss(self, input, target): 306 | logits = self(input) 307 | mse_loss = self.MSELoss(logits, target) 308 | arch_weights = torch.nn.functional.softmax(self.arch_param, dim=1) 309 | 310 | # regular_loss 311 | regular_loss = -arch_weights*torch.log10(arch_weights)-(1-arch_weights)*torch.log10(1-arch_weights) 312 | # latency loss: computing the average complexity (params num) of modules by running `python search_space.py` 313 | # latency_loss = torch.mean(arch_weights[:,0]*0.7+arch_weights[:,1]*0.3) 314 | 315 | return mse_loss + regular_loss.mean()*0.01 #+ latency_loss*0.1 316 | 317 | # multi-resolution network 318 | class HighResolutionNet(nn.Module): 319 | def __init__(self, ckt_path, in_channel=3, list_channels=[32, 64, 128, 256], list_modules=[2,4,4,4]): 320 | super(HighResolutionNet, self).__init__() 321 | self.list_channels = list_channels 322 | self.list_modules = list_modules 323 | 324 | arch = torch.load(ckt_path)['arch_param'] 325 | softmax = nn.Softmax(dim=1) 326 | arch_soft = softmax(arch) 327 | arch_hard = torch.argmax(arch_soft, dim=1) 328 | 329 | self.stage1 = self.__make_layer(in_channel, list_channels[0], list_modules[0]) 330 | 331 | self.transition1 = TransitionModule([list_channels[0]*Bottleneck.expansion],list_channels[:2]) 332 | idx = 0 333 | self.stage2 = self.__make_stage(arch_hard[idx:list_modules[1]], list_channels[:2], list_modules[1]) 334 | 335 | self.transition2 = TransitionModule(list_channels[:2],list_channels[:3]) 336 | idx += list_modules[1] 337 | self.stage3 = self.__make_stage(arch_hard[idx:idx+list_modules[2]], list_channels[:3],list_modules[2]) 338 | 339 | self.transition3 = TransitionModule(list_channels[:3],list_channels[:4]) 340 | idx += list_modules[2] 341 | self.stage4 = self.__make_stage(arch_hard[idx:idx+list_modules[3]], list_channels[:4],list_modules[3]) 342 | 343 | self.final_fuse = FusionModule(list_channels, multi_scale_output=False) 344 | self.final_conv = nn.Sequential( 345 | nn.Conv2d(list_channels[0], in_channel, 1, 1, bias=False) 346 | ) 347 | 348 | self.init_weights() 349 | 350 | def __make_layer(self, inplanes, planes, blocks, stride=1): 351 | layers = [] 352 | downsample = None 353 | if stride != 1 or inplanes != planes * Bottleneck.expansion: 354 | downsample = nn.Sequential( 355 | nn.Conv2d(inplanes, planes * Bottleneck.expansion, kernel_size=1, stride=stride, bias=False), 356 | nn.BatchNorm2d(planes * Bottleneck.expansion), 357 | ) 358 | layers.append(Bottleneck(inplanes, planes, stride, downsample)) 359 | inplanes = planes * Bottleneck.expansion 360 | for _ in range(1, blocks): 361 | layers.append(Bottleneck(inplanes, planes)) 362 | return nn.Sequential(*layers) 363 | 364 | def __make_stage(self, arch, num_inchannels, num_module): 365 | modules = [] 366 | for i in range(num_module): 367 | if arch[i] == 0: 368 | modules.append(HorizontalModule(num_inchannels)) 369 | elif arch[i] == 1: 370 | modules.append(FusionModule(num_inchannels)) 371 | return nn.Sequential(*modules) 372 | 373 | def init_weights(self): 374 | print('=> init weights from normal distribution') 375 | for m in self.modules(): 376 | if isinstance(m, nn.Conv2d): 377 | nn.init.normal_(m.weight, std=0.001) 378 | for name, _ in m.named_parameters(): 379 | if name in ['bias']: 380 | nn.init.constant_(m.bias, 0) 381 | elif isinstance(m, nn.BatchNorm2d): 382 | nn.init.constant_(m.weight, 1) 383 | nn.init.constant_(m.bias, 0) 384 | elif isinstance(m, nn.ConvTranspose2d): 385 | nn.init.normal_(m.weight, std=0.001) 386 | for name, _ in m.named_parameters(): 387 | if name in ['bias']: 388 | nn.init.constant_(m.bias, 0) 389 | 390 | def forward(self, x): 391 | x = self.stage1(x) 392 | x_list = self.transition1([x]) 393 | x_list = self.stage2(x_list) 394 | x_list = self.transition2(x_list) 395 | x_list = self.stage3(x_list) 396 | x_list = self.transition3(x_list) 397 | x_list = self.stage4(x_list) 398 | x = self.final_fuse(x_list)[0] 399 | x = self.final_conv(x) 400 | return x 401 | 402 | if __name__=='__main__': 403 | ''' 404 | Computing the average complexity (params) of modules. 405 | ''' 406 | from utils import count_parameters_in_MB 407 | import numpy as np 408 | list_channels=[32, 64, 128, 256] 409 | run_time_ratio = [] 410 | params = np.zeros((2,3)) 411 | for i in range(1, len(list_channels)): 412 | H = HorizontalModule(list_channels[:i+1]).cuda() 413 | F = FusionModule(list_channels[:i+1]).cuda() 414 | params[0,i-1] = count_parameters_in_MB(H) # parallel module 415 | params[1,i-1] = count_parameters_in_MB(F) # fusion module 416 | comp = np.sum(params, axis=1)/np.sum(params) 417 | print('params ratio: %.1f %.1f' % (comp[0], comp[1])) 418 | --------------------------------------------------------------------------------