├── .gitignore ├── README.md ├── code ├── data.lua ├── demo.lua ├── donkey.lua ├── loss.lua ├── main.lua ├── model.lua ├── opts.lua ├── test.lua └── train.lua ├── dataset ├── .gitignore └── examples │ ├── blur_gamma │ ├── seq_id-1 frame-8 Input.png │ ├── seq_id-11 frame-1 Input.png │ ├── seq_id-2 frame-8 Input.png │ ├── seq_id-3 frame-1 Input.png │ ├── seq_id-4 frame-1 Input.png │ ├── seq_id-5 frame-2 Input.png │ └── seq_id-6 frame-8 Input.png │ ├── blur_lin │ ├── seq_id-1 frame-8 Input.png │ ├── seq_id-11 frame-1 Input.png │ ├── seq_id-2 frame-8 Input.png │ ├── seq_id-3 frame-1 Input.png │ ├── seq_id-4 frame-1 Input.png │ ├── seq_id-5 frame-2 Input.png │ └── seq_id-6 frame-8 Input.png │ ├── deblurred_gamma │ ├── seq_id-1 frame-8 Input.png │ ├── seq_id-11 frame-1 Input.png │ ├── seq_id-2 frame-8 Input.png │ ├── seq_id-3 frame-1 Input.png │ ├── seq_id-4 frame-1 Input.png │ ├── seq_id-5 frame-2 Input.png │ └── seq_id-6 frame-8 Input.png │ └── deblurred_lin │ ├── seq_id-1 frame-8 Input.png │ ├── seq_id-11 frame-1 Input.png │ ├── seq_id-2 frame-8 Input.png │ ├── seq_id-3 frame-1 Input.png │ ├── seq_id-4 frame-1 Input.png │ ├── seq_id-5 frame-2 Input.png │ └── seq_id-6 frame-8 Input.png ├── experiment └── .gitignore └── images ├── Flower_blur1.png ├── Flower_sharp1.png ├── Istanbul_blur1.png ├── Istanbul_sharp1.png └── NTIRE2019.jpg /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepDeblur_release 2 | 3 | Single image deblurring with deep learning. 4 | 5 | This is a project page for our research. 6 | Please refer to our CVPR 2017 paper for details: 7 | 8 | Deep Multi-scale Convolutional Neural Network for Dynamic Scene Deblurring 9 | [[paper](http://openaccess.thecvf.com/content_cvpr_2017/papers/Nah_Deep_Multi-Scale_Convolutional_CVPR_2017_paper.pdf)] 10 | [[supplementary](http://openaccess.thecvf.com/content_cvpr_2017/supplemental/Nah_Deep_Multi-Scale_Convolutional_2017_CVPR_supplemental.zip)] 11 | [[slide](https://drive.google.com/file/d/1sj7l2tGgJR-8wTyauvnSDGpiokjOzX_C/view?usp=sharing)] 12 | 13 | 14 | If you find our work useful in your research or publication, please cite our work: 15 | ``` 16 | @InProceedings{Nah_2017_CVPR, 17 | author = {Nah, Seungjun and Kim, Tae Hyun and Lee, Kyoung Mu}, 18 | title = {Deep Multi-Scale Convolutional Neural Network for Dynamic Scene Deblurring}, 19 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 20 | month = {July}, 21 | year = {2017} 22 | } 23 | ``` 24 | 25 | ## PyTorch version 26 | 27 | PyTorch version is now available: [https://github.com/SeungjunNah/DeepDeblur-PyTorch](https://github.com/SeungjunNah/DeepDeblur-PyTorch) 28 | 29 | 30 | ## New dataset released! 31 | 32 | Check out our new **[REDS](https://seungjunnah.github.io/Datasets/reds)** dataset! 33 | In CVPR 2019, I co-organized the [4th NTIRE workshop](http://www.vision.ee.ethz.ch/ntire19/) and the corresponding video restoration challenges. 34 | We released the **[REDS](https://seungjunnah.github.io/Datasets/reds)** dataset for challenge participants to train and evaluate video deblurring / super-resolution methods. 35 | Special thanks go to my colleagues, [Sungyong Baik](https://scholar.google.com/citations?user=lQ4gotkAAAAJ&hl=en), [Seokil Hong](https://scholar.google.com/citations?user=nYDLTksAAAAJ&hl=en), [Gyeongsik Moon](https://scholar.google.com/citations?user=2f2D258AAAAJ&hl=en), [Sanghyun Son](https://scholar.google.com/citations?user=nWaSdu0AAAAJ&hl=en), [Radu Timofte](https://scholar.google.com/citations?user=u3MwH5kAAAAJ&hl=en) and [Kyoung Mu Lee](https://scholar.google.com/citations?user=Hofj9kAAAAAJ&hl=en) for collecting, processing, and releasing the dataset together. 36 | 37 | ### Updates 38 | Downloads are now available for training, validation, and test input data. A public leaderboard site is under construction. 39 | Download page: [https://seungjunnah.github.io/Datasets/reds](https://seungjunnah.github.io/Datasets/reds) 40 | 41 | 42 | 43 | ## Dependencies 44 | * [torch7](http://torch.ch/docs/getting-started.html#_) 45 | * [torchx](https://github.com/nicholas-leonard/torchx) 46 | ```bash 47 | luarocks install torchx 48 | ``` 49 | * cudnn 50 | ```bash 51 | cd ~/torch/extra/cudnn 52 | git checkout R7 # R7 is for cudnn v7 53 | luarocks make 54 | ``` 55 | 56 | ## Code 57 | 58 | To run demo, download and extract the trained models into "experiment" folder. 59 | 60 | 61 | * [models](https://drive.google.com/file/d/1Z8dV6KuubfOKj4ganEjxymhyMoXoydfo/view?usp=sharing) 62 | 63 | Type following command in "code" folder. 64 | ```bash 65 | qlua -i demo.lua -load -save release_scale3_adv_gamma -blur_type gamma2.2 -type cudaHalf 66 | qlua -i demo.lua -load -save release_scale3_adv_lin -blur_type linear -type cudaHalf 67 | ``` 68 | 69 | To train a model, clone this repository and download below dataset in "dataset" directory. 70 | 71 | The data structure should look like 72 | "dataset/GOPRO_Large/train/GOPRxxxx_xx_xx/blur/xxxxxx.png" 73 | 74 | Then run main.lua in "code" directory with optional parameters. 75 | ```bash 76 | th main.lua -nEpochs 450 -save scale3 # Train for 450 epochs, save in 'experiment/scale3' 77 | th main.lua -load -save scale3 # Load saved model 78 | > blur_dir, output_dir = ... 79 | > deblur_dir(blur_dir, output_dir) 80 | ``` 81 | Optional parameters are listed in opts.lua 82 | 83 | ex) -type: Operation type option. Supports cuda and cudaHalf. Half precision CNN has similar accuracy as single precision in evaluation mode. However, fp16 training is not meant to be supported in this code. ADAM optimizer is hard to use with fp16. 84 | 85 | ## Dataset 86 | 87 | In this work, we proposed a new dataset of realistic blurry and sharp image pairs using a high-speed camera. 88 | However, we do not provide blur kernels as they are unknown. 89 | 90 | * Downloads available [here](https://seungjunnah.github.io/Datasets/gopro) 91 | 92 | Statistics | Training | Test | Total 93 | -- | -- | -- | -- 94 | sequences | 22 | 11 | 33 95 | image pairs | 2103 | 1111 | 3214 96 | 97 | Here are some example images. 98 | 99 | Blurry image example 1 100 | ![Blurry image](images/Istanbul_blur1.png) 101 | 102 | Sharp image example 1 103 | ![Sharp image](images/Istanbul_sharp1.png) 104 | 105 | Blurry image example 2 106 | ![Blurry image](images/Flower_blur1.png) 107 | 108 | Sharp image example 2 109 | ![Sharp image](images/Flower_sharp1.png) 110 | 111 | 112 | ## Acknowledgment 113 | 114 | This project is partially funded by Microsoft Research Asia 115 | -------------------------------------------------------------------------------- /code/data.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- torch 2 | require 'torchx' 3 | 4 | local threads = require 'threads' 5 | threads.serialization('threads.sharedserialize') 6 | 7 | print('==> Reading Dataset '..opt.dataset) 8 | 9 | function get_imglist(directory, extlist) 10 | 11 | local temp_list = paths.indexdir(directory, extlist) 12 | local list = {} 13 | for i = 1, temp_list:size() do 14 | table.insert(list, temp_list:filename(i)) 15 | end 16 | table.sort(list) 17 | 18 | return list 19 | end 20 | 21 | ---------------------------------------------------- 22 | --train_list[seq_id]['blur'][frame_id] 23 | --train_list[seq_id]['blur_gamma'][frame_id] 24 | --train_list[seq_id]['sharp'][frame_id] 25 | --test_list[seq_id]['blur'][frame_id] 26 | --test_list[seq_id]['blur_gamma'][frame_id] 27 | --test_list[seq_id]['sharp'][frame_id] 28 | ---------------------------------------------------- 29 | local blur_key_linear, blur_key_gamma = 'blur', 'blur_gamma' 30 | sharp_key = 'sharp' 31 | if opt.blur_type == 'linear' then 32 | blur_key = blur_key_linear 33 | elseif opt.blur_type == 'gamma2.2' then 34 | blur_key = blur_key_gamma 35 | else 36 | error('unknown camera response function') 37 | end 38 | 39 | 40 | local datadir = paths.concat('../dataset', opt.dataset) 41 | local train_dir = paths.concat(datadir, 'train') 42 | local test_dir = paths.concat(datadir, 'test') 43 | 44 | local data_list 45 | train_list, test_list = {}, {} 46 | 47 | for subset in paths.iterdirs(datadir) do 48 | local subdir 49 | local sublist 50 | if subset == 'train' then 51 | subdir = train_dir 52 | sublist = train_list 53 | elseif subset == 'test' then 54 | subdir = test_dir 55 | sublist = test_list 56 | else -- no train / test division 57 | subdir = datadir 58 | sublist = data_list 59 | end 60 | 61 | local sequences = {} 62 | for seq_name in paths.iterdirs(subdir) do 63 | table.insert(sequences, seq_name) 64 | end 65 | table.sort(sequences) 66 | 67 | for seq_id, seq_name in ipairs(sequences) do 68 | local sequence_name = paths.concat(subdir, seq_name) 69 | 70 | local blur_dir = paths.concat(sequence_name, blur_key) 71 | local sharp_dir = paths.concat(sequence_name, sharp_key) 72 | 73 | sublist[seq_id] = {} 74 | sublist[seq_id][blur_key] = get_imglist(blur_dir) 75 | sublist[seq_id][sharp_key] = get_imglist(sharp_dir) 76 | end 77 | end 78 | 79 | if #train_list == 0 or #test_list == 0 then 80 | train_list = data_list 81 | test_list = data_list 82 | -- full_data = true 83 | end 84 | 85 | average = 0.5 86 | 87 | do -- initialize data loading threads 88 | if opt.nDonkeys > 0 then 89 | local def_type = default_type 90 | local options = opt -- make an upvalue to serialize over to donkey threads 91 | local list_train, list_test = train_list, test_list 92 | local key_blur, key_sharp = blur_key, sharp_key 93 | donkeys = threads.Threads( 94 | opt.nDonkeys, 95 | function(threadid) 96 | require 'torch' 97 | require 'image' 98 | 99 | return threadid 100 | end, 101 | function(threadid) 102 | default_type = def_type 103 | torch.setdefaulttensortype(default_type) 104 | opt = options 105 | blur_key, sharp_key = key_blur, key_sharp 106 | tid = threadid 107 | local seed = torch.seed() 108 | print(string.format('Starting donkey with id: %d seed: %d', tid, seed)) 109 | train_list, test_list = list_train, list_test 110 | blur_key, sharp_key = key_blur, key_sharp 111 | paths.dofile('donkey.lua') 112 | end 113 | ) 114 | else -- single threaded data loading. Useful for debugging. 115 | paths.dofile('donkey.lua') 116 | donkeys = {} 117 | function donkeys:addjob(f1, f2) f2(f1()) end 118 | function donkeys:synchronize() end 119 | end 120 | end 121 | 122 | collectgarbage() 123 | collectgarbage() -------------------------------------------------------------------------------- /code/demo.lua: -------------------------------------------------------------------------------- 1 | -- To run demo, type 2 | -- qlua -i demo.lua -load -save 'scale3-depth40_adv' 3 | 4 | require 'torch' 5 | require 'cutorch' 6 | require 'paths' 7 | require 'xlua' 8 | require 'optim' 9 | 10 | default_type = 'torch.FloatTensor' 11 | torch.setdefaulttensortype(default_type) 12 | 13 | local opts = paths.dofile('opts.lua') 14 | opt = opts.parse(arg) 15 | 16 | -- nb of threads and fixed seed (for repeatable experiments) 17 | torch.setnumthreads(opt.threads) 18 | torch.manualSeed(opt.seed) 19 | ---------------------------------------------------------------------- 20 | ---------------------------------------------------------------------- 21 | dofile 'model.lua' 22 | dofile 'train.lua' 23 | dofile 'test.lua' 24 | ---------------------------------------------------------------------- 25 | ---------------------------------------------------------------------- 26 | 27 | local example_dir = paths.concat('..', 'dataset', 'examples') 28 | local image_dir, output_dir 29 | 30 | if opt.blur_type == 'linear' then 31 | image_dir = paths.concat(example_dir, 'blur_lin') 32 | output_dir = paths.concat(example_dir, 'deblurred_lin') 33 | elseif opt.blur_type == 'gamma2.2' then 34 | image_dir = paths.concat(example_dir, 'blur_gamma') 35 | output_dir = paths.concat(example_dir, 'deblurred_gamma') 36 | end 37 | 38 | demo(image_dir, output_dir) 39 | 40 | require 'trepl'() -------------------------------------------------------------------------------- /code/donkey.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- torch 2 | require 'torchx' 3 | require 'cutorch' 4 | require 'xlua' -- xlua provides useful tools, like progress bars 5 | require 'optim' -- an optimization package, for online and batch methods 6 | require 'image' -- for rotating and flipping patches 7 | require 'math' -- to calculate base kernels 8 | 9 | function generate_pyramid(img, scale_levels) 10 | 11 | local scale_levels = scale_levels or 1 12 | local scales = {} 13 | for i = 1, scale_levels do 14 | scales[i] = 0.5^(i-1) 15 | end 16 | 17 | local average = img:mean(2):mean(3):squeeze() 18 | for channel = 1, img:size(1) do 19 | img[channel] = img[channel] - average[channel] 20 | end 21 | 22 | local pyramid = image.gaussianpyramid(img, scales) 23 | for lv = 1, scale_levels do 24 | for channel = 1, img:size(1) do 25 | pyramid[lv][channel] = pyramid[lv][channel] + average[channel] 26 | end 27 | end 28 | 29 | return pyramid 30 | end 31 | 32 | -- function extract_patch(train_list, seq_id, frame_id, merge_blur, supWidth, supHeight, scale_levels) 33 | function extract_patch(seq_id, frame_id, supWidth, supHeight, scale_levels) 34 | 35 | local supWidth = supWidth or opt.supWidth 36 | local supHeight = supHeight or supWidth 37 | local scale_levels = scale_levels or opt.scale_levels 38 | 39 | local imsize = image.getSize(train_list[seq_id][sharp_key][frame_id]) 40 | local scale = 0.5^(scale_levels-1) 41 | 42 | local lux_s = torch.random(0, math.floor((imsize[3]-supWidth)*scale)) 43 | local luy_s = torch.random(0, math.floor((imsize[2]-supHeight)*scale)) 44 | local lux, luy = lux_s/scale, luy_s/scale -- prevent translation while downsampling 45 | 46 | local input_patch = image.crop(image.load(train_list[seq_id][blur_key][frame_id]), lux, luy, lux+supWidth, luy+supHeight) 47 | local target_patch = image.crop(image.load(train_list[seq_id][sharp_key][frame_id]), lux, luy, lux+supWidth, luy+supHeight) 48 | 49 | collectgarbage() 50 | collectgarbage() 51 | 52 | return input_patch, target_patch 53 | 54 | end 55 | 56 | function augment_patch(input_patch, target_patch) 57 | 58 | local target_input = torch.random(1, 10) == 1 -- sharp input to sharp output 59 | local change_saturation = torch.random(1, 10) == 1 60 | local flip_h = torch.random(0, 1) == 1 61 | local rotate = torch.random(0, 3) 62 | 63 | local shuffle_color = true 64 | local add_noise = true 65 | 66 | if target_input then 67 | input_patch = target_patch:clone() 68 | end 69 | 70 | if flip_h then 71 | input_patch = image.hflip(input_patch) 72 | target_patch = image.hflip(target_patch) 73 | end 74 | 75 | if rotate > 0 then 76 | local theta = math.pi/2 * rotate 77 | input_patch = image.rotate(input_patch, theta) 78 | target_patch = image.rotate(target_patch, theta) 79 | end 80 | 81 | if shuffle_color then 82 | local nChannel = input_patch:size(1) 83 | local perm = torch.randperm(nChannel):long() 84 | 85 | input_patch = input_patch:index(1, perm) 86 | target_patch = target_patch:index(1, perm) 87 | end 88 | 89 | if change_saturation then 90 | local amp_factor = 1 + torch.uniform(-0.5, 0.5) 91 | local input_hsv = image.rgb2hsv(input_patch) 92 | local target_hsv = image.rgb2hsv(target_patch) 93 | 94 | input_hsv[2]:mul(amp_factor) 95 | target_hsv[2]:mul(amp_factor) 96 | 97 | input_patch = image.hsv2rgb(input_hsv) 98 | target_patch = image.hsv2rgb(target_hsv) 99 | end 100 | 101 | if add_noise then 102 | local sigma_sigma = 2/255 103 | local sigma = torch.randn(1)[1] * sigma_sigma 104 | local noise = torch.randn(input_patch:size()) * sigma 105 | input_patch:add(noise) 106 | end 107 | 108 | input_patch:clamp(0, 1) 109 | target_patch:clamp(0, 1) 110 | 111 | collectgarbage() 112 | collectgarbage() 113 | 114 | return input_patch, target_patch 115 | 116 | end 117 | 118 | function generate_batch(batch_size, scale_levels, supWidth, supHeight) 119 | 120 | local batch_size = batch_size or opt.minibatchSize 121 | local scale_levels = scale_levels or opt.scale_levels 122 | local supWidth = supWidth or opt.supWidth 123 | local supHeight = supHeight or supWidth 124 | -- local merge_blur = merge_blur or 0 -- merge subsequent blurry images to generate even larger blurs 125 | 126 | local input_batch, target_batch = {}, {} 127 | for lv = 1, scale_levels do 128 | local scale = 0.5^(lv-1) 129 | local supHeight_lv = supHeight * scale 130 | local supWidth_lv = supWidth * scale 131 | input_batch[lv] = torch.zeros(batch_size, 3, supHeight_lv, supWidth_lv) 132 | target_batch[lv] = torch.zeros(batch_size, 3, supHeight_lv, supWidth_lv) 133 | end 134 | 135 | local input_patch, target_patch 136 | local input_patch_pyramid, target_patch_pyramid 137 | 138 | local seq_prob = torch.ones(#train_list) 139 | local patch_seq_id = torch.multinomial(seq_prob, batch_size, false) 140 | for patch_id = 1, batch_size do 141 | local seq_id = patch_seq_id[patch_id] 142 | local frame_id = torch.random(#train_list[seq_id][sharp_key]) 143 | 144 | -- extract 145 | input_patch, target_patch = extract_patch(seq_id, frame_id, supWidth, supHeight, scale_levels) 146 | -- augment 147 | input_patch, target_patch = augment_patch(input_patch, target_patch) 148 | -- tug in 149 | input_patch_pyramid = generate_pyramid(input_patch, scale_levels) 150 | target_patch_pyramid = generate_pyramid(target_patch, scale_levels) 151 | for lv = 1, scale_levels do 152 | input_batch[lv][patch_id] = input_patch_pyramid[lv] 153 | target_batch[lv][patch_id] = target_patch_pyramid[lv] 154 | end 155 | end 156 | 157 | collectgarbage() 158 | collectgarbage() 159 | 160 | return input_batch, target_batch 161 | 162 | end 163 | 164 | function generate_testpair(seq_id, frame_id, scale_levels) 165 | 166 | local scale_levels = scale_levels or opt.scale_levels 167 | 168 | local input_img = image.load(test_list[seq_id][blur_key][frame_id]) 169 | local target_img = image.load(test_list[seq_id][sharp_key][frame_id]) 170 | 171 | local input_img_pyramid = generate_pyramid(input_img, scale_levels) 172 | local target_img_pyramid = generate_pyramid(target_img, scale_levels) 173 | for lv = 1, scale_levels do 174 | input_img_pyramid[lv] = input_img_pyramid[lv]:repeatTensor(1,1,1,1) 175 | target_img_pyramid[lv] = target_img_pyramid[lv]:repeatTensor(1,1,1,1) 176 | end 177 | 178 | collectgarbage() 179 | collectgarbage() 180 | 181 | return input_img_pyramid, target_img_pyramid 182 | end -------------------------------------------------------------------------------- /code/loss.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'nn' 3 | require 'cunn' 4 | 5 | if not opt then 6 | local opts = paths.dofile('opts.lua') 7 | opt = opts.parse(arg) 8 | dofile 'data.lua' 9 | end 10 | 11 | adv_train = opt.adv_weight > 0 12 | 13 | local data_term = nn.MultiCriterion() 14 | do 15 | local abs = nn.AbsCriterion(); abs.sizeAverage = true; 16 | data_term:add(abs, opt.abs_weight) 17 | local mse = nn.MSECriterion(); mse.sizeAverage = true; 18 | data_term:add(mse, opt.mse_weight) 19 | end 20 | 21 | local adv_loss 22 | if adv_train then 23 | local weights = torch.Tensor(opt.minibatchSize):fill(1/torch.log(2)) 24 | adv_loss = nn.BCECriterion(weights * opt.adv_weight) 25 | end 26 | 27 | criterion = {} 28 | criterion.G = nn.ParallelCriterion() 29 | for lv = 1, opt.scale_levels do 30 | criterion.G:add(data_term:clone()) 31 | end 32 | criterion.D = adv_loss 33 | 34 | criterion_container = nn.ParallelCriterion() 35 | :add(criterion.G) 36 | if adv_train then 37 | criterion_container:add(criterion.D) 38 | end 39 | 40 | if opt.type == 'cuda' then 41 | criterion.G = criterion.G:cuda() 42 | if adv_train then 43 | criterion.D = criterion.D:cuda() 44 | end 45 | criterion_container = criterion_container:cuda() 46 | elseif opt.type == 'cudaHalf' then 47 | criterion.G = criterion.G:cudaHalf() 48 | if adv_train then 49 | criterion.D = criterion.D:cudaHalf() 50 | end 51 | criterion_container = criterion_container:cudaHalf() 52 | end 53 | 54 | ---------------------------------------------------------------------- 55 | print '==> here is the loss function:' 56 | print(criterion_container) 57 | -------------------------------------------------------------------------------- /code/main.lua: -------------------------------------------------------------------------------- 1 | require 'torch' 2 | require 'cutorch' 3 | require 'paths' 4 | require 'xlua' 5 | require 'optim' 6 | 7 | function record_params(opt) 8 | 9 | if not paths.dirp('../experiment') then 10 | paths.mkdir('../experiment') 11 | end 12 | if not paths.dirp(opt.save) then 13 | paths.mkdir(opt.save) 14 | torch.save(paths.concat(opt.save, 'opt'), opt) 15 | end 16 | 17 | local today = now:sub(1, 11) 18 | local nowtime = now:sub(12, 13) .. ":" .. now:sub(15, 16) .. ":" .. now:sub(18, 19) 19 | 20 | local modelparam = io.open(opt.save .. "/modelparam.txt", "a+") 21 | modelparam:write("Experiment at " .. today .. nowtime .. "\n\n") 22 | if not (opt.load or opt.continue) then 23 | modelparam:write("Dataset : ".. opt.dataset.."\n") 24 | modelparam:write("Camera response function : ".. opt.blur_type.."\n") 25 | modelparam:write("model : " .. opt.model .. "\n") 26 | modelparam:write("scale_levels : " .. opt.scale_levels .. "\n\n") 27 | 28 | modelparam:write("supWidth : " .. opt.supWidth .. "\n") 29 | modelparam:write("nStates : " .. opt.nStates .. "\n") 30 | modelparam:write("filtsize : " .. opt.filtsize .. "\n") 31 | modelparam:write("nlayers : " .. opt.nlayers .. "\n") 32 | end 33 | modelparam:write("\n") 34 | modelparam:write("L1 loss weight : " .. opt.abs_weight .. "\n") 35 | modelparam:write("L2 loss weight : " .. opt.mse_weight .. "\n") 36 | modelparam:write("adversarial loss weight : " .. opt.adv_weight .. "\n") 37 | 38 | modelparam:write("optimization method : " .. opt.optimization .. "\n") 39 | if opt.optimization == 'SGD' or opt.optimization == 'ADAM' then 40 | modelparam:write("learning rate : " .. opt.rateLearning .. "\n") 41 | if opt.optimization == 'SGD' then 42 | modelparam:write("momentum : " .. opt.momentum .. "\n") 43 | modelparam:write("weight decay : " .. opt.weightDecay .. "\n") 44 | end 45 | end 46 | modelparam:write("batch size : " .. opt.epochbatchSize .. "\n") 47 | modelparam:write("mini-batch size : " .. opt.minibatchSize .. "\n") 48 | modelparam:write("\n\n") 49 | modelparam:close() 50 | 51 | return 52 | end 53 | 54 | default_type = 'torch.FloatTensor' 55 | torch.setdefaulttensortype(default_type) 56 | 57 | local opts = paths.dofile('opts.lua') 58 | opt = opts.parse(arg) 59 | record_params(opt) -- save experiment parameters 60 | 61 | ---------------------------------------------------------------------- 62 | ---------------------------------------------------------------------- 63 | print '==> executing all' 64 | dofile 'data.lua' 65 | dofile 'loss.lua' 66 | dofile 'model.lua' 67 | dofile 'train.lua' 68 | dofile 'test.lua' 69 | ---------------------------------------------------------------------- 70 | ---------------------------------------------------------------------- 71 | 72 | epoch = opt.epochNumber 73 | print('epoch begins : '..epoch) 74 | local epoch_threshold = opt.nEpochs 75 | local slow_down_step = math.ceil(150 * 4000/opt.epochbatchSize) 76 | 77 | if opt.load then 78 | require 'trepl'() 79 | else 80 | if opt.continue and (not opt.train_only) then 81 | if follow_up_test(1, epoch-1) == true then 82 | model = load_main_model(epoch-1) 83 | parameters, gradParameters = get_model_parameters() 84 | end 85 | end 86 | 87 | print '==> training!' 88 | while epoch <= epoch_threshold do 89 | train() 90 | if not opt.train_only then 91 | test(nil, true) 92 | end 93 | if epoch == slow_down_step then 94 | opt.rateLearning = opt.rateLearning / 10 95 | optimState.G.learningRate = opt.rateLearning 96 | if adv_train then 97 | optimState.D.learningRate = opt.rateLearning 98 | end 99 | end 100 | -- next epoch 101 | epoch = epoch + 1 102 | end 103 | epoch = epoch - 1 104 | if opt.train_only then 105 | test(nil, true) 106 | end 107 | if opt.train_only then 108 | follow_up_test(1, epoch-1) 109 | model = load_main_model(epoch) 110 | test(epoch, true) 111 | end 112 | end 113 | -------------------------------------------------------------------------------- /code/model.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- torch 2 | require 'cutorch' -- cudaTensor 3 | require 'image' -- for image transforms 4 | require 'nn' -- provides all sorts of trainable modules/layers 5 | require 'cunn' 6 | require 'cudnn' -- cudnn 7 | 8 | if not opt then 9 | opt = torch.load(paths.concat(opt.save, 'opt')); 10 | end 11 | 12 | cudnn.fastest = true 13 | cudnn.benchmark = true 14 | 15 | ---------------------------------------------------------------------- 16 | print '==> define parameters' 17 | nChannel = 3; 18 | 19 | supWidth = opt.supWidth; 20 | filtsize = opt.filtsize 21 | nlayers = opt.nlayers 22 | nStates = opt.nStates 23 | inChannel, outChannel = nChannel, nChannel 24 | 25 | model_dir = paths.concat(opt.save, 'models') 26 | if not paths.dirp(model_dir) then 27 | paths.mkdir(model_dir) 28 | end 29 | 30 | function ResBlock(filtsize, nStates, inStates) 31 | 32 | local function shortcut(str) 33 | local str = str or 1 34 | if str == 1 then 35 | return nn.Identity() 36 | else 37 | local str = 2 38 | return nn.Sequential() 39 | :add(nn.SpatialAveragePooling(1,1, str,str)) 40 | :add(nn.Concat(2)) 41 | :add(nn.Identity()) 42 | :add(nn.MulConstant(0)) 43 | end 44 | end 45 | 46 | local filtsize = filtsize or opt.filtsize 47 | local padW, padH = (filtsize-1)/2, (filtsize-1)/2 48 | local nStates = nStates or opt.nStates 49 | local inStates = inStates or nStates 50 | local str = ((nStates == inStates) and 1) or 2 51 | 52 | local block = nn.Sequential() 53 | local concat = nn.ConcatTable() 54 | 55 | local conv1 = nn.SpatialConvolution(inStates,nStates, filtsize,filtsize, str,str, padW,padH) 56 | local relu = nn.ReLU(true) 57 | local conv2 = nn.SpatialConvolution(nStates,nStates, filtsize,filtsize, 1,1, padW,padH) 58 | 59 | local path = nn.Sequential() 60 | :add(conv1):add(relu) 61 | :add(conv2) 62 | local concat = nn.ConcatTable() 63 | :add(path) 64 | :add(shortcut(str)) 65 | 66 | local block = nn.Sequential() 67 | :add(concat) 68 | :add(nn.CAddTable(true)) 69 | 70 | return block 71 | end 72 | 73 | function ResNet(nlayers, filtsize, inChannel, outChannel, nStates) 74 | 75 | local nlayers = nlayers or opt.nlayers 76 | local filtsize = filtsize or opt.filtsize 77 | local padW, padH = (filtsize-1)/2, (filtsize-1)/2 78 | local dW, dH = 1, 1 79 | local inChannel = inChannel or nChannel 80 | local outChannel = outChannel or nChannel 81 | local nStates = nStates or opt.nStates 82 | 83 | local model = nn.Sequential() 84 | local conv = nn.SpatialConvolution(inChannel,nStates, filtsize,filtsize, dW,dH, padW,padH) 85 | 86 | model:add(conv:clone()) 87 | for layer = 1, (nlayers-2)/2 do 88 | model:add(ResBlock(filtsize, nStates)) 89 | end 90 | conv = nn.SpatialConvolution(nStates,outChannel, filtsize,filtsize, dW,dH, padW,padH) 91 | model:add(conv:clone()) 92 | 93 | return model 94 | end 95 | 96 | function generate_conv_end(inChannel, outChannel, ratio) 97 | 98 | local inChannel = inChannel or nChannel 99 | local outChannel = outChannel or nChannel 100 | local ratio = ratio or 2 101 | -- local filt, pad, adj = ratio, 0, 0 102 | 103 | local filt = 5 --3 104 | local dW,dH = 1,1 105 | local padW, padH = (filt-1)/2, (filt-1)/2 106 | local uppath = nn.Sequential() 107 | :add(nn.SpatialConvolution(outChannel, inChannel*ratio^2, filt,filt, dW,dH, padW,padH)) 108 | :add(nn.PixelShuffle(ratio)) 109 | local conv_end = nn.ConcatTable() 110 | :add(uppath) 111 | :add(nn.Identity()) 112 | -- :add(nn.SpatialFullConvolution(outChannel,inChannel,filt,filt,ratio,ratio,pad,pad,adj,adj)) 113 | 114 | return conv_end 115 | end 116 | 117 | function generate_main_model(modeltype) 118 | local modeltype = modeltype or opt.model 119 | local scale_levels = opt.scale_levels 120 | local generate_net 121 | 122 | if modeltype == 'ConvNet' then 123 | generate_net = ConvNet 124 | elseif modeltype == 'ResNet' then 125 | generate_net = ResNet 126 | else 127 | error('unknown model type') 128 | end 129 | 130 | local model 131 | 132 | if scale_levels == 1 then 133 | 134 | model = generate_net(nlayers, filtsize, inChannel, outChannel, opt.nStates) 135 | 136 | model:insert(nn.Copy(default_type, operate_type), 1) 137 | model:insert(nn.AddConstant(-average, true), 2) 138 | model:add(nn.AddConstant(average, true)) 139 | model = nn.ParallelTable():add(model) 140 | 141 | else 142 | local conv_coarse = generate_net(nlayers, filtsize, inChannel, outChannel, opt.nStates) 143 | conv_coarse:add(generate_conv_end(inChannel, outChannel, ratio)) 144 | 145 | local conv_fine = generate_net(nlayers, filtsize, inChannel+outChannel, outChannel, opt.nStates) 146 | conv_fine:insert(nn.JoinTable(2), 1) 147 | conv_fine:add(generate_conv_end(inChannel, outChannel)) 148 | 149 | local conv_finest = generate_net(nlayers, filtsize, inChannel+outChannel, outChannel, opt.nStates) 150 | conv_finest:insert(nn.JoinTable(2), 1) 151 | 152 | -- local conv_finest = conv_fine:clone() 153 | -- conv_finest:remove() 154 | 155 | model = nn.Sequential() 156 | do -- coarse 157 | local submodel = nn.Sequential() 158 | local submodel_par = nn.ParallelTable() 159 | for lv = 1, scale_levels do 160 | local subpath = nn.Sequential() 161 | subpath:add(nn.Copy(default_type, operate_type)) 162 | subpath:add(nn.AddConstant(-average, true)) 163 | if lv == scale_levels then 164 | subpath:add(conv_coarse) 165 | end 166 | submodel_par:add(subpath) 167 | end 168 | submodel:add(submodel_par) 169 | submodel:add(nn.FlattenTable()) 170 | model:add(submodel) 171 | end 172 | for i = scale_levels-1, 2, -1 do -- fine 173 | local submodel = nn.Sequential() 174 | local subconcat = nn.ConcatTable() 175 | 176 | local subconcat_path1 = nn.NarrowTable(1,i-1) 177 | local subconcat_path2 = nn.Sequential():add(nn.NarrowTable(i, 2)) 178 | subconcat_path2:add(conv_fine:clone()) 179 | local subconcat_path3 = nn.NarrowTable(i+2, scale_levels-i) 180 | 181 | subconcat:add(subconcat_path1) 182 | subconcat:add(subconcat_path2) 183 | subconcat:add(subconcat_path3) 184 | submodel:add(subconcat) 185 | submodel:add(nn.FlattenTable()) 186 | 187 | model:add(submodel:clone()) 188 | end 189 | do -- finest 190 | local submodel = nn.Sequential() 191 | local subconcat = nn.ConcatTable() 192 | local subconcat_path1 = nn.Sequential():add(nn.NarrowTable(1,2)) 193 | subconcat_path1:add(conv_finest) 194 | 195 | subconcat:add(subconcat_path1) 196 | for j = 2, scale_levels do 197 | subconcat:add(nn.SelectTable(j+1)) 198 | end 199 | submodel:add(subconcat) 200 | model:add(submodel) 201 | end 202 | do -- post processing 203 | local endmodel_par = nn.ParallelTable() 204 | for lv = 1, scale_levels do 205 | endmodel_par:add(nn.AddConstant(average, true)) 206 | end 207 | model:add(endmodel_par) 208 | end 209 | 210 | end 211 | 212 | model = cudnn.convert(model, cudnn):cuda() 213 | model:reset() 214 | 215 | return model 216 | end 217 | 218 | function generate_discriminator() 219 | 220 | local function conv_block(filtsize, inStates, nStates, str, negval, pad) 221 | local filtsize = filtsize or 3 222 | local pad = pad or (filtsize-1)/2 223 | local str = str or 1 224 | -- local negval = nil 225 | local block = nn.Sequential() 226 | :add(nn.SpatialConvolution(inStates,nStates, filtsize,filtsize, str,str, pad,pad):noBias()) 227 | :add(nn.LeakyReLU(negval, true)) 228 | 229 | return block 230 | end 231 | 232 | local filtsize = filtsize or 5 233 | local pad = (filtsize-1)/2 234 | 235 | local nFeat = opt.nStates 236 | local conv_front = nn.SpatialConvolution(nChannel,nFeat/2, filtsize,filtsize, 1,1, pad,pad):noBias() 237 | local negval = 0.2 -- nil 238 | --[[ 239 | local dense = nn.SpatialConvolution(1024,1, 1,1) 240 | local model = nn.Sequential() 241 | :add(conv_front):add(nn.LeakyReLU(negval, true)) 242 | :add(conv_block(filtsize, 32,32, 2, negval)) -- 128 243 | :add(conv_block(filtsize, 32,64, 1, negval)) 244 | :add(conv_block(filtsize, 64,64, 2, negval)) -- 64 245 | :add(conv_block(filtsize, 64,128, 1, negval)) 246 | :add(conv_block(filtsize, 128,128, 2, negval)) -- 32 247 | :add(conv_block(filtsize, 128,256, 1, negval)) 248 | :add(conv_block(filtsize, 256,256, 4, negval)) -- 8 249 | :add(conv_block(filtsize, 256,512, 1, negval)) 250 | :add(conv_block(filtsize, 512,512, 4, negval)) -- 2 251 | :add(conv_block(filtsize, 512,1024, 2, negval)) -- 1 252 | :add(dense):add(nn.Sigmoid()) 253 | ]]-- 254 | local dense = nn.SpatialConvolution(nFeat*8,1, 1,1) 255 | local model = nn.Sequential() 256 | :add(nn.SelectTable(1)) 257 | :add(conv_front):add(nn.LeakyReLU(negval, true)) 258 | :add(conv_block(filtsize, nFeat/2,nFeat/2, 2, negval)) -- 128 259 | :add(conv_block(filtsize, nFeat/2,nFeat, 1, negval)) 260 | :add(conv_block(filtsize, nFeat,nFeat, 2, negval)) -- 64 261 | :add(conv_block(filtsize, nFeat,nFeat*2, 1, negval)) 262 | :add(conv_block(filtsize, nFeat*2,nFeat*2, 4, negval)) -- 16 263 | :add(conv_block(filtsize, nFeat*2,nFeat*4, 1, negval)) 264 | :add(conv_block(filtsize, nFeat*4,nFeat*4, 4, negval)) -- 4 265 | :add(conv_block(filtsize, nFeat*4,nFeat*8, 1, negval)) 266 | :add(conv_block(4, nFeat*8,nFeat*8, 4, negval, 0)) -- 1 filtsize 5 is equivalent to 4 here 267 | :add(dense):add(nn.Sigmoid()) 268 | 269 | model = cudnn.convert(model, cudnn):cuda() 270 | model:reset() 271 | 272 | return model 273 | end 274 | 275 | function load_main_model(epochNumber) 276 | -- load trained net 277 | model = nil 278 | model_container = nil 279 | parameters, gradParameters = nil, nil 280 | collectgarbage() 281 | collectgarbage() 282 | 283 | local epochNumber = epochNumber or opt.epochNumber - 1 284 | local modelname = paths.concat(opt.save, 'models', 'model-'.. epochNumber .. '.t7') 285 | print('==> loading model from ' .. modelname) 286 | 287 | assert(paths.filep(modelname), 'no trained model found!') 288 | local model = torch.load(modelname) 289 | if torch.type(model) ~= 'table' then -- backward compatibility 290 | model = {G = model} 291 | if adv_train then 292 | model.D = generate_discriminator() 293 | end 294 | end 295 | 296 | return model 297 | 298 | end 299 | 300 | function reduce_model(nFeat_new) 301 | local nFeat_new = nFeat_new or opt.nStates 302 | -- assume nFeat_new is equal to or less than previous opt.nStates 303 | 304 | local reducer = function(module) 305 | local layername = torch.type(module) 306 | if layername:find('SpatialConvolution') then 307 | 308 | local nInputPlance = math.min(nFeat_new, module.nInputPlane) 309 | local nOutputPlane = math.min(nFeat_new, module.nOutputPlane) 310 | 311 | local conv = nn.SpatialConvolution(nInputPlance,nOutputPlane, 312 | module.kW,module.kH, module.dW,module.dH, module.padW,module.padH) 313 | if layername:find('cudnn') then 314 | conv = cudnn.convert(conv, cudnn) 315 | end 316 | conv:type(module._type) 317 | 318 | if opt.reduce_method == 'simple' then 319 | 320 | conv.weight:copy(module.weight:sub(1,conv.nOutputPlane, 1,conv.nInputPlane)) 321 | if module.bias then 322 | conv.bias:copy(module.bias:sub(1, conv.nOutputPlane)) 323 | else 324 | conv:noBias() 325 | end 326 | 327 | elseif opt.reduce_method == 'cluster' then 328 | 329 | end 330 | 331 | collectgarbage() 332 | collectgarbage() 333 | 334 | return conv 335 | else 336 | return module 337 | end 338 | end 339 | 340 | -- model.G reduction 341 | model.G:replace(reducer) 342 | 343 | end 344 | 345 | function save_main_model(epochNumber) 346 | 347 | model.G:clearState() -- do not use model_container:clearState() 348 | if adv_train then 349 | model.D:clearState() 350 | end 351 | collectgarbage() 352 | collectgarbage() 353 | local filename = paths.concat(model_dir, 'model-'..epochNumber..'.t7') 354 | print('==> saving model to ' .. filename..'\n') 355 | torch.save(filename, model) 356 | 357 | return 358 | end 359 | 360 | function get_model_parameters() 361 | -- assume there is a global variable: model, model.G, (model.D) 362 | parameters, gradParameters = {}, {} 363 | for k, v in next, model do 364 | parameters[k], gradParameters[k] = model[k]:getParameters() 365 | end 366 | 367 | return parameters, gradParameters 368 | end 369 | 370 | ---------------------------------------------------------------------- 371 | print '==> construct model' 372 | 373 | if opt.load or opt.continue then 374 | model = load_main_model(opt.epochNumber-1) 375 | else 376 | if paths.filep(opt.loadmodel) then 377 | model = torch.load(opt.loadmodel) 378 | if opt.reduce_model then 379 | reduce_model(opt.nStates) 380 | end 381 | else 382 | model = {} 383 | model.G = generate_main_model(opt.model) 384 | if adv_train then 385 | model.D = generate_discriminator() 386 | end 387 | end 388 | end 389 | 390 | if opt.type == 'cuda' then 391 | model.G:cuda() 392 | if adv_train then 393 | model.D:cuda() 394 | end 395 | elseif opt.type == 'cudaHalf' then 396 | -- convert loaded model to fp16 model 397 | print('Converting to CudaHalfTensor') 398 | model.G:type('torch.CudaHalfTensor') 399 | if adv_train then 400 | model.D:type('torch.CudaHalfTensor') 401 | end 402 | end 403 | 404 | parameters, gradParameters = get_model_parameters() 405 | 406 | if opt.prune_ratio > 0 then 407 | local val, ind = torch.abs(parameters.G):sort() 408 | local nElems = parameters.G:nElement() 409 | local nZeros = math.floor(nElems * opt.prune_ratio) 410 | if nZeros > 0 and nZeros < nElems then 411 | parameters.G:scatter(1, ind[{{1, nZeros}}], 0) 412 | end 413 | collectgarbage() 414 | collectgarbage() 415 | print('Weight parameters zeroed: ' .. nZeros .. '/' .. nElems) 416 | end 417 | ---------------------------------------------------------------------- 418 | do 419 | model_container = nn.Sequential() 420 | :add(model.G) 421 | if adv_train then 422 | model_container:add( 423 | nn.ConcatTable() 424 | :add(nn.Identity()) 425 | :add(model.D) 426 | ) 427 | end 428 | end 429 | 430 | if not (opt.load or opt.continue) then 431 | print '==> here is the model: Generator' 432 | print(model.G) 433 | if adv_train then 434 | print '==> here is the model: Discriminator' 435 | print(model.D) 436 | end 437 | end 438 | 439 | collectgarbage() 440 | collectgarbage() -------------------------------------------------------------------------------- /code/opts.lua: -------------------------------------------------------------------------------- 1 | 2 | function table.ismember(t, item) 3 | for key, value in next, t do 4 | if value == item then 5 | return true 6 | end 7 | end 8 | return false 9 | end 10 | 11 | local M = {} 12 | 13 | function M.parse(arg) 14 | print 'Non-uniform blind deblurring with CNN' 15 | print '==> processing options' 16 | 17 | local cmd = torch.CmdLine() 18 | cmd:text() 19 | cmd:text('Options:') 20 | -- global: 21 | cmd:option('-seed', 0, 'If nonzero, fixed input seed for repeatable experiments. If 0, then random seed') 22 | cmd:option('-threads', 2, 'number of main threads') 23 | cmd:option('-nDonkeys', 4, 'number of donkeys to initialize (data loading threads)') 24 | cmd:option('-gpuid', 1, 'GPU id to use, 1-based') 25 | -- data: 26 | cmd:option('-dataset', 'GOPRO_Large', 'dataset to use for training : GOPRO | GOPRO_Large') 27 | cmd:option('-blur_type', 'gamma2.2', 'camera response function: linear | gamma2.2') 28 | cmd:option('-scale_levels', 3, '1 for base scale only, multiply 0.5 scale_levels times') 29 | -- model: 30 | cmd:option('-model', 'ResNet', 'type of model to construct: ConvNet | ResNet') 31 | cmd:option('-supWidth', 256, 'supWidth') 32 | cmd:option('-nStates', 64, '# of hidden units') 33 | cmd:option('-filtsize', 5, 'filter size') 34 | cmd:option('-nlayers', 40, 'number of conv layers at each scale: at least 1') 35 | cmd:option('-reduce_model', false, 'reduce pre-trained model features') 36 | cmd:option('-reduce_method', 'simple', 'model reduction method: simple | cluster') 37 | cmd:option('-prune_ratio', 0, 'ratio of feature maps to be zeroed out. 0 <= r < 1') 38 | -- loss: 39 | cmd:option('-abs_weight', 0, 'weight of L1 loss. At least one loss should be positive') 40 | cmd:option('-mse_weight', 1, 'weight of L2 loss. At least one loss should be positive') 41 | cmd:option('-adv_weight', 1e-4, 'weight of adversarial loss. At least one loss should be positive') 42 | -- training: 43 | now = os.date("%Y-%m-%d %H-%M-%S") 44 | cmd:option('-save', now, 'subdirectory to save/log experiments in') 45 | cmd:option('-nEpochs', -1, 'Number of total epochs to run') 46 | cmd:option('-epochNumber', 1, 'Manual epoch number (useful on restarts)') 47 | cmd:option('-epochbatchSize', 4000, 'epoch batch size') 48 | cmd:option('-minibatchSize', 4, 'mini-batch size') 49 | cmd:option('-loadmodel', '.', 'load pretrained network') 50 | cmd:option('-train_only', false, 'if true, do not test while training: true | false') 51 | -- optimization 52 | cmd:option('-optimization', 'ADAM', 'optimization method: SGD | ADADELTA | ADAM | RMSPROP') 53 | cmd:option('-rateLearning', 5e-5, 'initial learning rate') 54 | cmd:option('-weightDecay', 0, 'weight decay (SGD only)')--1e-6 55 | cmd:option('-momentum', 0.9, 'momentum (SGD only)') 56 | cmd:option('-beta1', 0.9, 'first momentum coefficient (ADAM)') 57 | cmd:option('-beta2', 0.999, 'first momentum coefficient (ADAM)') 58 | cmd:option('-epsilon', 1e-8, 'first momentum coefficient (ADAM)') 59 | cmd:option('-type', 'cuda', 'type: float | cuda | cudaHalf') 60 | 61 | -- continue experiment 62 | cmd:option('-load', false, 'load trained data. You may continue training') 63 | cmd:option('-continue', false, 'continue experiment.') 64 | 65 | cmd:text() 66 | local opt = cmd:parse(arg or {}) 67 | opt.save = paths.concat('../experiment', opt.save) 68 | if opt.load or opt.continue then 69 | local opt_old = torch.load(paths.concat(opt.save, 'opt')); 70 | local update_list = {} 71 | table.insert(update_list, 'save') 72 | table.insert(update_list, 'gpuid') 73 | table.insert(update_list, 'optimization') 74 | table.insert(update_list, 'rateLearning') 75 | table.insert(update_list, 'load') 76 | table.insert(update_list, 'continue') 77 | table.insert(update_list, 'nEpochs') 78 | table.insert(update_list, 'epochNumber') 79 | table.insert(update_list, 'dataset') 80 | table.insert(update_list, 'seed') 81 | table.insert(update_list, 'threads') 82 | table.insert(update_list, 'epochbatchSize') 83 | table.insert(update_list, 'minibatchSize') 84 | table.insert(update_list, 'train_only') 85 | table.insert(update_list, 'blur_type') 86 | table.insert(update_list, 'type') 87 | if opt.reduce_model then 88 | table.insert(update_list, 'nStates') 89 | end 90 | 91 | for key, value in next, opt_old do -- do not use ipairs 92 | if not table.ismember(update_list, key) then 93 | opt[key] = value 94 | end 95 | end 96 | 97 | if opt.epochNumber == 1 then -- if not set, then continue from the end 98 | local model_dir = paths.concat(opt.save, 'models') 99 | local max_epoch = 1 100 | for modelname in paths.iterfiles(model_dir) do 101 | local iter = tonumber(modelname:sub(7, -4)) 102 | if iter then -- nil if not number 103 | max_epoch = math.max(max_epoch, iter) 104 | end 105 | end 106 | opt.epochNumber = max_epoch + 1 107 | end 108 | end 109 | 110 | -- nb of threads and fixed seed (for repeatable experiments) 111 | if opt.threads <= 0 then 112 | opt.threads = 1 113 | end 114 | torch.setnumthreads(opt.threads) 115 | 116 | if opt.seed == 0 then -- not fixed 117 | opt.seed = torch.seed() 118 | else 119 | torch.manualSeed(opt.seed) 120 | end 121 | print(string.format('Starting main thread with seed: %d', opt.seed)) 122 | 123 | if opt.nEpochs <= 0 then 124 | opt.nEpochs = math.huge -- train forever 125 | end 126 | 127 | if opt.type == 'float' then 128 | print('==> switching to floats') 129 | operate_type = default_type 130 | elseif opt.type:find('cuda') then 131 | print('==> switching to CUDA') 132 | if opt.type == 'cuda' then 133 | operate_type = 'torch.CudaTensor' 134 | elseif opt.type == 'cudaHalf' then 135 | operate_type = 'torch.CudaHalfTensor' 136 | if not (opt.load or opt.continue) then 137 | opt.epsilon = math.sqrt(opt.epsilon) 138 | end 139 | end 140 | cutorch.setDevice(opt.gpuid) 141 | -- if cutorch.getDeviceCount() >= (opt.gpuid + opt.ngpu - 1) then 142 | -- cutorch.setDevice(opt.gpuid) 143 | -- end 144 | end 145 | 146 | return opt 147 | end 148 | 149 | return M 150 | -------------------------------------------------------------------------------- /code/test.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- torch 2 | require 'xlua' -- xlua provides useful tools, like progress bars 3 | require 'optim' -- an optimization package, for online and batch methods 4 | require 'image' 5 | 6 | ---------------------------------------------------------------------- 7 | -- parse command line arguments 8 | 9 | if not opt then 10 | local opts = paths.dofile('opts.lua') 11 | opt = opts.parse(arg) 12 | end 13 | 14 | ---------------------------------------------------------------------- 15 | print '==> defining test procedure' 16 | 17 | if opt.load or opt.continue then 18 | test_error = load_record('test', 'error') 19 | test_psnr = load_record('test', 'psnr') 20 | else 21 | test_error = {} 22 | test_psnr = {} 23 | save_record('test', 'error', test_error) 24 | save_record('test', 'psnr', test_psnr) 25 | end 26 | 27 | local test_dir = paths.concat(opt.save, 'test') 28 | if not paths.dirp(test_dir) then 29 | paths.mkdir(test_dir) 30 | end 31 | 32 | function get_output(input_img_pyramid) 33 | 34 | local temp = input_img_pyramid 35 | local temp_img_pyramid = {} 36 | local output_img_pyramid = {} 37 | if opt.scale_levels == 1 then 38 | temp = input_img_pyramid[1] 39 | for i = 1, #model.G:get(1) do 40 | temp = model.G:get(1):get(i):clone():forward(temp) 41 | collectgarbage() 42 | collectgarbage() 43 | end 44 | output_img_pyramid[1] = temp:float() 45 | else -- opt.scale_levels > 1 46 | for i = 1, #model.G do 47 | if i < opt.scale_levels then 48 | temp = model.G:get(i):clone():forward(temp) 49 | temp_img_pyramid[opt.scale_levels-i+1] = temp[opt.scale_levels-i+2]:clone() 50 | elseif i == opt.scale_levels then 51 | local finemodel = model.G:get(i):get(1):get(1):get(2):clone() 52 | for j = 1, #finemodel do 53 | if j == 1 then 54 | temp = finemodel:get(j):clone():forward({table.unpack(temp, 1, 2)}) 55 | else 56 | temp = finemodel:get(j):clone():forward(temp) 57 | end 58 | collectgarbage() 59 | collectgarbage() 60 | end 61 | temp_img_pyramid[1] = temp:clone() 62 | elseif i == opt.scale_levels + 1 then 63 | output_img_pyramid = model.G:get(i):clone():forward(temp_img_pyramid) 64 | for lv, img in ipairs(output_img_pyramid) do 65 | output_img_pyramid[lv] = img:float() 66 | end 67 | end 68 | end 69 | end 70 | collectgarbage() 71 | collectgarbage() 72 | 73 | return output_img_pyramid 74 | end 75 | 76 | function gen_pyramid(img, scale_levels) 77 | local scale_levels = scale_levels or opt.scale_levels 78 | local scales = {} 79 | for i = 1, scale_levels do 80 | scales[i] = 0.5^(i-1) 81 | end 82 | 83 | local average = img:mean(2):mean(3):squeeze() 84 | for channel = 1, img:size(1) do 85 | img[channel] = img[channel] - average[channel] 86 | end 87 | 88 | local pyramid = image.gaussianpyramid(img, scales) 89 | for lv = 1, scale_levels do 90 | for channel = 1, img:size(1) do 91 | pyramid[lv][channel] = pyramid[lv][channel] + average[channel] 92 | end 93 | pyramid[lv] = pyramid[lv]:repeatTensor(1,1,1,1) 94 | end 95 | 96 | collectgarbage() 97 | collectgarbage() 98 | 99 | return pyramid 100 | end 101 | 102 | function get_output_img(input_img, input_img_name) 103 | local input_img = input_img 104 | if not input_img then 105 | input_img = image.load(input_img_name, 3) 106 | end 107 | 108 | local orig_input_img = input_img:clone() 109 | local orig_imsize = orig_input_img:size() 110 | 111 | local height, width = orig_imsize[2], orig_imsize[3] 112 | local inv_scale = 2^(opt.scale_levels-1) 113 | local pad_height, pad_width = math.fmod(-height, inv_scale), math.fmod(-width, inv_scale) 114 | if pad_height < 0 then 115 | pad_height = pad_height + inv_scale 116 | local row_to_pad = input_img:sub(1,-1, -1,-1, 1,-1):repeatTensor(1,pad_height,1) 117 | input_img = torch.cat({input_img, row_to_pad}, 2) 118 | end 119 | if pad_width < 0 then 120 | pad_width = pad_width + inv_scale 121 | local col_to_pad = input_img:sub(1,-1, 1,-1, -1,-1):repeatTensor(1,1,pad_width) 122 | input_img = torch.cat({input_img, col_to_pad}, 3) 123 | end 124 | 125 | local imsize = input_img:size() 126 | local input_img_pyramid = gen_pyramid(input_img) 127 | 128 | local output_img_pyramid = get_output(input_img_pyramid) 129 | local output_img = output_img_pyramid[1][1]:sub(1,-1, 1,height, 1,width):contiguous() 130 | 131 | collectgarbage() 132 | collectgarbage() 133 | 134 | return output_img 135 | end 136 | 137 | function get_output_img_part(input_img, input_img_name, xmin, xmax, ymin, ymax, margin) 138 | local input_img = input_img 139 | if not input_img then 140 | input_img = image.load(input_img_name, 3) 141 | end 142 | local margin = margin or 16 143 | 144 | local output_img = input_img:clone() 145 | local orig_imsize = output_img:size() 146 | 147 | local height, width = orig_imsize[2], orig_imsize[3] 148 | 149 | local xmin, ymin = xmin or 1, ymin or 1 150 | local xmax, ymax = xmax or width, ymax or height 151 | xmin, ymin = math.max(1, xmin), math.max(1, ymin) 152 | xmax, ymax = math.min(xmax, width), math.min(ymax, height) 153 | local xmin_, ymin_ = math.max(1, xmin-margin), math.max(1, ymin-margin) 154 | local xmax_, ymax_ = math.min(xmax+margin, width), math.min(ymax+margin, height) 155 | local margin_xmin, margin_ymin = xmin - xmin_, ymin - ymin_ 156 | local margin_xmax, margin_ymax = xmax_ - xmax, ymax_ - ymax 157 | 158 | local blur_part = input_img:sub(1,-1, ymin_,ymax_, xmin_,xmax_):contiguous() 159 | 160 | local deblurred_part = get_output_img(blur_part) 161 | output_img[{{1,-1}, {ymin,ymax}, {xmin,xmax}}]:copy( 162 | deblurred_part:sub(1,-1, margin_ymin+1,-margin_ymax-1, margin_xmin+1,-margin_xmax-1) 163 | ) 164 | 165 | collectgarbage() 166 | collectgarbage() 167 | 168 | return output_img 169 | end 170 | 171 | function deblur_dir(image_dir, output_dir) 172 | model.G:evaluate() 173 | local image_dir = image_dir 174 | local output_dir = output_dir or paths.concat(opt.save, 'deblur_result') 175 | if not paths.dirp(output_dir) then 176 | paths.mkdir(output_dir) 177 | end 178 | 179 | for img_name in paths.iterfiles(image_dir) do 180 | local fullname = paths.concat(image_dir, img_name) 181 | local output_img = get_output_img(nil, fullname) 182 | image.save(paths.concat(output_dir, img_name), output_img) 183 | end 184 | collectgarbage() 185 | collectgarbage() 186 | 187 | return 188 | end 189 | 190 | -- test function 191 | function test(epochNumber, save_result, full_data) 192 | 193 | local epochNumber = epochNumber or epoch 194 | model.G:evaluate() 195 | -- local vars 196 | local timer = torch.Timer() 197 | 198 | -- test over test data 199 | print('==> testing on test set:') 200 | local cError = 0 201 | local cPSNR = 0 202 | 203 | local test_list = test_list 204 | local scale_levels = opt.scale_levels 205 | 206 | local test_count = 0 207 | for seq_id, seq_name in ipairs(test_list) do 208 | local test_size = #test_list[seq_id][blur_key] 209 | if not full_data then 210 | test_size = math.min(test_size, 10) 211 | end 212 | 213 | for frame_id = 1, test_size do 214 | -- queue jobs to data-workers 215 | donkeys:addjob( 216 | function() 217 | local seq_id, frame_id = seq_id, frame_id 218 | local scale_levels = scale_levels 219 | return generate_testpair(seq_id, frame_id, scale_levels) 220 | end, 221 | -- the end callback (runs in the main thread) 222 | function(input_img_pyramid, target_img_pyramid) 223 | 224 | local output_img_pyramid = get_output(input_img_pyramid) 225 | local mse = (output_img_pyramid[1]:cuda() - target_img_pyramid[1]:cuda()):pow(2):mean() 226 | local psnr = -10*math.log10(mse) 227 | 228 | cError = cError + mse 229 | cPSNR = cPSNR + psnr 230 | 231 | test_count = test_count + 1 232 | 233 | if save_result then 234 | local input_img_name = test_list[seq_id][blur_key][frame_id] 235 | local inputname = 'seq_id-'..seq_id..' frame-'.. frame_id .. ' Input.png'; 236 | local outputname = 'seq_id-'..seq_id..' frame-'.. frame_id .. ' Output.png'; 237 | 238 | file.copy(input_img_name, paths.concat(test_dir, inputname)) 239 | image.save(paths.concat(test_dir, outputname), output_img_pyramid[1]:squeeze(1)) 240 | end 241 | end 242 | ) 243 | end 244 | end 245 | donkeys:synchronize() 246 | cutorch.synchronize() 247 | collectgarbage() 248 | collectgarbage() 249 | 250 | test_error[epochNumber] = cError / test_count 251 | test_psnr[epochNumber] = cPSNR / test_count 252 | print('average PSNR(test) : ' .. test_psnr[epochNumber]) 253 | 254 | save_record('test', 'error', test_error) 255 | save_record('test', 'psnr', test_psnr) 256 | 257 | draw_error_plot(test_error, 'Test', 'MSE') 258 | draw_error_plot(test_psnr, 'Test', 'PSNR') 259 | 260 | 261 | local time = timer:time().real 262 | print("==> time to test = " .. time/60 .. ' min') 263 | print("==> time per image = " .. time/test_count .. ' sec\n') 264 | 265 | return 266 | end 267 | 268 | -- fill in the missing test error 269 | function follow_up_test(test_begin, test_end, force_all, full_data) 270 | 271 | local test_begin = test_begin or 1 272 | local test_end = test_end or opt.epochNumber - 1 273 | 274 | local tested = false 275 | for epochNumber = test_begin, test_end do 276 | if force_all or test_error[epochNumber] == nil then 277 | model = load_main_model(epochNumber) 278 | test(epochNumber, false, full_data) 279 | tested = true 280 | end 281 | end 282 | 283 | return tested 284 | end 285 | 286 | function demo(image_dir, output_dir) 287 | model.G:evaluate() 288 | local image_dir = image_dir 289 | local output_dir = output_dir or paths.concat(opt.save, 'deblur_result') 290 | if not paths.dirp(output_dir) then 291 | paths.mkdir(output_dir) 292 | end 293 | 294 | local window 295 | 296 | for img_name in paths.iterfiles(image_dir) do 297 | 298 | local fullname = paths.concat(image_dir, img_name) 299 | img = image.load(fullname) 300 | window = image.display{image = img, min=0, max=1, offscreen = false, win = window}--, gui = false} 301 | 302 | local timer = torch.Timer() 303 | local output_img = get_output_img(nil, fullname) 304 | local deblur_time = timer:time().real 305 | print(' ' .. deblur_time .. ' s taken') 306 | -- window = image.display{image = output_img, offscreen = false, win = window} 307 | window = image.display{image = output_img, min=0, max=1, offscreen = false, win = window}--, gui = false} 308 | image.save(paths.concat(output_dir, img_name), output_img) 309 | sys.sleep(3) 310 | end 311 | 312 | end 313 | 314 | 315 | 316 | 317 | 318 | 319 | -------------------------------------------------------------------------------- /code/train.lua: -------------------------------------------------------------------------------- 1 | require 'torch' -- torch 2 | require 'torchx' 3 | require 'cutorch' 4 | require 'xlua' -- xlua provides useful tools, like progress bars 5 | require 'optim' -- an optimization package, for online and batch methods 6 | require 'image' -- for rotating and flipping patches 7 | require 'math' -- to calculate base kernels 8 | require 'gnuplot' -- to visualize error plot 9 | 10 | function draw_error_plot(error_table, mode, error_type) 11 | -- visualize error plot 12 | 13 | local mode = mode or 'Train' -- 'Train' or 'Test' 14 | local legend = mode .. ' ' .. error_type 15 | local title = mode .. ' Plot (' .. error_type .. ')' 16 | local filename = paths.concat(opt.save, title .. '.pdf') 17 | local n = gnuplot.pdffigure(filename) 18 | 19 | if error_type:lower() == 'entropy' then 20 | local legend = {'gen', 'fake', 'real'} 21 | local error_tensors = {} 22 | error_tensors.gen = torch.Tensor(error_table.gen) 23 | error_tensors.fake = torch.Tensor(error_table.fake) 24 | error_tensors.real = torch.Tensor(error_table.real) 25 | 26 | gnuplot.plot( 27 | {legend[1], error_tensors.gen, '-'}, 28 | {legend[2], error_tensors.fake, '-'}, 29 | {legend[3], error_tensors.real, '-'} 30 | ) 31 | else 32 | local error_tensor = torch.Tensor(error_table) 33 | gnuplot.plot(legend, error_tensor, '-') 34 | end 35 | 36 | gnuplot.grid(true) 37 | gnuplot.title(title) 38 | if error_type == 'PSNR' then 39 | gnuplot.movelegend('right', 'bottom') 40 | elseif error_type == 'MSE' then 41 | gnuplot.movelegend('right', 'top') 42 | -- else 43 | -- if error_table[#error_table] > error_table[1] then 44 | -- gnuplot.movelegend('right', 'bottom') 45 | -- else 46 | -- gnuplot.movelegend('right', 'top') 47 | -- end 48 | end 49 | gnuplot.xlabel('iteration') 50 | gnuplot.plotflush(n) 51 | 52 | gnuplot.closeall() 53 | end 54 | 55 | ---------------------------------------------------------------------- 56 | print '==> defining some tools' 57 | ---------------------------------------------------------------------- 58 | print '==> configuring optimizer' 59 | print('optimization algorithm : '..opt.optimization) 60 | 61 | function set_state(optimization) 62 | local optimization = optimization or opt.optimization 63 | local optimState, optimMethod 64 | if optimization == 'SGD' then 65 | optimState = { 66 | learningRate = opt.rateLearning, 67 | weightDecay = opt.weightDecay, 68 | momentum = opt.momentum, 69 | dampening = 0, 70 | learningRateDecay = 1e-5, 71 | nesterov = true 72 | } 73 | optimMethod = optim.sgd 74 | elseif optimization == 'ADADELTA' then 75 | optimState = { 76 | weightDecay = opt.weightDecay 77 | } 78 | optimMethod = optim.adadelta; 79 | elseif optimization == 'ADAM' then 80 | optimState = { 81 | learningRate = opt.rateLearning, 82 | beta1 = opt.beta1, 83 | beta2 = opt.beta2, 84 | epsilon = opt.epsilon, 85 | weightDecay = opt.weightDecay 86 | } 87 | optimMethod = optim.adam 88 | elseif optimization == 'RMSPROP' then 89 | optimState = { 90 | learningRate = opt.rateLearning 91 | } 92 | optimMethod = optim.rmsprop 93 | else 94 | error('unknown optimization method') 95 | end 96 | 97 | return optimState, optimMethod 98 | end 99 | 100 | function load_state() 101 | -- assume set_state is called before 102 | local new_state = torch.load(paths.concat(opt.save, 'train_state.t7')) 103 | new_state.G.learningRate = optimState.G.learningRate 104 | if adv_train then 105 | new_state.D.learningRate = optimState.D.learningRate 106 | end 107 | 108 | return new_state 109 | end 110 | 111 | optimState, optimMethod = {}, {} 112 | optimState.G, optimMethod.G = set_state(opt.optimization) 113 | if adv_train then 114 | optimState.D, optimMethod.D = set_state(opt.optimization) 115 | end 116 | ---------------------------------------------------------------------- 117 | function load_record(mode, type) 118 | local mode = mode or 'train' -- train or test 119 | local filename = paths.concat(opt.save, mode .. '_' .. type ..'.t7') -- ex) train_error.t7 120 | assert(paths.filep(filename), 'No ' .. mode .. ' ' .. type .. ' record found!') 121 | 122 | return torch.load(filename) 123 | end 124 | 125 | function save_record(mode, type, loss) 126 | local mode = mode or 'train' -- train or test 127 | local filename = paths.concat(opt.save, mode .. '_' .. type ..'.t7') 128 | torch.save(filename, loss) 129 | 130 | return 131 | end 132 | 133 | if opt.load or opt.continue then 134 | 135 | if paths.filep(paths.concat(opt.save, 'state.t7')) then 136 | optimState = load_state() 137 | end 138 | 139 | train_error = load_record('train', 'error') 140 | train_psnr = load_record('train', 'psnr') 141 | if adv_train then 142 | train_entropy = load_record('train','entropy') 143 | end 144 | else 145 | train_error = {} 146 | train_psnr = {} 147 | if adv_train then 148 | train_entropy = {} 149 | train_entropy.gen, train_entropy.fake, train_entropy.real = {}, {}, {} 150 | end 151 | 152 | end 153 | 154 | ---------------------------------------------------------------------- 155 | print '==> defining training procedure' 156 | 157 | local abs, mse = 0, 0 158 | local entropy = {gen = 0, fake = 0, real = 0} 159 | local loss = 0 160 | 161 | local blur, sharp, deblurred 162 | local gt = {} 163 | local true_label, false_label 164 | if adv_train then 165 | true_label = torch.ones(opt.minibatchSize) 166 | false_label = torch.zeros(opt.minibatchSize) 167 | if opt.type == 'cuda' then 168 | true_label = true_label:cuda() 169 | false_label = false_label:cuda() 170 | elseif opt.type == 'cudaHalf' then 171 | true_label = true_label:cudaHalf() 172 | false_label = false_label:cudaHalf() 173 | end 174 | gt[2] = true_label 175 | end 176 | 177 | local feval = {} 178 | feval.G = function(x) 179 | model.G:zeroGradParameters() 180 | 181 | local output = model_container:forward(blur) 182 | if adv_train == false then 183 | deblurred, gt = output, sharp 184 | elseif adv_train == true then 185 | deblurred, gt[1] = output[1], sharp 186 | output_label = output[2] 187 | end 188 | 189 | loss = criterion_container(output, gt) 190 | abs = criterion_container.criterions[1].criterions[1].criterions[1].output 191 | mse = criterion_container.criterions[1].criterions[1].criterions[2].output 192 | if adv_train then 193 | entropy.gen = criterion_container.criterions[2].output 194 | end 195 | 196 | model_container:backward(blur, criterion_container.gradInput) 197 | 198 | return loss, gradParameters.G 199 | end 200 | 201 | feval.D = function(x) 202 | model.D:zeroGradParameters() 203 | 204 | -- train with Generator output as negative example 205 | entropy.fake = criterion.D(output_label, false_label) 206 | model.D:backward(deblurred, criterion.D.gradInput) 207 | -- train with GT as a positive example 208 | output_label = model.D:forward(sharp) 209 | entropy.real = criterion.D(output_label, true_label) 210 | model.D:backward(sharp, criterion.D.gradInput) 211 | 212 | return entropy.fake + entropy.real, gradParameters.D 213 | 214 | end 215 | 216 | function trainBatch(inputs, targets, shuffle) 217 | cutorch.synchronize() 218 | 219 | blur = inputs 220 | sharp = {} 221 | for lv, lv_patch in ipairs(targets) do 222 | if opt.type == 'cudaHalf' then 223 | sharp[lv] = lv_patch:cudaHalf() 224 | else 225 | sharp[lv] = lv_patch:cuda() 226 | end 227 | end 228 | 229 | optimMethod.G(feval.G, parameters.G, optimState.G) 230 | if adv_train then 231 | optimMethod.D(feval.D, parameters.D, optimState.D) 232 | end 233 | cutorch.synchronize() 234 | 235 | return 236 | end 237 | 238 | function train() 239 | print('==> doing epoch on training data:') 240 | print("==> online epoch # " .. epoch .. ' [mini-batchSize = ' .. opt.minibatchSize .. ']') 241 | 242 | -- local vars 243 | local timer = torch.Timer() 244 | 245 | -- set model to training mode (for modules that differ in training and testing, like Dropout) 246 | cutorch.synchronize() 247 | model.G:training() 248 | if adv_train then model.D:training() end 249 | 250 | local cABS, cMSE = 0, 0 251 | local cLoss = 0 252 | local cPSNR = 0 253 | local cEntropy = {gen = 0, fake = 0, real = 0} 254 | local minibatch_count = 0 255 | 256 | local function cumulate_error() 257 | cABS = cABS + abs 258 | cMSE = cMSE + mse 259 | cPSNR = cPSNR - 10*math.log10(mse) 260 | 261 | cLoss = cLoss + loss 262 | 263 | if adv_train then 264 | cEntropy.gen = cEntropy.gen + entropy.gen 265 | cEntropy.fake = cEntropy.fake + entropy.fake 266 | cEntropy.real = cEntropy.real + entropy.real 267 | end 268 | 269 | end 270 | do 271 | local opt = opt 272 | local train_list = train_list 273 | for i = 1, opt.epochbatchSize, opt.minibatchSize do 274 | -- queue jobs to data-workers 275 | donkeys:addjob( 276 | -- the job callback (runs in data-worker thread) 277 | function() 278 | return generate_batch() 279 | end, 280 | -- the end callback (runs in the main thread) 281 | function(input_batch, target_batch) 282 | trainBatch(input_batch, target_batch) 283 | cumulate_error() 284 | minibatch_count = minibatch_count + 1 285 | xlua.progress(minibatch_count*opt.minibatchSize, opt.epochbatchSize) 286 | end 287 | ) 288 | end 289 | end 290 | 291 | donkeys:synchronize() 292 | cutorch.synchronize() 293 | 294 | train_error[epoch] = cLoss / minibatch_count 295 | draw_error_plot(train_error, 'Train', 'Loss') 296 | 297 | train_psnr[epoch] = cPSNR / minibatch_count -- this is meaningless when data term is not mse 298 | draw_error_plot(train_psnr, 'Train', 'PSNR') 299 | print('average PSNR(train) : ' .. train_psnr[epoch]) 300 | 301 | if adv_train then 302 | train_entropy.gen[epoch] = cEntropy.gen / opt.adv_weight / minibatch_count 303 | train_entropy.fake[epoch] = cEntropy.fake / opt.adv_weight / minibatch_count 304 | train_entropy.real[epoch] = cEntropy.real / opt.adv_weight / minibatch_count 305 | draw_error_plot(train_entropy, 'Train', 'Entropy') 306 | print('average Entropy(gen) : ' .. train_entropy.gen[epoch]) 307 | print('average Entropy(fake) : ' .. train_entropy.fake[epoch]) 308 | print('average Entropy(real) : ' .. train_entropy.real[epoch]) 309 | end 310 | -- time taken 311 | local time = timer:time().real / 60 312 | print("==> time to learn 1 epoch = " .. time .. ' min') 313 | 314 | collectgarbage() 315 | collectgarbage() 316 | 317 | -- save/log current net 318 | do 319 | opt.epochNumber = epoch 320 | filename = paths.concat(opt.save, 'opt') 321 | torch.save(filename, opt) 322 | 323 | save_main_model(epoch) 324 | save_record('train', 'state', optimState) 325 | 326 | save_record('train', 'error', train_error) 327 | save_record('train', 'psnr', train_psnr) 328 | if adv_train then 329 | save_record('train', 'entropy', train_entropy) 330 | end 331 | end 332 | 333 | end 334 | 335 | -------------------------------------------------------------------------------- /dataset/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | *.png 4 | -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-1 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-1 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-11 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-11 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-2 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-2 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-3 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-3 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-4 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-4 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-5 frame-2 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-5 frame-2 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_gamma/seq_id-6 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_gamma/seq_id-6 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-1 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-1 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-11 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-11 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-2 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-2 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-3 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-3 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-4 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-4 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-5 frame-2 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-5 frame-2 Input.png -------------------------------------------------------------------------------- /dataset/examples/blur_lin/seq_id-6 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/blur_lin/seq_id-6 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-1 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-1 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-11 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-11 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-2 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-2 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-3 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-3 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-4 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-4 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-5 frame-2 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-5 frame-2 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_gamma/seq_id-6 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_gamma/seq_id-6 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-1 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-1 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-11 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-11 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-2 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-2 frame-8 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-3 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-3 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-4 frame-1 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-4 frame-1 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-5 frame-2 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-5 frame-2 Input.png -------------------------------------------------------------------------------- /dataset/examples/deblurred_lin/seq_id-6 frame-8 Input.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/dataset/examples/deblurred_lin/seq_id-6 frame-8 Input.png -------------------------------------------------------------------------------- /experiment/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !.gitignore 3 | -------------------------------------------------------------------------------- /images/Flower_blur1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/images/Flower_blur1.png -------------------------------------------------------------------------------- /images/Flower_sharp1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/images/Flower_sharp1.png -------------------------------------------------------------------------------- /images/Istanbul_blur1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/images/Istanbul_blur1.png -------------------------------------------------------------------------------- /images/Istanbul_sharp1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/images/Istanbul_sharp1.png -------------------------------------------------------------------------------- /images/NTIRE2019.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeungjunNah/DeepDeblur_release/2d5a698560e658718f0520e48dfb15bd52c80118/images/NTIRE2019.jpg --------------------------------------------------------------------------------