├── .gitignore ├── README.md ├── base ├── __init__.py ├── base_data_loader.py ├── base_model.py └── base_trainer.py ├── config.json ├── config_aligned.json ├── data_loader ├── __init__.py ├── data_loader.py └── dataset.py ├── deblur_image.py ├── demo_pic ├── 1.png ├── 2.png ├── deblurred1.png └── deblurred2.png ├── make_aligned_data.py ├── model ├── layer_utils.py ├── loss.py ├── metric.py └── model.py ├── requirements.txt ├── test.py ├── train.py ├── trainer ├── __init__.py └── trainer.py └── utils ├── __init__.py ├── logger.py ├── util.py └── visualization.py /.gitignore: -------------------------------------------------------------------------------- 1 | /pretrained_weights 2 | /data 3 | /aligned_data 4 | /saved 5 | /.idea 6 | */__pycache__/ -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeblurGAN 2 | 3 | An easy-to-read implementation of [DeblurGAN](https://arxiv.org/pdf/1711.07064.pdf) using PyTorch 4 | 5 | ## Some demos of deblurring: 6 | - ![blurred1](demo_pic/1.png) ![deblurred1](demo_pic/deblurred1.png) 7 | 8 | - ![blurred2](demo_pic/2.png) ![deblurred2](demo_pic/deblurred2.png) 9 | 10 | ## Prerequisites 11 | - NVIDIA GPU + CUDA cuDNN 12 | - Python 3.7 13 | 14 | ## Folder Structure 15 | ``` 16 | deblurGAN/ 17 | │ 18 | ├── deblur_image.py - deblur your own images 19 | ├── test.py - evaluation of trained model 20 | ├── train.py - main script to start training 21 | ├── make_aligned_data.py - make aligned data 22 | ├── config.json - demo config file 23 | ├── config_aligned.json - demo config file using aligned dataset 24 | ├── pretrained_weights/ - some pretrained weights for test 25 | │ ├── GAN/ - folder of pretrained weights using GAN loss 26 | │ └── WGAN_GP/ - folder of pretrained weights using WGAN_GP loss 27 | │ 28 | ├── base/ - abstract base classes 29 | │ ├── base_data_loader.py - abstract base class for data loaders 30 | │ ├── base_model.py - abstract base class for models 31 | │ └── base_trainer.py - abstract base class for trainers 32 | │ 33 | ├── data_loader/ - dataloader and dataset 34 | │ ├── data_loader.py 35 | | └── dataset.py 36 | │ 37 | ├── data/ - default directory for storing input data, containing 2 directory for blurred and sharp 38 | │ ├── blurred/ - directory for blurred images 39 | │ └── sharp/ - directory for sharp images 40 | │ 41 | ├── model/ - models, losses, and metrics 42 | │ ├── layer_utils.py 43 | │ ├── loss.py 44 | │ ├── metric.py 45 | │ └── model.py 46 | │ 47 | ├── trainer/ - trainers 48 | │ └── trainer.py 49 | │ 50 | └── utils/ 51 | ├── logger.py - class for train logging 52 | ├── util.py 53 | ├── visualization.py - class for tensorboardX visualization support 54 | └── ... 55 | ``` 56 | 57 | ## Config file format 58 | ``` 59 | { 60 | "name": "DeblurGAN", // training session name 61 | "n_gpu": 1, // number of GPUs to use for training 62 | "data_loader": { // selecting data loader 63 | "type": "GoProDataLoader", 64 | "args": { 65 | "data_dir": "data/", 66 | "batch_size": 1, 67 | "shuffle": false, 68 | "validation_split": 0.1, 69 | "num_workers": 4 70 | } 71 | }, 72 | "generator": { // architecture of generator 73 | "type": "ResNetGenerator", 74 | "args": { 75 | "input_nc": 3, 76 | "output_nc": 3 77 | } 78 | }, 79 | "discriminator": { // architecture of discriminator 80 | "type": "NLayerDiscriminator", 81 | "args": { 82 | "input_nc": 3 83 | } 84 | }, 85 | "loss": { // loss function 86 | "adversarial": "wgan_gp_loss", 87 | "content": "perceptual_loss" 88 | }, 89 | "metrics": [ // list of metrics to evaluate 90 | "PSNR" 91 | ], 92 | "optimizer": { // configuration of the optimizer (both generator and discriminator) 93 | "type": "Adam", 94 | "args": { 95 | "lr": 0.0001, 96 | "betas": [ 97 | 0.5, 98 | 0.999 99 | ], 100 | "weight_decay": 0, 101 | "amsgrad": true 102 | } 103 | }, 104 | "lr_scheduler": { // learning rate scheduler 105 | "type": "LambdaLR", 106 | "args": { 107 | "lr_lambda": "origin_lr_scheduler" 108 | } 109 | }, 110 | "trainer": { // configuration of the trainer 111 | "epochs": 300, 112 | "save_dir": "saved/", 113 | "save_period": 1, 114 | "verbosity": 2, 115 | "monitor": "max PSNR", 116 | "tensorboardX": true, 117 | "log_dir": "saved/runs" 118 | }, 119 | "others": { // other hyperparameters 120 | "gp_lambda": 10, 121 | "content_loss_lambda": 100 122 | } 123 | } 124 | ``` 125 | 126 | ## How to run 127 | * **Train** 128 | ``` 129 | python train.py --config config.json 130 | ``` 131 | 132 | * **Resume** 133 | ``` 134 | python train.py --resume path/to/checkpoint 135 | ``` 136 | 137 | * **Test** 138 | ``` 139 | python test.py --resume path/to/checkpoint 140 | ``` 141 | 142 | * **Deblur** 143 | ``` 144 | python deblur_image.py --blurred path/to/blurred_images --deblurred path/to/deblurred_images --resume path/to/checkpoint 145 | ``` 146 | 147 | * **Make aligned data first if you want to use aligned dataset** 148 | ``` 149 | python make_aligned_data.py --blurred path/to/blurred_images --sharp path/to/sharp_images --aligned path/to/aligned_images 150 | ``` 151 | 152 | ## Tips 153 | - If you want to use gan_loss instead of wgan_gp_loss, use_sigmoid must be set to true in generator. 154 | - Aligned dataset could boost the speed of data_loader a little bit. So run make_aligned_data.py to get aligned dataset before training. 155 | - Pretrained weights of both GAN and WGAN_GP are available. 156 | - **Download pretrained weights: https://drive.google.com/open?id=1w-u0r3hd3cfzSjFuvvuYAs9wA-E-B-11** 157 | 158 | ## Acknowledgements 159 | The organization of this project is based on [PyTorch Template Project](https://github.com/victoresque/pytorch-template) -------------------------------------------------------------------------------- /base/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/base/__init__.py -------------------------------------------------------------------------------- /base/base_data_loader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils.data import DataLoader 3 | from torch.utils.data.dataloader import default_collate 4 | from torch.utils.data.sampler import SubsetRandomSampler 5 | 6 | 7 | class BaseDataLoader(DataLoader): 8 | """ 9 | Base class for all data loaders 10 | """ 11 | def __init__(self, dataset, batch_size, shuffle, validation_split, num_workers, collate_fn=default_collate): 12 | self.validation_split = validation_split 13 | self.shuffle = shuffle 14 | 15 | self.batch_idx = 0 16 | self.n_samples = len(dataset) 17 | 18 | self.sampler, self.valid_sampler = self._split_sampler(self.validation_split) 19 | 20 | self.init_kwargs = { 21 | 'dataset': dataset, 22 | 'batch_size': batch_size, 23 | 'shuffle': self.shuffle, 24 | 'collate_fn': collate_fn, 25 | 'num_workers': num_workers 26 | } 27 | super(BaseDataLoader, self).__init__(sampler=self.sampler, **self.init_kwargs) 28 | 29 | def _split_sampler(self, split): 30 | if split == 0.0: 31 | return None, None 32 | 33 | idx_full = np.arange(self.n_samples) 34 | 35 | np.random.seed(0) 36 | np.random.shuffle(idx_full) 37 | 38 | len_valid = int(self.n_samples * split) 39 | 40 | valid_idx = idx_full[0:len_valid] 41 | train_idx = np.delete(idx_full, np.arange(0, len_valid)) 42 | 43 | train_sampler = SubsetRandomSampler(train_idx) 44 | valid_sampler = SubsetRandomSampler(valid_idx) 45 | 46 | # turn off shuffle option which is mutually exclusive with sampler 47 | self.shuffle = False 48 | self.n_samples = len(train_idx) 49 | 50 | return train_sampler, valid_sampler 51 | 52 | def split_validation(self): 53 | if self.valid_sampler is None: 54 | return None 55 | else: 56 | return DataLoader(sampler=self.valid_sampler, **self.init_kwargs) 57 | -------------------------------------------------------------------------------- /base/base_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(nn.Module): 6 | """ 7 | Base class for all models 8 | """ 9 | 10 | def __init__(self): 11 | super(BaseModel, self).__init__() 12 | self.logger = logging.getLogger(self.__class__.__name__) 13 | 14 | def forward(self, *input): 15 | """ 16 | Forward pass logic 17 | :return: Model output 18 | """ 19 | raise NotImplementedError 20 | 21 | def summary(self): 22 | """ 23 | Model summary 24 | """ 25 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 26 | params = sum([p.numel() for p in model_parameters]) 27 | self.logger.info('Trainable parameters: {}'.format(params)) 28 | self.logger.info(self) 29 | 30 | def __str__(self): 31 | """ 32 | Model prints with number of trainable parameters 33 | """ 34 | model_parameters = filter(lambda p: p.requires_grad, self.parameters()) 35 | params = sum([p.numel() for p in model_parameters]) 36 | return super(BaseModel, self).__str__() + '\nTrainable parameters: {}'.format(params) 37 | -------------------------------------------------------------------------------- /base/base_trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import logging 5 | import datetime 6 | 7 | import torch 8 | 9 | from utils.util import ensure_dir 10 | from utils.visualization import WriterTensorboardX 11 | from model.layer_utils import init_weights 12 | 13 | 14 | class BaseTrainer: 15 | """ 16 | Base class for all trainers 17 | """ 18 | 19 | def __init__(self, config, generator, discriminator, loss, metrics, optimizer, lr_scheduler, resume, train_logger): 20 | self.config = config 21 | self.logger = logging.getLogger(self.__class__.__name__) 22 | 23 | # setup GPU device if available, move model into configured device and init the weights 24 | self.device, device_ids = self._prepare_device(config['n_gpu']) 25 | self.generator = generator.to(self.device) 26 | self.discriminator = discriminator.to(self.device) 27 | self.generator.apply(init_weights) 28 | self.discriminator.apply(init_weights) 29 | if len(device_ids) > 1: 30 | self.generator = torch.nn.DataParallel(generator, device_ids=device_ids) 31 | self.discriminator = torch.nn.DataParallel(discriminator, device_ids=device_ids) 32 | 33 | self.adversarial_loss = loss['adversarial'] 34 | self.content_loss = loss['content'] 35 | self.metrics = metrics 36 | self.generator_optimizer = optimizer['generator'] 37 | self.discriminator_optimizer = optimizer['discriminator'] 38 | self.generator_lr_scheduler = lr_scheduler['generator'] 39 | self.discriminator_lr_scheduler = lr_scheduler['discriminator'] 40 | self.train_logger = train_logger 41 | 42 | cfg_trainer = config['trainer'] 43 | self.epochs = cfg_trainer['epochs'] 44 | self.save_period = cfg_trainer['save_period'] 45 | self.verbosity = cfg_trainer['verbosity'] 46 | self.monitor = cfg_trainer.get('monitor', 'off') 47 | 48 | # configuration to monitor model performance and save best 49 | if self.monitor == 'off': 50 | self.mnt_mode = 'off' 51 | self.mnt_best = 0 52 | else: 53 | self.mnt_mode, self.mnt_metric = self.monitor.split() 54 | assert self.mnt_mode in ['min', 'max'] 55 | 56 | self.mnt_best = math.inf if self.mnt_mode == 'min' else -math.inf 57 | self.early_stop = cfg_trainer.get('early_stop', math.inf) 58 | 59 | self.start_epoch = 1 60 | 61 | # setup directory for checkpoint saving 62 | start_time = datetime.datetime.now().strftime('%m%d_%H%M%S') 63 | self.checkpoint_dir = os.path.join(cfg_trainer['save_dir'], config['name'], start_time) 64 | # setup visualization writer instance 65 | writer_dir = os.path.join(cfg_trainer['log_dir'], config['name'], start_time) 66 | self.writer = WriterTensorboardX(writer_dir, self.logger, cfg_trainer['tensorboardX']) 67 | 68 | # Save configuration file into checkpoint directory 69 | ensure_dir(self.checkpoint_dir) 70 | config_save_path = os.path.join(self.checkpoint_dir, 'config.json') 71 | with open(config_save_path, 'w') as handle: 72 | json.dump(config, handle, indent=4) 73 | 74 | if resume: 75 | self._resume_checkpoint(resume) 76 | 77 | def _prepare_device(self, n_gpu_use): 78 | """ 79 | setup GPU device if available, move model into configured device 80 | """ 81 | n_gpu = torch.cuda.device_count() 82 | if n_gpu_use > 0 and n_gpu == 0: 83 | self.logger.warning("Warning: There's no GPU available on this machine, training will be performed on CPU.") 84 | n_gpu_use = 0 85 | if n_gpu_use > n_gpu: 86 | self.logger.warning( 87 | "Warning: The number of GPU's configured to use is {}, but only {} are available " 88 | "on this machine.".format(n_gpu_use, n_gpu)) 89 | n_gpu_use = n_gpu 90 | device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') 91 | device_ids = list(range(n_gpu_use)) 92 | return device, device_ids 93 | 94 | def train(self): 95 | """ 96 | Full training logic 97 | """ 98 | not_improved_count = 0 99 | for epoch in range(self.start_epoch, self.epochs + 1): 100 | result = self._train_epoch(epoch) 101 | 102 | # save logged informations into log dict 103 | log = {'epoch': epoch} 104 | for key, value in result.items(): 105 | if key == 'metrics': 106 | log.update({mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 107 | elif key == 'val_metrics': 108 | log.update({'val_' + mtr.__name__: value[i] for i, mtr in enumerate(self.metrics)}) 109 | else: 110 | log[key] = value 111 | 112 | # print logged informations to the screen 113 | if self.train_logger is not None: 114 | self.train_logger.add_entry(log) 115 | if self.verbosity >= 1: 116 | for key, value in log.items(): 117 | self.logger.info(' {:15s}: {}'.format(str(key), value)) 118 | 119 | # evaluate model performance according to configured metric, save best checkpoint as model_best 120 | is_best = False 121 | if self.mnt_mode != 'off': 122 | try: 123 | # check whether model performance improved or not, according to specified metric(mnt_metric) 124 | improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \ 125 | (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best) 126 | except KeyError: 127 | self.logger.warning("Warning: Metric '{}' is not found. " 128 | "Model performance monitoring is disabled.".format(self.mnt_metric)) 129 | self.mnt_mode = 'off' 130 | improved = False 131 | not_improved_count = 0 132 | 133 | if improved: 134 | self.mnt_best = log[self.mnt_metric] 135 | not_improved_count = 0 136 | is_best = True 137 | else: 138 | not_improved_count += 1 139 | 140 | if not_improved_count > self.early_stop: 141 | self.logger.info("Validation performance didn't improve for {} epochs. " 142 | "Training stops.".format(self.early_stop)) 143 | break 144 | 145 | if epoch % self.save_period == 0: 146 | self._save_checkpoint(epoch, save_best=is_best) 147 | 148 | def _train_epoch(self, epoch): 149 | """ 150 | Training logic for an epoch 151 | 152 | :param epoch: Current epoch number 153 | """ 154 | raise NotImplementedError 155 | 156 | def _save_checkpoint(self, epoch, save_best=False): 157 | """ 158 | Saving checkpoints 159 | 160 | :param epoch: current epoch number 161 | :param save_best: if True, rename the saved checkpoint to 'model_best.pth' 162 | """ 163 | state = { 164 | 'epoch': epoch, 165 | 'logger': self.train_logger, 166 | 'generator': self.generator.state_dict(), 167 | 'discriminator': self.discriminator.state_dict(), 168 | 'generator_optimizer': self.generator_optimizer.state_dict(), 169 | 'discriminator_optimizer': self.discriminator_optimizer.state_dict(), 170 | 'generator_lr_scheduler': self.generator_lr_scheduler.state_dict(), 171 | 'discriminator_lr_scheduler': self.discriminator_lr_scheduler.state_dict(), 172 | 'monitor_best': self.mnt_best, 173 | 'config': self.config 174 | } 175 | filename = os.path.join(self.checkpoint_dir, 'checkpoint-epoch{}.pth'.format(epoch)) 176 | torch.save(state, filename) 177 | self.logger.info("Saving checkpoint: {} ...".format(filename)) 178 | if save_best: 179 | best_path = os.path.join(self.checkpoint_dir, 'model_best.pth') 180 | torch.save(state, best_path) 181 | self.logger.info("Saving current best: {} ...".format('model_best.pth')) 182 | 183 | def _resume_checkpoint(self, resume_path): 184 | """ 185 | Resume from saved checkpoints 186 | 187 | :param resume_path: Checkpoint path to be resumed 188 | """ 189 | self.logger.info("Loading checkpoint: {} ...".format(resume_path)) 190 | checkpoint = torch.load(resume_path) 191 | self.start_epoch = checkpoint['epoch'] + 1 192 | self.mnt_best = checkpoint['monitor_best'] 193 | 194 | # load params from checkpoint 195 | if checkpoint['config']['name'] != self.config['name']: 196 | self.logger.warning("Warning: Architecture configuration given in config file is different from that of " 197 | "checkpoint. This may yield an exception while state_dict is being loaded.") 198 | self.generator.load_state_dict(checkpoint['generator']) 199 | self.discriminator.load_state_dict(checkpoint['discriminator']) 200 | 201 | # load optimizer state from checkpoint only when optimizer type is not changed. 202 | if checkpoint['config']['optimizer']['type'] != self.config['optimizer']['type']: 203 | self.logger.warning("Warning: Optimizer type given in config file is different from that of checkpoint. " 204 | "Optimizer parameters not being resumed.") 205 | else: 206 | self.generator_optimizer.load_state_dict(checkpoint['generator_optimizer']) 207 | self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer']) 208 | 209 | # load learning scheduler state from checkpoint only when learning scheduler type is not changed. 210 | if checkpoint['config']['lr_scheduler']['type'] != self.config['lr_scheduler']['type']: 211 | self.logger.warning( 212 | "Warning: Learning scheduler type given in config file is different from that of checkpoint. " 213 | "Learning scheduler parameters not being resumed.") 214 | else: 215 | self.generator_lr_scheduler.load_state_dict(checkpoint['generator_lr_scheduler']) 216 | self.discriminator_lr_scheduler.load_state_dict(checkpoint['discriminator_lr_scheduler']) 217 | 218 | self.train_logger = checkpoint['logger'] 219 | self.logger.info("Checkpoint '{}' (epoch {}) loaded".format(resume_path, self.start_epoch)) 220 | -------------------------------------------------------------------------------- /config.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DeblurGAN", 3 | "n_gpu": 1, 4 | "data_loader": { 5 | "type": "GoProDataLoader", 6 | "args": { 7 | "data_dir": "data/", 8 | "batch_size": 1, 9 | "shuffle": false, 10 | "validation_split": 0, 11 | "num_workers": 4 12 | } 13 | }, 14 | "generator": { 15 | "type": "ResNetGenerator", 16 | "args": { 17 | "input_nc": 3, 18 | "output_nc": 3 19 | } 20 | }, 21 | "discriminator": { 22 | "type": "NLayerDiscriminator", 23 | "args": { 24 | "input_nc": 3 25 | } 26 | }, 27 | "loss": { 28 | "adversarial": "wgan_gp_loss", 29 | "content": "perceptual_loss" 30 | }, 31 | "metrics": [ 32 | "PSNR" 33 | ], 34 | "optimizer": { 35 | "type": "Adam", 36 | "args": { 37 | "lr": 0.0001, 38 | "betas": [ 39 | 0.5, 40 | 0.999 41 | ], 42 | "weight_decay": 0, 43 | "amsgrad": true 44 | } 45 | }, 46 | "lr_scheduler": { 47 | "type": "LambdaLR", 48 | "args": { 49 | "lr_lambda": "origin_lr_scheduler" 50 | } 51 | }, 52 | "trainer": { 53 | "epochs": 300, 54 | "save_dir": "saved/", 55 | "save_period": 5, 56 | "verbosity": 2, 57 | "monitor": "max PSNR", 58 | "tensorboardX": true, 59 | "log_dir": "saved/runs" 60 | }, 61 | "others": { 62 | "gp_lambda": 10, 63 | "content_loss_lambda": 100 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /config_aligned.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "DeblurGAN", 3 | "n_gpu": 1, 4 | "data_loader": { 5 | "type": "GoProAlignedDataLoader", 6 | "args": { 7 | "data_dir": "aligned_data/", 8 | "batch_size": 1, 9 | "shuffle": false, 10 | "validation_split": 0, 11 | "num_workers": 4 12 | } 13 | }, 14 | "generator": { 15 | "type": "ResNetGenerator", 16 | "args": { 17 | "input_nc": 3, 18 | "output_nc": 3 19 | } 20 | }, 21 | "discriminator": { 22 | "type": "NLayerDiscriminator", 23 | "args": { 24 | "input_nc": 3 25 | } 26 | }, 27 | "loss": { 28 | "adversarial": "wgan_gp_loss", 29 | "content": "perceptual_loss" 30 | }, 31 | "metrics": [ 32 | "PSNR" 33 | ], 34 | "optimizer": { 35 | "type": "Adam", 36 | "args": { 37 | "lr": 0.0001, 38 | "betas": [ 39 | 0.5, 40 | 0.999 41 | ], 42 | "weight_decay": 0, 43 | "amsgrad": true 44 | } 45 | }, 46 | "lr_scheduler": { 47 | "type": "LambdaLR", 48 | "args": { 49 | "lr_lambda": "origin_lr_scheduler" 50 | } 51 | }, 52 | "trainer": { 53 | "epochs": 300, 54 | "save_dir": "saved/", 55 | "save_period": 5, 56 | "verbosity": 2, 57 | "monitor": "max PSNR", 58 | "tensorboardX": true, 59 | "log_dir": "saved/runs" 60 | }, 61 | "others": { 62 | "gp_lambda": 10, 63 | "content_loss_lambda": 100 64 | } 65 | } 66 | -------------------------------------------------------------------------------- /data_loader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/data_loader/__init__.py -------------------------------------------------------------------------------- /data_loader/data_loader.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from PIL import Image 3 | from torch.utils.data import DataLoader 4 | 5 | from . import dataset 6 | from base.base_data_loader import BaseDataLoader 7 | 8 | 9 | class GoProDataLoader(BaseDataLoader): 10 | """ 11 | GoPro data loader 12 | """ 13 | 14 | def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers): 15 | transform = transforms.Compose([ 16 | transforms.Resize([360, 640], Image.BICUBIC), # downscale by a factor of two (720*1280 -> 360*640) 17 | transforms.ToTensor(), # convert to tensor 18 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize 19 | ]) 20 | self.dataset = dataset.GoProDataset(data_dir, transform=transform, height=360, width=640, fine_size=256) 21 | 22 | super(GoProDataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 23 | 24 | 25 | class GoProAlignedDataLoader(BaseDataLoader): 26 | """ 27 | GoPro aligned data loader 28 | """ 29 | 30 | def __init__(self, data_dir, batch_size, shuffle, validation_split, num_workers): 31 | transform = transforms.Compose([ 32 | transforms.Resize([360, 1280], Image.BICUBIC), # downscale by a factor of two (720*2560 -> 360*1280) 33 | transforms.ToTensor(), # convert to tensor 34 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize 35 | ]) 36 | self.dataset = dataset.GoProAlignedDataset(data_dir, transform=transform, height=360, width=1280, fine_size=256) 37 | 38 | super(GoProAlignedDataLoader, self).__init__(self.dataset, batch_size, shuffle, validation_split, num_workers) 39 | 40 | 41 | class CustomDataLoader(DataLoader): 42 | """ 43 | Custom data loader for image deblurring 44 | """ 45 | 46 | def __init__(self, data_dir): 47 | transform = transforms.Compose([ 48 | transforms.ToTensor(), # convert to tensor 49 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # normalize 50 | ]) 51 | self.dataset = dataset.CustomDataset(data_dir, transform=transform) 52 | 53 | super(CustomDataLoader, self).__init__(self.dataset) 54 | -------------------------------------------------------------------------------- /data_loader/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torch 5 | from PIL import Image 6 | from torchvision import transforms 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class GoProDataset(Dataset): 11 | """ 12 | GoPro dataset 13 | """ 14 | 15 | def __init__(self, data_dir='data', transform=None, height=360, width=640, fine_size=256): 16 | self.blurred_dir = os.path.join(data_dir, 'blurred') 17 | self.sharp_dir = os.path.join(data_dir, 'sharp') 18 | self.image_names = os.listdir(self.blurred_dir) # we assume that blurred and sharp images have the same names 19 | 20 | self.transform = transform 21 | 22 | assert height >= fine_size and width >= fine_size 23 | self.height = height 24 | self.width = width 25 | self.fine_size = fine_size 26 | 27 | def __len__(self): 28 | return len(self.image_names) 29 | 30 | def __getitem__(self, index): 31 | blurred = Image.open(os.path.join(self.blurred_dir, self.image_names[index])).convert('RGB') 32 | sharp = Image.open(os.path.join(self.sharp_dir, self.image_names[index])).convert('RGB') 33 | 34 | if self.transform: 35 | blurred = self.transform(blurred) 36 | sharp = self.transform(sharp) 37 | 38 | # crop image tensor to defined size 39 | # we assume that self.transform contains ToTensor() 40 | assert isinstance(blurred, torch.Tensor) and isinstance(sharp, torch.Tensor) 41 | h_offset = random.randint(0, self.height - self.fine_size) 42 | w_offset = random.randint(0, self.width - self.fine_size) 43 | blurred = blurred[:, h_offset:h_offset + self.fine_size, w_offset:w_offset + self.fine_size] 44 | sharp = sharp[:, h_offset:h_offset + self.fine_size, w_offset:w_offset + self.fine_size] 45 | 46 | return {'blurred': blurred, 'sharp': sharp} 47 | 48 | 49 | class GoProAlignedDataset(Dataset): 50 | """ 51 | GoPro aligned dataset 52 | """ 53 | 54 | def __init__(self, data_dir='aligned_data', transform=None, height=360, width=1280, fine_size=256): 55 | self.data_dir = data_dir 56 | self.image_names = os.listdir(self.data_dir) 57 | 58 | self.transform = transform 59 | 60 | assert height >= fine_size and width >= fine_size * 2 61 | self.height = height 62 | self.width = width 63 | self.fine_size = fine_size 64 | 65 | def __len__(self): 66 | return len(self.image_names) 67 | 68 | def __getitem__(self, index): 69 | aligned = Image.open(os.path.join(self.data_dir, self.image_names[index])).convert('RGB') 70 | 71 | if self.transform: 72 | aligned = self.transform(aligned) 73 | 74 | # crop image tensor to defined size 75 | # we assume that self.transform contains ToTensor() 76 | assert isinstance(aligned, torch.Tensor) 77 | h = self.height 78 | w = int(self.width / 2) 79 | h_offset = random.randint(0, h - self.fine_size) 80 | w_offset = random.randint(0, w - self.fine_size) 81 | blurred = aligned[:, h_offset:h_offset + self.fine_size, w_offset:w_offset + self.fine_size] 82 | sharp = aligned[:, h_offset:h_offset + self.fine_size, w_offset + w:w_offset + w + self.fine_size] 83 | return {'blurred': blurred, 'sharp': sharp} 84 | 85 | else: 86 | return {'aligned': aligned} 87 | 88 | 89 | class CustomDataset(Dataset): 90 | """Custom dataset for image deblurring""" 91 | 92 | def __init__(self, data_dir, transform=None): 93 | self.data_dir = data_dir 94 | self.transform = transform 95 | 96 | self.image_names = os.listdir(self.data_dir) 97 | 98 | def __len__(self): 99 | return len(self.image_names) 100 | 101 | def __getitem__(self, index): 102 | image_name = self.image_names[index] 103 | blurred = Image.open(os.path.join(self.data_dir, image_name)).convert('RGB') 104 | h = blurred.size[1] 105 | w = blurred.size[0] 106 | new_h = h - h % 4 + 4 if h % 4 != 0 else h 107 | new_w = w - w % 4 + 4 if w % 4 != 0 else w 108 | blurred = transforms.Resize([new_h, new_w], Image.BICUBIC)(blurred) 109 | 110 | if self.transform: 111 | blurred = self.transform(blurred) 112 | 113 | return {'blurred': blurred, 'image_name': image_name} 114 | -------------------------------------------------------------------------------- /deblur_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | from tqdm import tqdm 5 | from torchvision.transforms.functional import to_pil_image 6 | import torch 7 | 8 | 9 | def main(blurred_dir, deblurred_dir, resume): 10 | # load checkpoint 11 | checkpoint = torch.load(resume) 12 | config = checkpoint['config'] 13 | 14 | # setup data_loader instances 15 | data_loader = CustomDataLoader(data_dir=blurred_dir) 16 | 17 | # build model architecture 18 | generator_class = getattr(module_arch, config['generator']['type']) 19 | generator = generator_class(**config['generator']['args']) 20 | 21 | # prepare model for deblurring 22 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 23 | generator.to(device) 24 | 25 | generator.load_state_dict(checkpoint['generator']) 26 | 27 | generator.eval() 28 | 29 | # start to deblur 30 | with torch.no_grad(): 31 | for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): 32 | blurred = sample['blurred'].to(device) 33 | image_name = sample['image_name'][0] 34 | 35 | deblurred = generator(blurred) 36 | deblurred_img = to_pil_image(denormalize(deblurred).squeeze().cpu()) 37 | 38 | deblurred_img.save(os.path.join(deblurred_dir, 'deblurred ' + image_name)) 39 | 40 | 41 | if __name__ == '__main__': 42 | parser = argparse.ArgumentParser(description='Deblur your own image!') 43 | 44 | parser.add_argument('-b', '--blurred', required=True, type=str, help='dir of blurred images') 45 | parser.add_argument('-d', '--deblurred', required=True, type=str, help='dir to save deblurred images') 46 | parser.add_argument('-r', '--resume', required=True, type=str, help='path to latest checkpoint') 47 | parser.add_argument('--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 48 | 49 | args = parser.parse_args() 50 | 51 | if args.device: 52 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 53 | 54 | import model.model as module_arch 55 | from data_loader.data_loader import CustomDataLoader 56 | from utils.util import denormalize 57 | 58 | main(args.blurred, args.deblurred, args.resume) 59 | -------------------------------------------------------------------------------- /demo_pic/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/demo_pic/1.png -------------------------------------------------------------------------------- /demo_pic/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/demo_pic/2.png -------------------------------------------------------------------------------- /demo_pic/deblurred1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/demo_pic/deblurred1.png -------------------------------------------------------------------------------- /demo_pic/deblurred2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/demo_pic/deblurred2.png -------------------------------------------------------------------------------- /make_aligned_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import numpy as np 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | from utils.util import ensure_dir 9 | 10 | 11 | def main(blurred_dir, sharp_dir, aligned_dir): 12 | image_names = os.listdir(blurred_dir) # we assume that blurred and sharp images have the same names 13 | ensure_dir(aligned_dir) 14 | for image_name in tqdm(image_names, ascii=True): 15 | # convert PIL image to numpy array (H, W, C) 16 | blurred = np.array(Image.open(os.path.join(blurred_dir, image_name)).convert('RGB'), dtype=np.uint8) 17 | sharp = np.array(Image.open(os.path.join(sharp_dir, image_name)).convert('RGB'), dtype=np.uint8) 18 | aligned = np.concatenate((blurred, sharp), axis=1) # horizontal alignment 19 | Image.fromarray(aligned).save(os.path.join(aligned_dir, image_name)) 20 | 21 | 22 | if __name__ == '__main__': 23 | parser = argparse.ArgumentParser(description='Make aligned data from raw data') 24 | 25 | parser.add_argument('-b', '--blurred', required=True, type=str, help='dir of blurred images') 26 | parser.add_argument('-s', '--sharp', required=True, type=str, help='dir of sharp images') 27 | parser.add_argument('-a', '--aligned', required=True, type=str, help='dir to save aligned images') 28 | 29 | args = parser.parse_args() 30 | 31 | main(args.blurred, args.sharp, args.aligned) 32 | -------------------------------------------------------------------------------- /model/layer_utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torchvision import models 6 | 7 | CONV3_3_IN_VGG_19 = models.vgg19(pretrained=True).features[:15].cuda() 8 | 9 | 10 | def get_norm_layer(norm_type='instance'): 11 | if norm_type == 'batch': 12 | norm_layer = nn.BatchNorm2d 13 | elif norm_type == 'instance': 14 | # we should never set track_running_stats to True in InstanceNorm 15 | # because it behaves differently in training and testing mode 16 | norm_layer = functools.partial(nn.InstanceNorm2d, track_running_stats=False) 17 | else: 18 | raise NotImplementedError('normalization layer [%s] is not found' % norm_type) 19 | return norm_layer 20 | 21 | 22 | def init_weights(m): 23 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 24 | nn.init.normal_(m.weight, 0, 0.02) 25 | if m.bias is not None: 26 | nn.init.zeros_(m.bias) 27 | if isinstance(m, nn.BatchNorm2d): 28 | nn.init.normal_(m.weight, 1, 0.02) 29 | nn.init.zeros_(m.bias) 30 | 31 | 32 | class ResNetBlock(nn.Module): 33 | """ResNet block""" 34 | 35 | def __init__(self, dim, norm_layer, padding_type, use_dropout, use_bias): 36 | super(ResNetBlock, self).__init__() 37 | 38 | sequence = list() 39 | padding = self._chose_padding_type(padding_type, sequence) 40 | 41 | sequence += [ 42 | nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=padding, bias=use_bias), 43 | norm_layer(dim), 44 | nn.ReLU(True) 45 | ] 46 | 47 | if use_dropout: 48 | sequence += [nn.Dropout(0.5)] 49 | 50 | self.model = nn.Sequential(*sequence) 51 | 52 | def _chose_padding_type(self, padding_type, sequence): 53 | padding = 0 54 | if padding_type == 'reflect': 55 | sequence += [nn.ReflectionPad2d(1)] 56 | elif padding_type == 'replicate': 57 | sequence += [nn.ReplicationPad2d(1)] 58 | elif padding_type == 'zero': 59 | padding = 1 60 | else: 61 | raise NotImplementedError('padding [%s] is not implemented' % padding_type) 62 | return padding 63 | 64 | def forward(self, x): 65 | out = x + self.model(x) 66 | return out 67 | 68 | 69 | class MinibatchDiscrimination(nn.Module): 70 | """minibatch discrimination""" 71 | 72 | def __init__(self, in_features, out_features, kernel_dims, mean=False): 73 | super(MinibatchDiscrimination, self).__init__() 74 | self.in_features = in_features # A 75 | self.out_features = out_features # B 76 | self.kernel_dims = kernel_dims # C 77 | self.mean = mean 78 | self.T = nn.Parameter(torch.Tensor(in_features, out_features, kernel_dims)).cuda() # AxBxC 79 | nn.init.normal_(self.T, 0, 1) 80 | 81 | def forward(self, x): 82 | # x is NxA 83 | # T is AxBxC 84 | matrices = x.mm(self.T.view(self.in_features, -1)) # NxBC 85 | matrices = matrices.view(-1, self.out_features, self.kernel_dims) # NxBxC 86 | 87 | M = matrices.unsqueeze(0) # 1xNxBxC 88 | M_T = M.permute(1, 0, 2, 3) # Nx1xBxC 89 | norm = torch.abs(M - M_T).sum(3) # NxNxB 90 | expnorm = torch.exp(-norm) 91 | o_b = (expnorm.sum(0)) # NxB 92 | if self.mean: 93 | o_b /= x.size(0) 94 | 95 | x = torch.cat((x, o_b), 1) # Nx(A+B) 96 | return x 97 | -------------------------------------------------------------------------------- /model/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.autograd as autograd 4 | 5 | from .layer_utils import CONV3_3_IN_VGG_19 6 | 7 | 8 | def perceptual_loss(deblurred, sharp): 9 | model = CONV3_3_IN_VGG_19 10 | 11 | # feature map of the output and target 12 | deblurred_feature_map = model.forward(deblurred) 13 | sharp_feature_map = model.forward(sharp).detach() # we do not need the gradient of it 14 | loss = F.mse_loss(deblurred_feature_map, sharp_feature_map) 15 | return loss 16 | 17 | 18 | def wgan_gp_loss(type, **kwargs): 19 | if type == 'G': # generator losss 20 | deblurred_discriminator_out = kwargs['deblurred_discriminator_out'] 21 | return -deblurred_discriminator_out.mean() 22 | 23 | elif type == 'D': # discriminator loss 24 | gp_lambda = kwargs['gp_lambda'] # lambda coefficient of gradient penalty term 25 | interpolates = kwargs['interpolates'] # interpolates = alpha * sharp + (1 - alpha) * deblurred 26 | interpolates_discriminator_out = kwargs['interpolates_discriminator_out'] 27 | sharp_discriminator_out = kwargs['sharp_discriminator_out'] 28 | deblurred_discriminator_out = kwargs['deblurred_discriminator_out'] 29 | 30 | # WGAN loss 31 | wgan_loss = deblurred_discriminator_out.mean() - sharp_discriminator_out.mean() 32 | 33 | # gradient penalty 34 | gradients = autograd.grad(outputs=interpolates_discriminator_out, inputs=interpolates, 35 | grad_outputs=torch.ones(interpolates_discriminator_out.size()).cuda(), 36 | retain_graph=True, 37 | create_graph=True)[0] 38 | gradient_penalty = ((gradients.view(gradients.size(0), -1).norm(2, dim=1) - 1) ** 2).mean() 39 | 40 | return wgan_loss, gp_lambda * gradient_penalty 41 | 42 | 43 | def gan_loss(type, **kwargs): 44 | if type == 'G': 45 | deblurred_discriminator_out = kwargs['deblurred_discriminator_out'] 46 | return F.binary_cross_entropy(deblurred_discriminator_out, torch.ones_like(deblurred_discriminator_out)) 47 | 48 | elif type == 'D': 49 | sharp_discriminator_out = kwargs['sharp_discriminator_out'] 50 | deblurred_discriminator_out = kwargs['deblurred_discriminator_out'] 51 | 52 | # GAN loss 53 | real_loss = F.binary_cross_entropy(sharp_discriminator_out, torch.ones_like(sharp_discriminator_out)) 54 | fake_loss = F.binary_cross_entropy(deblurred_discriminator_out, torch.zeros_like(deblurred_discriminator_out)) 55 | return (real_loss + fake_loss) / 2.0 56 | -------------------------------------------------------------------------------- /model/metric.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | 5 | 6 | def PSNR(deblurred, sharp): 7 | """Peak Signal to Noise Ratio""" 8 | mse = torch.mean((deblurred - sharp) ** 2) # mean square error 9 | if mse == 0: 10 | return 100 11 | PIXEL_MAX = 1 12 | return 10 * math.log10(PIXEL_MAX ** 2 / mse) 13 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .layer_utils import get_norm_layer, ResNetBlock, MinibatchDiscrimination 7 | from base.base_model import BaseModel 8 | 9 | 10 | class ResNetGenerator(BaseModel): 11 | """Define a generator using ResNet""" 12 | 13 | def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9, norm_type='instance', padding_type='reflect', 14 | use_dropout=True, learn_residual=True): 15 | super(ResNetGenerator, self).__init__() 16 | 17 | self.learn_residual = learn_residual 18 | 19 | norm_layer = get_norm_layer(norm_type) 20 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 21 | use_bias = norm_layer.func != nn.BatchNorm2d 22 | else: 23 | use_bias = norm_layer != nn.BatchNorm2d 24 | 25 | sequence = [ 26 | nn.ReflectionPad2d(3), 27 | nn.Conv2d(input_nc, ngf, kernel_size=7, stride=1, padding=0, bias=use_bias), 28 | norm_layer(ngf), 29 | nn.ReLU(True) 30 | ] 31 | 32 | n_downsampling = 2 33 | for i in range(n_downsampling): # downsample the feature map 34 | mult = 2 ** i 35 | sequence += [ 36 | nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), 37 | norm_layer(ngf * mult * 2), 38 | nn.ReLU(True) 39 | ] 40 | 41 | for i in range(n_blocks): # ResNet 42 | sequence += [ 43 | ResNetBlock(ngf * 2 ** n_downsampling, norm_layer, padding_type, use_dropout, use_bias) 44 | ] 45 | 46 | for i in range(n_downsampling): # upsample the feature map 47 | mult = 2 ** (n_downsampling - i) 48 | sequence += [ 49 | nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, 50 | output_padding=1, bias=use_bias), 51 | norm_layer(int(ngf * mult / 2)), 52 | nn.ReLU(True) 53 | ] 54 | 55 | sequence += [ 56 | nn.ReflectionPad2d(3), 57 | nn.Conv2d(ngf, output_nc, kernel_size=7, stride=1, padding=0), 58 | nn.Tanh() 59 | ] 60 | 61 | self.model = nn.Sequential(*sequence) 62 | 63 | def forward(self, x): 64 | out = self.model(x) 65 | if self.learn_residual: 66 | out = x + out 67 | out = torch.clamp(out, min=-1, max=1) # clamp to [-1,1] according to normalization(mean=0.5, var=0.5) 68 | return out 69 | 70 | 71 | class NLayerDiscriminator(BaseModel): 72 | """Define a PatchGAN discriminator""" 73 | 74 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_type='instance', use_sigmoid=False, 75 | use_minibatch_discrimination=False): 76 | super(NLayerDiscriminator, self).__init__() 77 | 78 | self.use_minibatch_discrimination = use_minibatch_discrimination 79 | 80 | norm_layer = get_norm_layer(norm_type) 81 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 82 | use_bias = norm_layer.func != nn.BatchNorm2d 83 | else: 84 | use_bias = norm_layer != nn.BatchNorm2d 85 | 86 | kernel_size = 4 87 | padding = 1 88 | sequence = [ 89 | nn.Conv2d(input_nc, ndf, kernel_size=kernel_size, stride=2, padding=padding), 90 | nn.LeakyReLU(0.2, True) 91 | ] 92 | 93 | nf_mult = 1 94 | for n in range(1, n_layers): # gradually increase the number of filters 95 | nf_mult_prev = nf_mult 96 | nf_mult = min(2 ** n, 8) 97 | sequence += [ 98 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kernel_size, stride=2, padding=padding, 99 | bias=use_bias), 100 | norm_layer(ndf * nf_mult), 101 | nn.LeakyReLU(0.2, True) 102 | ] 103 | 104 | nf_mult_prev = nf_mult 105 | nf_mult = min(2 ** n_layers, 8) 106 | sequence += [ 107 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kernel_size, stride=1, padding=padding, 108 | bias=use_bias), 109 | norm_layer(ndf * nf_mult), 110 | nn.LeakyReLU(0.2, True) 111 | ] 112 | 113 | sequence += [ 114 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kernel_size, stride=1, padding=padding) 115 | ] # output 1 channel prediction map 116 | 117 | if use_sigmoid: 118 | sequence += [nn.Sigmoid()] 119 | 120 | self.model = nn.Sequential(*sequence) 121 | 122 | def forward(self, x): 123 | out = self.model(x) 124 | if self.use_minibatch_discrimination: 125 | out = out.view(out.size(0), -1) 126 | a = out.size(1) 127 | out = MinibatchDiscrimination(a, a, 3)(out) 128 | return out 129 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0.0 2 | torchvision>=0.2.2 3 | numpy>=1.15.4 4 | tqdm>=4.28.1 5 | pillow>=5.3.0 6 | tensorboardX>=1.6 -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | import torch 5 | from tqdm import tqdm 6 | import numpy as np 7 | 8 | 9 | def main(resume): 10 | # load checkpoint 11 | checkpoint = torch.load(resume) 12 | config = checkpoint['config'] 13 | 14 | # setup data_loader instances 15 | data_loader_class = getattr(module_data, config['data_loader']['type']) 16 | data_loader_config_args = { 17 | "data_dir": config['data_loader']['args']['data_dir'], 18 | 'batch_size': 16, # use large batch_size 19 | 'shuffle': False, # do not shuffle 20 | 'validation_split': 0.0, # do not split, just use the full dataset 21 | 'num_workers': 16 # use large num_workers 22 | } 23 | data_loader = data_loader_class(**data_loader_config_args) 24 | 25 | # build model architecture 26 | generator_class = getattr(module_arch, config['generator']['type']) 27 | generator = generator_class(**config['generator']['args']) 28 | 29 | discriminator_class = getattr(module_arch, config['discriminator']['type']) 30 | discriminator = discriminator_class(**config['discriminator']['args']) 31 | 32 | # get function handles of loss and metrics 33 | loss_fn = {k: getattr(module_loss, v) for k, v in config['loss'].items()} 34 | metric_fns = [getattr(module_metric, met) for met in config['metrics']] 35 | 36 | # prepare model for testing 37 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 38 | generator = generator.to(device) 39 | discriminator = discriminator.to(device) 40 | 41 | generator.load_state_dict(checkpoint['generator']) 42 | discriminator.load_state_dict(checkpoint['discriminator']) 43 | 44 | generator.eval() 45 | discriminator.eval() 46 | 47 | total_loss = 0.0 48 | total_metrics = np.zeros(len(metric_fns)) 49 | 50 | with torch.no_grad(): 51 | for batch_idx, sample in enumerate(tqdm(data_loader, ascii=True)): 52 | blurred = sample['blurred'].to(device) 53 | sharp = sample['sharp'].to(device) 54 | 55 | deblurred = generator(blurred) 56 | deblurred_discriminator_out = discriminator(deblurred) 57 | 58 | denormalized_deblurred = denormalize(deblurred) 59 | denormalized_sharp = denormalize(sharp) 60 | 61 | # computing loss, metrics on test set 62 | content_loss_lambda = config['others']['content_loss_lambda'] 63 | adversarial_loss_fn = loss_fn['adversarial'] 64 | content_loss_fn = loss_fn['content'] 65 | kwargs = { 66 | 'deblurred_discriminator_out': deblurred_discriminator_out 67 | } 68 | loss = adversarial_loss_fn('G', **kwargs) + content_loss_fn(deblurred, sharp) * content_loss_lambda 69 | 70 | total_loss += loss.item() 71 | for i, metric in enumerate(metric_fns): 72 | total_metrics[i] += metric(denormalized_deblurred, denormalized_sharp) 73 | 74 | n_samples = len(data_loader) 75 | log = {'loss': total_loss / n_samples} 76 | log.update({met.__name__: total_metrics[i].item() / n_samples for i, met in enumerate(metric_fns)}) 77 | print(log) 78 | 79 | 80 | if __name__ == '__main__': 81 | parser = argparse.ArgumentParser(description='DeblurGAN') 82 | 83 | parser.add_argument('-r', '--resume', required=True, type=str, help='path to latest checkpoint') 84 | parser.add_argument('--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 85 | 86 | args = parser.parse_args() 87 | 88 | if args.device: 89 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 90 | 91 | import data_loader.data_loader as module_data 92 | import model.loss as module_loss 93 | import model.metric as module_metric 94 | import model.model as module_arch 95 | from utils.util import denormalize 96 | 97 | main(args.resume) 98 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | 7 | from trainer.trainer import Trainer 8 | from utils.logger import Logger 9 | from utils.util import get_lr_scheduler 10 | from data_loader import data_loader as module_data 11 | from model import loss as module_loss 12 | from model import metric as module_metric 13 | from model import model as module_arch 14 | 15 | 16 | def main(config, resume): 17 | train_logger = Logger() 18 | 19 | # setup data_loader instances 20 | data_loader_class = getattr(module_data, config['data_loader']['type']) 21 | data_loader = data_loader_class(**config['data_loader']['args']) 22 | valid_data_loader = data_loader.split_validation() 23 | 24 | # build model architecture 25 | generator_class = getattr(module_arch, config['generator']['type']) 26 | generator = generator_class(**config['generator']['args']) 27 | 28 | discriminator_class = getattr(module_arch, config['discriminator']['type']) 29 | discriminator = discriminator_class(**config['discriminator']['args']) 30 | 31 | print(generator) 32 | print(discriminator) 33 | 34 | # get function handles of loss and metrics 35 | loss = {k: getattr(module_loss, v) for k, v in config['loss'].items()} 36 | metrics = [getattr(module_metric, met) for met in config['metrics']] 37 | 38 | # build optimizer for generator and discriminator 39 | generator_trainable_params = filter(lambda p: p.requires_grad, generator.parameters()) 40 | discriminator_trainable_params = filter(lambda p: p.requires_grad, discriminator.parameters()) 41 | optimizer_class = getattr(torch.optim, config['optimizer']['type']) 42 | optimizer = dict() 43 | optimizer['generator'] = optimizer_class(generator_trainable_params, **config['optimizer']['args']) 44 | optimizer['discriminator'] = optimizer_class(discriminator_trainable_params, **config['optimizer']['args']) 45 | 46 | # build learning rate scheduler for generator and discriminator 47 | lr_scheduler = dict() 48 | lr_scheduler['generator'] = get_lr_scheduler(config['lr_scheduler'], optimizer['generator']) 49 | lr_scheduler['discriminator'] = get_lr_scheduler(config['lr_scheduler'], optimizer['discriminator']) 50 | 51 | # start to train the network 52 | trainer = Trainer(config, generator, discriminator, loss, metrics, optimizer, lr_scheduler, resume, data_loader, 53 | valid_data_loader, train_logger) 54 | trainer.train() 55 | 56 | 57 | if __name__ == '__main__': 58 | parser = argparse.ArgumentParser(description='DeblurGAN') 59 | parser.add_argument('-c', '--config', default=None, type=str, help='config file path (default: None)') 60 | parser.add_argument('-r', '--resume', default=None, type=str, help='path to latest checkpoint (default: None)') 61 | parser.add_argument('--device', default=None, type=str, help='indices of GPUs to enable (default: all)') 62 | args = parser.parse_args() 63 | 64 | if args.config: 65 | # load config file 66 | with open(args.config) as handle: 67 | config = json.load(handle) 68 | # setting path to save trained models and log files 69 | path = os.path.join(config['trainer']['save_dir'], config['name']) 70 | elif args.resume: 71 | # load config from checkpoint if new config file is not given. 72 | # Use '--config' and '--resume' together to fine-tune trained model with changed configurations. 73 | config = torch.load(args.resume)['config'] 74 | else: 75 | raise AssertionError("Configuration file need to be specified. Add '-c config.json', for example.") 76 | 77 | if args.device: 78 | os.environ["CUDA_VISIBLE_DEVICES"] = args.device 79 | 80 | main(config, args.resume) 81 | -------------------------------------------------------------------------------- /trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/trainer/__init__.py -------------------------------------------------------------------------------- /trainer/trainer.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from torchvision.utils import make_grid 6 | 7 | from base.base_trainer import BaseTrainer 8 | from utils.util import denormalize 9 | 10 | 11 | class Trainer(BaseTrainer): 12 | """ 13 | Trainer class 14 | 15 | Note: 16 | Inherited from BaseTrainer. 17 | """ 18 | 19 | def __init__(self, config, generator, discriminator, loss, metrics, optimizer, lr_scheduler, resume, data_loader, 20 | valid_data_loader=None, train_logger=None): 21 | super(Trainer, self).__init__(config, generator, discriminator, loss, metrics, optimizer, lr_scheduler, resume, 22 | train_logger) 23 | 24 | self.data_loader = data_loader 25 | self.valid_data_loader = valid_data_loader 26 | self.do_validation = self.valid_data_loader is not None 27 | self.log_step = int(np.sqrt(data_loader.batch_size)) 28 | 29 | def _eval_metrics(self, output, target): 30 | acc_metrics = np.zeros(len(self.metrics)) 31 | for i, metric in enumerate(self.metrics): 32 | acc_metrics[i] += metric(output, target) 33 | self.writer.add_scalar('{}'.format(metric.__name__), acc_metrics[i]) 34 | return acc_metrics 35 | 36 | def _train_epoch(self, epoch): 37 | """ 38 | Training logic for an epoch 39 | 40 | :param epoch: Current training epoch. 41 | :return: A log that contains all information you want to save. 42 | 43 | Note: 44 | If you have additional information to record, for example: 45 | > additional_log = {"x": x, "y": y} 46 | merge it with log before return. i.e. 47 | > log = {**log, **additional_log} 48 | > return log 49 | 50 | The metrics in log must have the key 'metrics'. 51 | """ 52 | # set models to train mode 53 | self.generator.train() 54 | self.discriminator.train() 55 | 56 | total_generator_loss = 0 57 | total_discriminator_loss = 0 58 | total_metrics = np.zeros(len(self.metrics)) 59 | 60 | for batch_idx, sample in enumerate(self.data_loader): 61 | self.writer.set_step((epoch - 1) * len(self.data_loader) + batch_idx) 62 | 63 | # get data and send them to GPU 64 | blurred = sample['blurred'].to(self.device) 65 | sharp = sample['sharp'].to(self.device) 66 | 67 | # get G's output 68 | deblurred = self.generator(blurred) 69 | 70 | # denormalize 71 | with torch.no_grad(): 72 | denormalized_blurred = denormalize(blurred) 73 | denormalized_sharp = denormalize(sharp) 74 | denormalized_deblurred = denormalize(deblurred) 75 | 76 | if batch_idx % 100 == 0: 77 | # save blurred, sharp and deblurred image 78 | self.writer.add_image('blurred', make_grid(denormalized_blurred.cpu())) 79 | self.writer.add_image('sharp', make_grid(denormalized_sharp.cpu())) 80 | self.writer.add_image('deblurred', make_grid(denormalized_deblurred.cpu())) 81 | 82 | # get D's output 83 | sharp_discriminator_out = self.discriminator(sharp) 84 | deblurred_discriminator_out = self.discriminator(deblurred) 85 | 86 | # set critic_updates 87 | if self.config['loss']['adversarial'] == 'wgan_gp_loss': 88 | critic_updates = 5 89 | else: 90 | critic_updates = 1 91 | 92 | # train discriminator 93 | discriminator_loss = 0 94 | for i in range(critic_updates): 95 | self.discriminator_optimizer.zero_grad() 96 | 97 | # train discriminator on real and fake 98 | if self.config['loss']['adversarial'] == 'wgan_gp_loss': 99 | gp_lambda = self.config['others']['gp_lambda'] 100 | alpha = random.random() 101 | interpolates = alpha * sharp + (1 - alpha) * deblurred 102 | interpolates_discriminator_out = self.discriminator(interpolates) 103 | kwargs = { 104 | 'gp_lambda': gp_lambda, 105 | 'interpolates': interpolates, 106 | 'interpolates_discriminator_out': interpolates_discriminator_out, 107 | 'sharp_discriminator_out': sharp_discriminator_out, 108 | 'deblurred_discriminator_out': deblurred_discriminator_out 109 | } 110 | wgan_loss_d, gp_d = self.adversarial_loss('D', **kwargs) 111 | discriminator_loss_per_update = wgan_loss_d + gp_d 112 | 113 | self.writer.add_scalar('wgan_loss_d', wgan_loss_d.item()) 114 | self.writer.add_scalar('gp_d', gp_d.item()) 115 | elif self.config['loss']['adversarial'] == 'gan_loss': 116 | kwargs = { 117 | 'sharp_discriminator_out': sharp_discriminator_out, 118 | 'deblurred_discriminator_out': deblurred_discriminator_out 119 | } 120 | gan_loss_d = self.adversarial_loss('D', **kwargs) 121 | discriminator_loss_per_update = gan_loss_d 122 | 123 | self.writer.add_scalar('gan_loss_d', gan_loss_d.item()) 124 | else: 125 | # add other loss if you like 126 | raise NotImplementedError 127 | 128 | discriminator_loss_per_update.backward(retain_graph=True) 129 | self.discriminator_optimizer.step() 130 | discriminator_loss += discriminator_loss_per_update.item() 131 | 132 | discriminator_loss /= critic_updates 133 | self.writer.add_scalar('discriminator_loss', discriminator_loss) 134 | total_discriminator_loss += discriminator_loss 135 | 136 | # train generator 137 | self.generator_optimizer.zero_grad() 138 | 139 | content_loss_lambda = self.config['others']['content_loss_lambda'] 140 | kwargs = { 141 | 'deblurred_discriminator_out': deblurred_discriminator_out 142 | } 143 | adversarial_loss_g = self.adversarial_loss('G', **kwargs) 144 | content_loss_g = self.content_loss(deblurred, sharp) * content_loss_lambda 145 | # in the recent version of PyTorch .detach() is required 146 | generator_loss = adversarial_loss_g.detach() + content_loss_g 147 | 148 | self.writer.add_scalar('adversarial_loss_g', adversarial_loss_g.item()) 149 | self.writer.add_scalar('content_loss_g', content_loss_g.item()) 150 | self.writer.add_scalar('generator_loss', generator_loss.item()) 151 | 152 | generator_loss.backward() 153 | self.generator_optimizer.step() 154 | total_generator_loss += generator_loss.item() 155 | 156 | # calculate the metrics 157 | total_metrics += self._eval_metrics(denormalized_deblurred, denormalized_sharp) 158 | 159 | if self.verbosity >= 2 and batch_idx % self.log_step == 0: 160 | self.logger.info( 161 | 'Train Epoch: {} [{}/{} ({:.0f}%)] generator_loss: {:.6f} discriminator_loss: {:.6f}'.format( 162 | epoch, 163 | batch_idx * self.data_loader.batch_size, 164 | self.data_loader.n_samples, 165 | 100.0 * batch_idx / len(self.data_loader), 166 | generator_loss.item(), # it's a tensor, so we call .item() method 167 | discriminator_loss # just a num 168 | ) 169 | ) 170 | 171 | log = { 172 | 'generator_loss': total_generator_loss / len(self.data_loader), 173 | 'discriminator_loss': total_discriminator_loss / len(self.data_loader), 174 | 'metrics': (total_metrics / len(self.data_loader)).tolist() 175 | } 176 | 177 | if self.do_validation: 178 | val_log = self._valid_epoch(epoch) 179 | log = {**log, **val_log} 180 | 181 | self.generator_lr_scheduler.step() 182 | self.discriminator_lr_scheduler.step() 183 | 184 | return log 185 | 186 | def _valid_epoch(self, epoch): 187 | """ 188 | Validate after training an epoch 189 | 190 | :return: A log that contains information about validation 191 | 192 | Note: 193 | The validation metrics in log must have the key 'val_metrics'. 194 | """ 195 | self.generator.eval() 196 | self.discriminator.eval() 197 | 198 | total_val_loss = 0 199 | total_val_metrics = np.zeros(len(self.metrics)) 200 | 201 | with torch.no_grad(): 202 | for batch_idx, sample in enumerate(self.valid_data_loader): 203 | blurred = sample['blurred'].to(self.device) 204 | sharp = sample['sharp'].to(self.device) 205 | 206 | deblurred = self.generator(blurred) 207 | deblurred_discriminator_out = self.discriminator(deblurred) 208 | 209 | content_loss_lambda = self.config['others']['content_loss_lambda'] 210 | kwargs = { 211 | 'deblurred_discriminator_out': deblurred_discriminator_out 212 | } 213 | adversarial_loss_g = self.adversarial_loss('G', **kwargs) 214 | content_loss_g = self.content_loss(deblurred, sharp) * content_loss_lambda 215 | loss_g = adversarial_loss_g + content_loss_g 216 | 217 | self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') 218 | self.writer.add_scalar('adversarial_loss_g', adversarial_loss_g.item()) 219 | self.writer.add_scalar('content_loss_g', content_loss_g.item()) 220 | self.writer.add_scalar('loss_g', loss_g.item()) 221 | total_val_loss += loss_g.item() 222 | 223 | total_val_metrics += self._eval_metrics(denormalize(deblurred), denormalize(sharp)) 224 | 225 | # add histogram of model parameters to the tensorboard 226 | for name, p in self.generator.named_parameters(): 227 | self.writer.add_histogram(name, p, bins='auto') 228 | 229 | return { 230 | 'val_loss': total_val_loss / len(self.valid_data_loader), 231 | 'val_metrics': (total_val_metrics / len(self.valid_data_loader)).tolist() 232 | } 233 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fourson/DeblurGAN-pytorch/41696fc6c68ef2a3924dfbddc0ece69821e5678b/utils/__init__.py -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | 4 | logging.basicConfig(level=logging.INFO, format='') 5 | 6 | 7 | class Logger: 8 | """ 9 | Training process logger 10 | 11 | Note: 12 | Used by BaseTrainer to save training history. 13 | """ 14 | def __init__(self): 15 | self.entries = {} 16 | 17 | def add_entry(self, entry): 18 | self.entries[len(self.entries) + 1] = entry 19 | 20 | def __str__(self): 21 | return json.dumps(self.entries, sort_keys=True, indent=4) 22 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | 6 | def ensure_dir(path): 7 | if not os.path.exists(path): 8 | os.makedirs(path) 9 | 10 | 11 | def get_lr_lambda(lr_lambda): 12 | if lr_lambda == 'origin_lr_scheduler': 13 | # the same as origin paper's method (epoch=300) 14 | # "After the first 150 epochs we linearly decay the rate to zero over the next 150 epochs" 15 | return lambda epoch: (1 - (epoch - 150) / 150) if epoch > 150 else 1 16 | # add other lambdas if you want 17 | else: 18 | raise NotImplementedError('lr_lambda [%s] is not found' % lr_lambda) 19 | 20 | 21 | def get_lr_scheduler(lr_scheduler_config, optimizer): 22 | lr_scheduler_class = getattr(torch.optim.lr_scheduler, lr_scheduler_config['type']) 23 | if lr_scheduler_config['type'] == 'LambdaLR': 24 | lr_lambda = get_lr_lambda(lr_scheduler_config['args']['lr_lambda']) 25 | return lr_scheduler_class(optimizer, lr_lambda) 26 | else: 27 | return lr_scheduler_class(optimizer, **lr_scheduler_config['args']) 28 | 29 | 30 | def denormalize(image_tensor): 31 | # denormalize the normalized image tensor(N,C,H,W) with mean=0.5 and std=0.5 for each channel 32 | return (image_tensor + 1) / 2.0 33 | -------------------------------------------------------------------------------- /utils/visualization.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | class WriterTensorboardX: 5 | def __init__(self, writer_dir, logger, enable): 6 | self.writer = None 7 | if enable: 8 | log_path = writer_dir 9 | try: 10 | self.writer = importlib.import_module('tensorboardX').SummaryWriter(log_path) 11 | except ImportError: 12 | message = "Warning: TensorboardX visualization is configured to use, but currently not installed on " \ 13 | "this machine. Please install the package by 'pip install tensorboardx' command or turn " \ 14 | "off the option in the 'config.json' file." 15 | logger.warning(message) 16 | self.step = 0 17 | self.mode = '' 18 | 19 | self.tb_writer_ftns = [ 20 | 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', 21 | 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' 22 | ] 23 | self.tag_mode_exceptions = ['add_histogram', 'add_embedding'] 24 | 25 | def set_step(self, step, mode='train'): 26 | self.mode = mode 27 | self.step = step 28 | 29 | def __getattr__(self, name): 30 | """ 31 | If visualization is configured to use: 32 | return add_data() methods of tensorboard with additional information (step, tag) added. 33 | Otherwise: 34 | return a blank function handle that does nothing 35 | """ 36 | if name in self.tb_writer_ftns: 37 | add_data = getattr(self.writer, name, None) 38 | 39 | def wrapper(tag, data, *args, **kwargs): 40 | if add_data is not None: 41 | # add mode(train/valid) tag 42 | if name not in self.tag_mode_exceptions: 43 | tag = '{}/{}'.format(self.mode, tag) 44 | add_data(tag, data, self.step, *args, **kwargs) 45 | return wrapper 46 | else: 47 | # default action for returning methods defined in this class, set_step() for instance. 48 | try: 49 | attr = object.__getattr__(name) 50 | except AttributeError: 51 | raise AttributeError("type object 'WriterTensorboardX' has no attribute '{}'".format(name)) 52 | return attr 53 | --------------------------------------------------------------------------------