├── .gitignore ├── Dehazing ├── ITS │ ├── README.md │ ├── data │ │ ├── __init__.py │ │ ├── data_augment.py │ │ └── data_load.py │ ├── eval.py │ ├── main.py │ ├── models │ │ ├── ConvIR.py │ │ └── layers.py │ ├── train.py │ ├── utils.py │ └── valid.py └── OTS │ ├── README.md │ ├── data │ ├── __init__.py │ ├── data_augment.py │ └── data_load.py │ ├── eval.py │ ├── main.py │ ├── models │ ├── ConvIR.py │ └── layers.py │ ├── train.py │ ├── utils.py │ └── valid.py ├── Image_deraining ├── README.md ├── data │ ├── __init__.py │ ├── data_augment.py │ └── data_load.py ├── deraining_test.m ├── main.py ├── models │ ├── ConvIR.py │ └── layers.py ├── test.py ├── train.py ├── utils.py └── valid.py ├── Image_desnowing ├── README.md ├── data │ ├── __init__.py │ ├── data_augment.py │ └── data_load.py ├── eval.py ├── main.py ├── models │ ├── ConvIR.py │ └── layers.py ├── train.py ├── utils.py └── valid.py ├── LICENSE ├── Motion_Deblurring ├── README.md ├── data │ ├── __init__.py │ ├── data_augment.py │ └── data_load.py ├── eval.py ├── main.py ├── models │ ├── ConvIR.py │ └── layers.py ├── train.py ├── utils.py └── valid.py ├── README.html ├── README.md └── pytorch-gradual-warmup-lr ├── setup.py └── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /Dehazing/ITS/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - reside-indoor [[gdrive](https://drive.google.com/drive/folders/1pbtfTp29j7Ip-mRzDpMpyopCfXd-ZJhC?usp=sharing), [Baidu](https://pan.baidu.com/s/1jD-TU0wdtSoEb4ki-Cut2A?pwd=1lr0)] 3 | - (Separate SOTS test set if needed) [[gdrive](https://drive.google.com/file/d/16j2dwVIa9q_0RtpIXMzhu-7Q6dwz_D1N/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1R6qWri7sG1hC_Ifj-H6DOQ?pwd=o5sk)] 4 | 5 | 6 | ### Training 7 | on ITS 8 | ~~~ 9 | python main.py 10 | --mode train 11 | --version small 12 | --data ITS 13 | --data_dir your_path/reside-indoor 14 | --batch_size 4 15 | --learning_rate 1e-4 16 | --num_epoch 1000 17 | --save_freq 20 18 | --valid_freq 20 19 | ~~~ 20 | 21 | To train on real haze or haze4k, please uncomment the corresponding hyper-parameter part 22 | 23 | Then run ``python main.py --data Haze4K`` or ``python main.py --data real_haze`` with the mode and model version you want, such as small/large. 24 | 25 | Please comfirm that the model is trained on a larger patch for real haze. 26 | ### Evaluation 27 | Download model from [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 28 | #### Testing 29 | ~~~ 30 | python main.py 31 | --mode test 32 | --data_dir your_path/reside-indoor 33 | --test_model path_to_its_model 34 | --data ITS 35 | ~~~ 36 | 37 | 38 | For training and testing, your directory structure should look like this 39 | 40 | `Your path` 
41 | `├──reside-indoor`
42 |      `├──train`
43 |           `├──gt`
44 |           `└──hazy` 45 |      `└──test`
46 |           `├──gt`
47 |           `└──hazy` 48 | 49 | `└──NHR`
50 |      `├──train`
51 |           `├──gt`
52 |           `└──hazy` 53 |      `└──test`
54 |           `├──gt`
55 |           `└──hazy` 56 | 57 | `└──GTA5`
58 |      `├──train`
59 |           `├──gt`
60 |           `└──hazy` 61 |      `└──test`
62 |           `├──gt`
63 |           `└──hazy` 64 | 65 | `└──haze4k`
66 |      `├──train`
67 |           `├──GT`
68 |           `└──IN` 69 |      `└──test`
70 |           `├──GT`
71 |           `└──IN` 72 | 73 | `└──Haze1k-thin`
74 |      `├──train`
75 |           `├──input`
76 |           `└──target` 77 |      `└──test`
78 |           `├──input`
79 |           `└──target` 80 | 81 | `└──Haze1k-moderate`
82 |      `├──train`
83 |           `├──input`
84 |           `└──target` 85 |      `└──test`
86 |           `├──input`
87 |           `└──target` 88 | 89 | `└──Haze1k-thick`
90 |      `├──train`
91 |           `├──input`
92 |           `└──target` 93 |      `└──test`
94 |           `├──input`
95 |           `└──target` 96 | 97 | `└──Dense-Haze`
98 |      `├──train`
99 |           `├──hazy`
100 |           `└──gt` 101 |      `└──test`
102 |           `├──hazy`
103 |           `└──gt` 104 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image as Image 3 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor 4 | from torchvision.transforms import functional as F 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | 8 | def train_dataloader(path, batch_size=64, num_workers=0, data='ITS', use_transform=True): 9 | image_dir = os.path.join(path, 'train') 10 | 11 | if data == 'real_haze': 12 | crop_size = [800,1184] 13 | else: 14 | crop_size = 256 15 | 16 | transform = None 17 | if use_transform: 18 | transform = PairCompose( 19 | [ 20 | PairRandomCrop(crop_size), 21 | PairRandomHorizontalFilp(), 22 | PairToTensor() 23 | ] 24 | ) 25 | dataloader = DataLoader( 26 | DeblurDataset(image_dir, data, transform=transform), 27 | batch_size=batch_size, 28 | shuffle=True, 29 | num_workers=num_workers, 30 | pin_memory=True 31 | ) 32 | return dataloader 33 | 34 | 35 | def test_dataloader(path, data, batch_size=1, num_workers=0): 36 | image_dir = os.path.join(path, 'test') 37 | dataloader = DataLoader( 38 | DeblurDataset(image_dir, data, is_test=True), 39 | batch_size=batch_size, 40 | shuffle=False, 41 | num_workers=num_workers, 42 | pin_memory=True 43 | ) 44 | 45 | return dataloader 46 | 47 | 48 | def valid_dataloader(path, data, batch_size=1, num_workers=0): 49 | dataloader = DataLoader( 50 | DeblurDataset(os.path.join(path, 'test'), data), 51 | batch_size=batch_size, 52 | shuffle=False, 53 | num_workers=num_workers 54 | ) 55 | 56 | return dataloader 57 | 58 | 59 | class DeblurDataset(Dataset): 60 | def __init__(self, image_dir, data, transform=None, is_test=False): 61 | self.image_dir = image_dir 62 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/')) 63 | self.image_list.sort() 64 | self.transform = transform 65 | self.is_test = is_test 66 | self.data = data 67 | 68 | def __len__(self): 69 | return len(self.image_list) 70 | 71 | def __getitem__(self, idx): 72 | if self.data == 'ITS': 73 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])) 74 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png')) 75 | elif self.data == 'real_haze': 76 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])) 77 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx]).replace('hazy', 'GT')) 78 | elif self.data == 'haze4k': 79 | image = Image.open(os.path.join(self.image_dir, 'IN', self.image_list[idx])) 80 | label = Image.open(os.path.join(self.image_dir, 'GT', self.image_list[idx])) 81 | 82 | if self.transform: 83 | image, label = self.transform(image, label) 84 | else: 85 | image = F.to_tensor(image) 86 | label = F.to_tensor(label) 87 | if self.is_test: 88 | name = self.image_list[idx] 89 | return image, label, name 90 | return image, label 91 | 92 | -------------------------------------------------------------------------------- /Dehazing/ITS/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | import numpy as np 5 | from utils import Adder 6 | from data import test_dataloader 7 | from skimage.metrics import peak_signal_noise_ratio 8 | import time 9 | from pytorch_msssim import ssim 10 | import torch.nn.functional as f 11 | 12 | from skimage import img_as_ubyte 13 | import cv2 14 | 15 | def _eval(model, args): 16 | state_dict = torch.load(args.test_model) 17 | model.load_state_dict(state_dict['model']) 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 20 | torch.cuda.empty_cache() 21 | adder = Adder() 22 | model.eval() 23 | factor = 32 24 | with torch.no_grad(): 25 | psnr_adder = Adder() 26 | ssim_adder = Adder() 27 | 28 | for iter_idx, data in enumerate(dataloader): 29 | input_img, label_img, name = data 30 | 31 | input_img = input_img.to(device) 32 | 33 | h, w = input_img.shape[2], input_img.shape[3] 34 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 35 | padh = H-h if h%factor!=0 else 0 36 | padw = W-w if w%factor!=0 else 0 37 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 38 | 39 | tm = time.time() 40 | 41 | pred = model(input_img)[2] 42 | pred = pred[:,:,:h,:w] 43 | 44 | elapsed = time.time() - tm 45 | adder(elapsed) 46 | 47 | pred_clip = torch.clamp(pred, 0, 1) 48 | 49 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 50 | label_numpy = label_img.squeeze(0).cpu().numpy() 51 | 52 | 53 | label_img = (label_img).cuda() 54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img)) 55 | down_ratio = max(1, round(min(H, W) / 256)) 56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 58 | data_range=1, size_average=False) 59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val)) 60 | ssim_adder(ssim_val) 61 | 62 | if args.save_image: 63 | save_name = os.path.join(args.result_dir, name[0]) 64 | pred_clip += 0.5 / 255 65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 66 | pred.save(save_name) 67 | 68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 69 | psnr_adder(psnr_val) 70 | 71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed)) 72 | 73 | print('==========================================================') 74 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 75 | print('The average SSIM is %.5f dB' % (ssim_adder.average())) 76 | 77 | print("Average time: %f" % adder.average()) 78 | 79 | -------------------------------------------------------------------------------- /Dehazing/ITS/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.ConvIR import build_net 6 | from train import _train 7 | from eval import _eval 8 | 9 | def main(args): 10 | # CUDNN 11 | cudnn.benchmark = True 12 | 13 | if not os.path.exists('results/'): 14 | os.makedirs(args.model_save_dir) 15 | if not os.path.exists('results/' + args.model_name + '/'): 16 | os.makedirs('results/' + args.model_name + '/') 17 | if not os.path.exists(args.result_dir): 18 | os.makedirs(args.result_dir) 19 | model = build_net(args.version, args.data) 20 | # print(model) 21 | 22 | if torch.cuda.is_available(): 23 | model.cuda() 24 | if args.mode == 'train': 25 | _train(model, args) 26 | 27 | elif args.mode == 'test': 28 | _eval(model, args) 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | 34 | # Directories 35 | parser.add_argument('--model_name', default='ConvIR', type=str) 36 | parser.add_argument('--data', type=str, default='ITS', choices=['ITS', 'Haze4K', 'NHR', 'GTA5', 'real_haze']) 37 | parser.add_argument('--version', default='small', choices=['small', 'base', 'large'], type=str) 38 | 39 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) 40 | parser.add_argument('--data_dir', type=str, default='') 41 | 42 | # Train for its 43 | parser.add_argument('--batch_size', type=int, default=4) 44 | parser.add_argument('--learning_rate', type=float, default=1e-4) 45 | parser.add_argument('--weight_decay', type=float, default=0) 46 | parser.add_argument('--num_epoch', type=int, default=300) 47 | parser.add_argument('--print_freq', type=int, default=100) 48 | parser.add_argument('--num_worker', type=int, default=8) 49 | parser.add_argument('--save_freq', type=int, default=10) 50 | parser.add_argument('--valid_freq', type=int, default=10) 51 | parser.add_argument('--resume', type=str, default='') 52 | 53 | 54 | # uncomment for different datasets 55 | 56 | # Train for real-haze 57 | # parser.add_argument('--batch_size', type=int, default=2) 58 | # parser.add_argument('--learning_rate', type=float, default=2e-4) 59 | # parser.add_argument('--weight_decay', type=float, default=0) 60 | # parser.add_argument('--num_epoch', type=int, default=5000) 61 | # parser.add_argument('--print_freq', type=int, default=20) 62 | # parser.add_argument('--num_worker', type=int, default=4) 63 | # parser.add_argument('--save_freq', type=int, default=10) 64 | # parser.add_argument('--valid_freq', type=int, default=10) 65 | 66 | # Train for Haze4k 67 | # parser.add_argument('--batch_size', type=int, default=8) 68 | # parser.add_argument('--learning_rate', type=float, default=4e-4) 69 | # parser.add_argument('--weight_decay', type=float, default=0) 70 | # parser.add_argument('--num_epoch', type=int, default=1000) 71 | # parser.add_argument('--print_freq', type=int, default=100) 72 | # parser.add_argument('--num_worker', type=int, default=8) 73 | # parser.add_argument('--save_freq', type=int, default=20) 74 | # parser.add_argument('--valid_freq', type=int, default=20) 75 | 76 | # Test 77 | parser.add_argument('--test_model', type=str, default='') 78 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 79 | 80 | args = parser.parse_args() 81 | args.model_save_dir = os.path.join('results/', args.model_name, 'Training-Results/') 82 | args.result_dir = os.path.join('results/', args.model_name, 'images', args.data) 83 | if not os.path.exists(args.model_save_dir): 84 | os.makedirs(args.model_save_dir) 85 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 86 | os.system(command) 87 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir 88 | os.system(command) 89 | command = 'cp ' + 'train.py ' + args.model_save_dir 90 | os.system(command) 91 | command = 'cp ' + 'main.py ' + args.model_save_dir 92 | os.system(command) 93 | print(args) 94 | main(args) 95 | -------------------------------------------------------------------------------- /Dehazing/ITS/models/ConvIR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | 8 | class EBlock(nn.Module): 9 | def __init__(self, out_channel, num_res, data): 10 | super(EBlock, self).__init__() 11 | 12 | layers = [ResBlock(out_channel, out_channel, data) for _ in range(num_res-1)] 13 | layers.append(ResBlock(out_channel, out_channel, data, filter=True)) 14 | 15 | self.layers = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DBlock(nn.Module): 22 | def __init__(self, channel, num_res, data): 23 | super(DBlock, self).__init__() 24 | 25 | layers = [ResBlock(channel, channel, data) for _ in range(num_res-1)] 26 | layers.append(ResBlock(channel, channel, data, filter=True)) 27 | self.layers = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.layers(x) 31 | 32 | 33 | class SCM(nn.Module): 34 | def __init__(self, out_plane): 35 | super(SCM, self).__init__() 36 | self.main = nn.Sequential( 37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 41 | nn.InstanceNorm2d(out_plane, affine=True) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.main(x) 46 | return x 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, channel): 50 | super(FAM, self).__init__() 51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 52 | 53 | def forward(self, x1, x2): 54 | return self.merge(torch.cat([x1, x2], dim=1)) 55 | 56 | class ConvIR(nn.Module): 57 | def __init__(self, version, data): 58 | super(ConvIR, self).__init__() 59 | 60 | if version == 'small': 61 | num_res = 4 62 | elif version == 'base': 63 | num_res = 8 64 | elif version == 'large': 65 | num_res = 16 66 | 67 | base_channel = 32 68 | 69 | self.Encoder = nn.ModuleList([ 70 | EBlock(base_channel, num_res, data), 71 | EBlock(base_channel*2, num_res, data), 72 | EBlock(base_channel*4, num_res, data), 73 | ]) 74 | 75 | self.feat_extract = nn.ModuleList([ 76 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 77 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 78 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 79 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 80 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 81 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 82 | ]) 83 | 84 | self.Decoder = nn.ModuleList([ 85 | DBlock(base_channel * 4, num_res, data), 86 | DBlock(base_channel * 2, num_res, data), 87 | DBlock(base_channel, num_res, data) 88 | ]) 89 | 90 | self.Convs = nn.ModuleList([ 91 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 92 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 93 | ]) 94 | 95 | self.ConvsOut = nn.ModuleList( 96 | [ 97 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 98 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 99 | ] 100 | ) 101 | 102 | self.FAM1 = FAM(base_channel * 4) 103 | self.SCM1 = SCM(base_channel * 4) 104 | self.FAM2 = FAM(base_channel * 2) 105 | self.SCM2 = SCM(base_channel * 2) 106 | 107 | def forward(self, x): 108 | x_2 = F.interpolate(x, scale_factor=0.5) 109 | x_4 = F.interpolate(x_2, scale_factor=0.5) 110 | z2 = self.SCM2(x_2) 111 | z4 = self.SCM1(x_4) 112 | 113 | outputs = list() 114 | # 256 115 | x_ = self.feat_extract[0](x) 116 | res1 = self.Encoder[0](x_) 117 | # 128 118 | z = self.feat_extract[1](res1) 119 | z = self.FAM2(z, z2) 120 | res2 = self.Encoder[1](z) 121 | # 64 122 | z = self.feat_extract[2](res2) 123 | z = self.FAM1(z, z4) 124 | z = self.Encoder[2](z) 125 | 126 | z = self.Decoder[0](z) 127 | z_ = self.ConvsOut[0](z) 128 | # 128 129 | z = self.feat_extract[3](z) 130 | outputs.append(z_+x_4) 131 | 132 | z = torch.cat([z, res2], dim=1) 133 | z = self.Convs[0](z) 134 | z = self.Decoder[1](z) 135 | z_ = self.ConvsOut[1](z) 136 | # 256 137 | z = self.feat_extract[4](z) 138 | outputs.append(z_+x_2) 139 | 140 | z = torch.cat([z, res1], dim=1) 141 | z = self.Convs[1](z) 142 | z = self.Decoder[2](z) 143 | z = self.feat_extract[5](z) 144 | outputs.append(z+x) 145 | 146 | return outputs 147 | 148 | 149 | def build_net(version, data): 150 | return ConvIR(version, data) 151 | -------------------------------------------------------------------------------- /Dehazing/ITS/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel, data, filter=False): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | DeepPoolLayer(in_channel, out_channel, data) if filter else nn.Identity(), 35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.main(x) + x 40 | 41 | 42 | class DeepPoolLayer(nn.Module): 43 | def __init__(self, k, k_out, data): 44 | super(DeepPoolLayer, self).__init__() 45 | self.pools_sizes = [8,4,2] 46 | 47 | if data == 'ITS' or 'Densehaze' or 'Haze4k' or 'Ihaze' or 'Nhhaze' or 'NHR' or 'Ohaze': 48 | dilation = [7,9,11] 49 | elif data == 'GTA5': 50 | dilation = [5,9,11] 51 | 52 | pools, convs, dynas = [],[],[] 53 | for j, i in enumerate(self.pools_sizes): 54 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 55 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 56 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j])) 57 | self.pools = nn.ModuleList(pools) 58 | self.convs = nn.ModuleList(convs) 59 | self.dynas = nn.ModuleList(dynas) 60 | self.relu = nn.GELU() 61 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 62 | 63 | def forward(self, x): 64 | x_size = x.size() 65 | resl = x 66 | for i in range(len(self.pools_sizes)): 67 | if i == 0: 68 | y = self.dynas[i](self.convs[i](self.pools[i](x))) 69 | else: 70 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up)) 71 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 72 | if i != len(self.pools_sizes)-1: 73 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) 74 | resl = self.relu(resl) 75 | resl = self.conv_sum(resl) 76 | 77 | return resl 78 | 79 | 80 | class dynamic_filter(nn.Module): 81 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8): 82 | super(dynamic_filter, self).__init__() 83 | self.stride = stride 84 | self.kernel_size = kernel_size 85 | self.group = group 86 | self.dilation = dilation 87 | 88 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) 89 | self.bn = nn.BatchNorm2d(group*kernel_size**2) 90 | self.act = nn.Tanh() 91 | 92 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 93 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 94 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 95 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2) 96 | 97 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 98 | self.gap = nn.AdaptiveAvgPool2d(1) 99 | 100 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True) 101 | 102 | def forward(self, x): 103 | identity_input = x 104 | low_filter = self.ap(x) 105 | low_filter = self.conv(low_filter) 106 | low_filter = self.bn(low_filter) 107 | 108 | n, c, h, w = x.shape 109 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) 110 | 111 | n,c1,p,q = low_filter.shape 112 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) 113 | 114 | low_filter = self.act(low_filter) 115 | 116 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) 117 | 118 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 119 | 120 | out_low = out_low * self.lamb_l[None,:,None,None] 121 | 122 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.) 123 | 124 | return out_low + out_high 125 | 126 | 127 | class cubic_attention(nn.Module): 128 | def __init__(self, dim, group, dilation, kernel) -> None: 129 | super().__init__() 130 | 131 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel) 132 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False) 133 | self.gamma = nn.Parameter(torch.zeros(dim,1,1)) 134 | self.beta = nn.Parameter(torch.ones(dim,1,1)) 135 | 136 | def forward(self, x): 137 | out = self.H_spatial_att(x) 138 | out = self.W_spatial_att(out) 139 | return self.gamma * out + x * self.beta 140 | 141 | 142 | class spatial_strip_att(nn.Module): 143 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None: 144 | super().__init__() 145 | 146 | self.k = kernel 147 | pad = dilation*(kernel-1) // 2 148 | self.kernel = (1, kernel) if H else (kernel, 1) 149 | self.padding = (kernel//2, 1) if H else (1, kernel//2) 150 | self.dilation = dilation 151 | self.group = group 152 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad)) 153 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False) 154 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 155 | self.filter_act = nn.Tanh() 156 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True) 157 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True) 158 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True) 159 | gap_kernel = (None,1) if H else (1, None) 160 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel) 161 | 162 | def forward(self, x): 163 | identity_input = x.clone() 164 | filter = self.ap(x) 165 | filter = self.conv(filter) 166 | n, c, h, w = x.shape 167 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w) 168 | n, c1, p, q = filter.shape 169 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2) 170 | filter = self.filter_act(filter) 171 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w) 172 | 173 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 174 | out_low = out_low * self.lamb_l[None,:,None,None] 175 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.) 176 | 177 | return out_low + out_high 178 | 179 | 180 | class MultiShapeKernel(nn.Module): 181 | def __init__(self, dim, kernel_size=3, dilation=1, group=8): 182 | super().__init__() 183 | 184 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size) 185 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size) 186 | 187 | def forward(self, x): 188 | 189 | x1 = self.strip_att(x) 190 | x2 = self.square_att(x) 191 | 192 | return x1+x2 193 | 194 | 195 | -------------------------------------------------------------------------------- /Dehazing/ITS/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | def _train(model, args): 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | criterion = torch.nn.L1Loss() 15 | 16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker, args.data) 18 | max_iter = len(dataloader) 19 | warmup_epochs=3 20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 22 | scheduler.step() 23 | epoch = 1 24 | if args.resume: 25 | state = torch.load(args.resume) 26 | epoch = state['epoch'] 27 | optimizer.load_state_dict(state['optimizer']) 28 | model.load_state_dict(state['model']) 29 | print('Resume from %d'%epoch) 30 | epoch += 1 31 | 32 | writer = SummaryWriter() 33 | epoch_pixel_adder = Adder() 34 | epoch_fft_adder = Adder() 35 | iter_pixel_adder = Adder() 36 | iter_fft_adder = Adder() 37 | epoch_timer = Timer('m') 38 | iter_timer = Timer('m') 39 | best_psnr=-1 40 | 41 | for epoch_idx in range(epoch, args.num_epoch + 1): 42 | 43 | epoch_timer.tic() 44 | iter_timer.tic() 45 | for iter_idx, batch_data in enumerate(dataloader): 46 | 47 | input_img, label_img = batch_data 48 | input_img = input_img.to(device) 49 | label_img = label_img.to(device) 50 | 51 | optimizer.zero_grad() 52 | pred_img = model(input_img) 53 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 54 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 55 | l1 = criterion(pred_img[0], label_img4) 56 | l2 = criterion(pred_img[1], label_img2) 57 | l3 = criterion(pred_img[2], label_img) 58 | loss_content = l1+l2+l3 59 | 60 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 61 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 62 | 63 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 64 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 65 | 66 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 67 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 68 | 69 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 70 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 71 | 72 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 73 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 74 | 75 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 76 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 77 | 78 | f1 = criterion(pred_fft1, label_fft1) 79 | f2 = criterion(pred_fft2, label_fft2) 80 | f3 = criterion(pred_fft3, label_fft3) 81 | loss_fft = f1+f2+f3 82 | 83 | loss = loss_content + 0.1 * loss_fft 84 | loss.backward() 85 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001) 86 | optimizer.step() 87 | 88 | iter_pixel_adder(loss_content.item()) 89 | iter_fft_adder(loss_fft.item()) 90 | 91 | epoch_pixel_adder(loss_content.item()) 92 | epoch_fft_adder(loss_fft.item()) 93 | 94 | if (iter_idx + 1) % args.print_freq == 0: 95 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 96 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 97 | iter_fft_adder.average())) 98 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 99 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 100 | 101 | iter_timer.tic() 102 | iter_pixel_adder.reset() 103 | iter_fft_adder.reset() 104 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 105 | torch.save({'model': model.state_dict(), 106 | 'optimizer': optimizer.state_dict(), 107 | 'epoch': epoch_idx}, overwrite_name) 108 | 109 | if epoch_idx % args.save_freq == 0: 110 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 111 | torch.save({'model': model.state_dict()}, save_name) 112 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 113 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 114 | epoch_fft_adder.reset() 115 | epoch_pixel_adder.reset() 116 | scheduler.step() 117 | if epoch_idx % args.valid_freq == 0: 118 | val = _valid(model, args, epoch_idx) 119 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val)) 120 | writer.add_scalar('PSNR', val, epoch_idx) 121 | if val >= best_psnr: 122 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 124 | torch.save({'model': model.state_dict()}, save_name) 125 | -------------------------------------------------------------------------------- /Dehazing/ITS/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Dehazing/ITS/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | dataset = valid_dataloader(args.data_dir, args.data, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Evaluation') 18 | factor = 32 19 | for idx, data in enumerate(dataset): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /Dehazing/OTS/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - reside-outdoor [[gdrive](https://drive.google.com/drive/folders/1eL4Qs-WNj7PzsKwDRsgUEzmysdjkRs22?usp=sharing)] 3 | - (Separate SOTS test set if needed) [[gdrive](https://drive.google.com/file/d/16j2dwVIa9q_0RtpIXMzhu-7Q6dwz_D1N/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1R6qWri7sG1hC_Ifj-H6DOQ?pwd=o5sk)] 4 | 5 | 6 | ### Train on RESIDE-Outdoor 7 | For example, train the small model 8 | ~~~ 9 | python main.py --mode train --data_dir your_path/reside-outdoor --version small 10 | ~~~ 11 | 12 | ### Evaluation 13 | Download model from [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 14 | 15 | ~~~ 16 | python main.py --mode test --data_dir your_path/reside-outdoor --test_model path_to_ots_model --version small 17 | ~~~ 18 | 19 | For training and testing, your directory structure should look like this 20 | 21 | `Your path` 
22 | `├──reside-outdoor`
23 |      `├──train`
24 |           `├──gt`
25 |           `└──hazy` 26 |      `└──test`
27 |           `├──gt`
28 |           `└──hazy` -------------------------------------------------------------------------------- /Dehazing/OTS/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Dehazing/OTS/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Dehazing/OTS/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image as Image 5 | from torchvision.transforms import functional as F 6 | from torch.utils.data import Dataset, DataLoader 7 | from PIL import ImageFile 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | def train_dataloader(path, batch_size=64, num_workers=0): 11 | image_dir = os.path.join(path, 'train') 12 | 13 | dataloader = DataLoader( 14 | DeblurDataset(image_dir, ps=256), 15 | batch_size=batch_size, 16 | shuffle=True, 17 | num_workers=num_workers, 18 | pin_memory=True 19 | ) 20 | return dataloader 21 | 22 | 23 | def test_dataloader(path, batch_size=1, num_workers=0): 24 | image_dir = os.path.join(path, 'test') 25 | dataloader = DataLoader( 26 | DeblurDataset(image_dir, is_test=True), 27 | batch_size=batch_size, 28 | shuffle=False, 29 | num_workers=num_workers, 30 | pin_memory=True 31 | ) 32 | 33 | return dataloader 34 | 35 | 36 | def valid_dataloader(path, batch_size=1, num_workers=0): 37 | dataloader = DataLoader( 38 | DeblurDataset(os.path.join(path, 'test'), is_valid=True), 39 | batch_size=batch_size, 40 | shuffle=False, 41 | num_workers=num_workers 42 | ) 43 | 44 | return dataloader 45 | 46 | import random 47 | class DeblurDataset(Dataset): 48 | def __init__(self, image_dir, transform=None, is_test=False, is_valid=False, ps=None): 49 | self.image_dir = image_dir 50 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/')) 51 | self._check_image(self.image_list) 52 | self.image_list.sort() 53 | self.transform = transform 54 | self.is_test = is_test 55 | self.is_valid = is_valid 56 | self.ps = ps 57 | 58 | def __len__(self): 59 | return len(self.image_list) 60 | 61 | def __getitem__(self, idx): 62 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])).convert('RGB') 63 | if self.is_valid or self.is_test: 64 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png')).convert('RGB') 65 | else: 66 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.jpg')).convert('RGB') 67 | ps = self.ps 68 | 69 | if self.ps is not None: 70 | image = F.to_tensor(image) 71 | label = F.to_tensor(label) 72 | 73 | hh, ww = label.shape[1], label.shape[2] 74 | 75 | rr = random.randint(0, hh-ps) 76 | cc = random.randint(0, ww-ps) 77 | 78 | image = image[:, rr:rr+ps, cc:cc+ps] 79 | label = label[:, rr:rr+ps, cc:cc+ps] 80 | 81 | if random.random() < 0.5: 82 | image = image.flip(2) 83 | label = label.flip(2) 84 | else: 85 | image = F.to_tensor(image) 86 | label = F.to_tensor(label) 87 | 88 | if self.is_test: 89 | name = self.image_list[idx] 90 | return image, label, name 91 | return image, label 92 | 93 | 94 | 95 | @staticmethod 96 | def _check_image(lst): 97 | for x in lst: 98 | splits = x.split('.') 99 | if splits[-1] not in ['png', 'jpg', 'jpeg']: 100 | raise ValueError 101 | -------------------------------------------------------------------------------- /Dehazing/OTS/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | import numpy as np 5 | from utils import Adder 6 | from data import test_dataloader 7 | from skimage.metrics import peak_signal_noise_ratio 8 | import time 9 | from pytorch_msssim import ssim 10 | import torch.nn.functional as f 11 | 12 | from skimage import img_as_ubyte 13 | import cv2 14 | # --------------------------------------------------- 15 | 16 | def _eval(model, args): 17 | state_dict = torch.load(args.test_model) 18 | model.load_state_dict(state_dict['model']) 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 21 | torch.cuda.empty_cache() 22 | adder = Adder() 23 | model.eval() 24 | factor = 32 25 | with torch.no_grad(): 26 | psnr_adder = Adder() 27 | ssim_adder = Adder() 28 | 29 | for iter_idx, data in enumerate(dataloader): 30 | input_img, label_img, name = data 31 | 32 | input_img = input_img.to(device) 33 | 34 | h, w = input_img.shape[2], input_img.shape[3] 35 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 36 | padh = H-h if h%factor!=0 else 0 37 | padw = W-w if w%factor!=0 else 0 38 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 39 | 40 | torch.cuda.synchronize() 41 | tm = time.time() 42 | 43 | pred = model(input_img)[2] 44 | pred = pred[:,:,:h,:w] 45 | torch.cuda.synchronize() 46 | 47 | elapsed = time.time() - tm 48 | adder(elapsed) 49 | 50 | pred_clip = torch.clamp(pred, 0, 1) 51 | 52 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 53 | label_numpy = label_img.squeeze(0).cpu().numpy() 54 | 55 | label_img = (label_img).cuda() 56 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img)) 57 | down_ratio = max(1, round(min(H, W) / 256)) 58 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 59 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 60 | data_range=1, size_average=False) 61 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val)) 62 | ssim_adder(ssim_val) 63 | 64 | if args.save_image: 65 | save_name = os.path.join(args.result_dir, name[0]) 66 | pred_clip += 0.5 / 255 67 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 68 | pred.save(save_name) 69 | 70 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 71 | psnr_adder(psnr_val) 72 | 73 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed)) 74 | 75 | print('==========================================================') 76 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 77 | print('The average SSIM is %.4f dB' % (ssim_adder.average())) 78 | 79 | print("Average time: %f" % adder.average()) 80 | 81 | -------------------------------------------------------------------------------- /Dehazing/OTS/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.ConvIR import build_net 6 | from train import _train 7 | from eval import _eval 8 | 9 | 10 | def main(args): 11 | # CUDNN 12 | cudnn.benchmark = True 13 | 14 | if not os.path.exists('results/'): 15 | os.makedirs(args.model_save_dir) 16 | if not os.path.exists('results/' + args.model_name + '/'): 17 | os.makedirs('results/' + args.model_name + '/') 18 | if not os.path.exists(args.result_dir): 19 | os.makedirs(args.result_dir) 20 | 21 | model = build_net(args.type) 22 | # print(model) 23 | 24 | if torch.cuda.is_available(): 25 | model.cuda() 26 | if args.mode == 'train': 27 | _train(model, args) 28 | 29 | elif args.mode == 'test': 30 | _eval(model, args) 31 | 32 | 33 | if __name__ == '__main__': 34 | parser = argparse.ArgumentParser() 35 | 36 | # Directories 37 | parser.add_argument('--model_name', default='ConvIR', type=str) 38 | parser.add_argument('--data_dir', type=str, default='') 39 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) 40 | parser.add_argument('--type', default='small', choices=['small', 'base', 'large'], type=str) 41 | 42 | # Train 43 | parser.add_argument('--batch_size', type=int, default=8) 44 | parser.add_argument('--learning_rate', type=float, default=1e-4) 45 | parser.add_argument('--weight_decay', type=float, default=0) 46 | parser.add_argument('--num_epoch', type=int, default=30) 47 | parser.add_argument('--print_freq', type=int, default=100) 48 | parser.add_argument('--num_worker', type=int, default=8) 49 | parser.add_argument('--save_freq', type=int, default=1) 50 | parser.add_argument('--valid_freq', type=int, default=1) 51 | parser.add_argument('--resume', type=str, default='') 52 | 53 | 54 | # Test 55 | parser.add_argument('--test_model', type=str, default='') 56 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 57 | 58 | args = parser.parse_args() 59 | args.model_save_dir = os.path.join('results/', 'ConvIR', 'OTS/') 60 | args.result_dir = os.path.join('results/', args.model_name, 'test') 61 | if not os.path.exists(args.model_save_dir): 62 | os.makedirs(args.model_save_dir) 63 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 64 | os.system(command) 65 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir 66 | os.system(command) 67 | command = 'cp ' + 'train.py ' + args.model_save_dir 68 | os.system(command) 69 | command = 'cp ' + 'main.py ' + args.model_save_dir 70 | os.system(command) 71 | print(args) 72 | main(args) 73 | -------------------------------------------------------------------------------- /Dehazing/OTS/models/ConvIR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | 8 | class EBlock(nn.Module): 9 | def __init__(self, out_channel, num_res=8): 10 | super(EBlock, self).__init__() 11 | 12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)] 13 | layers.append(ResBlock(out_channel, out_channel, filter=True)) 14 | 15 | self.layers = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DBlock(nn.Module): 22 | def __init__(self, channel, num_res=8): 23 | super(DBlock, self).__init__() 24 | 25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)] 26 | layers.append(ResBlock(channel, channel, filter=True)) 27 | self.layers = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.layers(x) 31 | 32 | 33 | class SCM(nn.Module): 34 | def __init__(self, out_plane): 35 | super(SCM, self).__init__() 36 | self.main = nn.Sequential( 37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 41 | nn.InstanceNorm2d(out_plane, affine=True) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.main(x) 46 | return x 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, channel): 50 | super(FAM, self).__init__() 51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 52 | 53 | def forward(self, x1, x2): 54 | return self.merge(torch.cat([x1, x2], dim=1)) 55 | 56 | class ConvIR(nn.Module): 57 | def __init__(self, version): 58 | super(ConvIR, self).__init__() 59 | 60 | if version == 'small': 61 | num_res = 4 62 | elif version == 'base': 63 | num_res = 8 64 | elif version == 'large': 65 | num_res = 16 66 | 67 | base_channel = 32 68 | 69 | self.Encoder = nn.ModuleList([ 70 | EBlock(base_channel, num_res), 71 | EBlock(base_channel*2, num_res), 72 | EBlock(base_channel*4, num_res), 73 | ]) 74 | 75 | self.feat_extract = nn.ModuleList([ 76 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 77 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 78 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 79 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 80 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 81 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 82 | ]) 83 | 84 | self.Decoder = nn.ModuleList([ 85 | DBlock(base_channel * 4, num_res), 86 | DBlock(base_channel * 2, num_res), 87 | DBlock(base_channel, num_res) 88 | ]) 89 | 90 | self.Convs = nn.ModuleList([ 91 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 92 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 93 | ]) 94 | 95 | self.ConvsOut = nn.ModuleList( 96 | [ 97 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 98 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 99 | ] 100 | ) 101 | 102 | self.FAM1 = FAM(base_channel * 4) 103 | self.SCM1 = SCM(base_channel * 4) 104 | self.FAM2 = FAM(base_channel * 2) 105 | self.SCM2 = SCM(base_channel * 2) 106 | 107 | def forward(self, x): 108 | x_2 = F.interpolate(x, scale_factor=0.5) 109 | x_4 = F.interpolate(x_2, scale_factor=0.5) 110 | z2 = self.SCM2(x_2) 111 | z4 = self.SCM1(x_4) 112 | 113 | outputs = list() 114 | # 256 115 | x_ = self.feat_extract[0](x) 116 | res1 = self.Encoder[0](x_) 117 | # 128 118 | z = self.feat_extract[1](res1) 119 | z = self.FAM2(z, z2) 120 | res2 = self.Encoder[1](z) 121 | # 64 122 | z = self.feat_extract[2](res2) 123 | z = self.FAM1(z, z4) 124 | z = self.Encoder[2](z) 125 | 126 | z = self.Decoder[0](z) 127 | z_ = self.ConvsOut[0](z) 128 | # 128 129 | z = self.feat_extract[3](z) 130 | outputs.append(z_+x_4) 131 | 132 | z = torch.cat([z, res2], dim=1) 133 | z = self.Convs[0](z) 134 | z = self.Decoder[1](z) 135 | z_ = self.ConvsOut[1](z) 136 | # 256 137 | z = self.feat_extract[4](z) 138 | outputs.append(z_+x_2) 139 | 140 | z = torch.cat([z, res1], dim=1) 141 | z = self.Convs[1](z) 142 | z = self.Decoder[2](z) 143 | z = self.feat_extract[5](z) 144 | outputs.append(z+x) 145 | 146 | return outputs 147 | 148 | 149 | def build_net(version): 150 | return ConvIR(version) 151 | -------------------------------------------------------------------------------- /Dehazing/OTS/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel, filter=False): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(), 35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.main(x) + x 40 | 41 | 42 | class DeepPoolLayer(nn.Module): 43 | def __init__(self, k, k_out): 44 | super(DeepPoolLayer, self).__init__() 45 | self.pools_sizes = [8,4,2] 46 | dilation = [7,9,11] 47 | pools, convs, dynas = [],[],[] 48 | for j, i in enumerate(self.pools_sizes): 49 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 50 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 51 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j])) 52 | self.pools = nn.ModuleList(pools) 53 | self.convs = nn.ModuleList(convs) 54 | self.dynas = nn.ModuleList(dynas) 55 | self.relu = nn.GELU() 56 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 57 | 58 | def forward(self, x): 59 | x_size = x.size() 60 | resl = x 61 | for i in range(len(self.pools_sizes)): 62 | if i == 0: 63 | y = self.dynas[i](self.convs[i](self.pools[i](x))) 64 | else: 65 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up)) 66 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 67 | if i != len(self.pools_sizes)-1: 68 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) 69 | resl = self.relu(resl) 70 | resl = self.conv_sum(resl) 71 | 72 | return resl 73 | 74 | class dynamic_filter(nn.Module): 75 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8): 76 | super(dynamic_filter, self).__init__() 77 | self.stride = stride 78 | self.kernel_size = kernel_size 79 | self.group = group 80 | self.dilation = dilation 81 | 82 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) 83 | self.bn = nn.BatchNorm2d(group*kernel_size**2) 84 | self.act = nn.Tanh() 85 | 86 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 87 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 88 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 89 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2) 90 | 91 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 92 | self.gap = nn.AdaptiveAvgPool2d(1) 93 | 94 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True) 95 | 96 | def forward(self, x): 97 | identity_input = x 98 | low_filter = self.ap(x) 99 | low_filter = self.conv(low_filter) 100 | low_filter = self.bn(low_filter) 101 | 102 | n, c, h, w = x.shape 103 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) 104 | 105 | n,c1,p,q = low_filter.shape 106 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) 107 | 108 | low_filter = self.act(low_filter) 109 | 110 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) 111 | 112 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 113 | 114 | out_low = out_low * self.lamb_l[None,:,None,None] 115 | 116 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.) 117 | 118 | return out_low + out_high 119 | 120 | 121 | class cubic_attention(nn.Module): 122 | def __init__(self, dim, group, dilation, kernel) -> None: 123 | super().__init__() 124 | 125 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel) 126 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False) 127 | self.gamma = nn.Parameter(torch.zeros(dim,1,1)) 128 | self.beta = nn.Parameter(torch.ones(dim,1,1)) 129 | 130 | def forward(self, x): 131 | out = self.H_spatial_att(x) 132 | out = self.W_spatial_att(out) 133 | return self.gamma * out + x * self.beta 134 | 135 | 136 | class spatial_strip_att(nn.Module): 137 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None: 138 | super().__init__() 139 | 140 | self.k = kernel 141 | pad = dilation*(kernel-1) // 2 142 | self.kernel = (1, kernel) if H else (kernel, 1) 143 | self.padding = (kernel//2, 1) if H else (1, kernel//2) 144 | self.dilation = dilation 145 | self.group = group 146 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad)) 147 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False) 148 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 149 | self.filter_act = nn.Tanh() 150 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True) 151 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True) 152 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True) 153 | gap_kernel = (None,1) if H else (1, None) 154 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel) 155 | 156 | def forward(self, x): 157 | identity_input = x.clone() 158 | filter = self.ap(x) 159 | filter = self.conv(filter) 160 | n, c, h, w = x.shape 161 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w) 162 | n, c1, p, q = filter.shape 163 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2) 164 | filter = self.filter_act(filter) 165 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w) 166 | 167 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 168 | out_low = out_low * self.lamb_l[None,:,None,None] 169 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.) 170 | 171 | return out_low + out_high 172 | 173 | 174 | class MultiShapeKernel(nn.Module): 175 | def __init__(self, dim, kernel_size=3, dilation=1, group=8): 176 | super().__init__() 177 | 178 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size) 179 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size) 180 | 181 | def forward(self, x): 182 | 183 | x1 = self.strip_att(x) 184 | x2 = self.square_att(x) 185 | 186 | return x1+x2 187 | 188 | 189 | -------------------------------------------------------------------------------- /Dehazing/OTS/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | def _train(model, args): 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | criterion = torch.nn.L1Loss() 15 | 16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 18 | max_iter = len(dataloader) 19 | warmup_epochs=1 20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 22 | scheduler.step() 23 | epoch = 1 24 | if args.resume: 25 | state = torch.load(args.resume) 26 | epoch = state['epoch'] 27 | optimizer.load_state_dict(state['optimizer']) 28 | model.load_state_dict(state['model']) 29 | print('Resume from %d'%epoch) 30 | epoch += 1 31 | 32 | writer = SummaryWriter() 33 | epoch_pixel_adder = Adder() 34 | epoch_fft_adder = Adder() 35 | iter_pixel_adder = Adder() 36 | iter_fft_adder = Adder() 37 | epoch_timer = Timer('m') 38 | iter_timer = Timer('m') 39 | best_psnr=-1 40 | 41 | eval_now = max_iter//6-1 42 | 43 | for epoch_idx in range(epoch, args.num_epoch + 1): 44 | 45 | epoch_timer.tic() 46 | iter_timer.tic() 47 | for iter_idx, batch_data in enumerate(dataloader): 48 | 49 | input_img, label_img = batch_data 50 | input_img = input_img.to(device) 51 | label_img = label_img.to(device) 52 | 53 | optimizer.zero_grad() 54 | pred_img = model(input_img) 55 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 56 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 57 | l1 = criterion(pred_img[0], label_img4) 58 | l2 = criterion(pred_img[1], label_img2) 59 | l3 = criterion(pred_img[2], label_img) 60 | loss_content = l1+l2+l3 61 | 62 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 63 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 64 | 65 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 66 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 67 | 68 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 69 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 70 | 71 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 72 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 73 | 74 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 75 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 76 | 77 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 78 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 79 | 80 | f1 = criterion(pred_fft1, label_fft1) 81 | f2 = criterion(pred_fft2, label_fft2) 82 | f3 = criterion(pred_fft3, label_fft3) 83 | loss_fft = f1+f2+f3 84 | 85 | loss = loss_content + 0.1 * loss_fft 86 | loss.backward() 87 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) 88 | optimizer.step() 89 | 90 | iter_pixel_adder(loss_content.item()) 91 | iter_fft_adder(loss_fft.item()) 92 | 93 | epoch_pixel_adder(loss_content.item()) 94 | epoch_fft_adder(loss_fft.item()) 95 | 96 | if (iter_idx + 1) % args.print_freq == 0: 97 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 98 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 99 | iter_fft_adder.average())) 100 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 101 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 102 | 103 | iter_timer.tic() 104 | iter_pixel_adder.reset() 105 | iter_fft_adder.reset() 106 | 107 | 108 | if iter_idx%eval_now==0 and iter_idx>0 and (epoch_idx>20 or epoch_idx == 1): 109 | 110 | save_name = os.path.join(args.model_save_dir, 'model_%d_%d.pkl' % (epoch_idx, iter_idx)) 111 | torch.save({'model': model.state_dict()}, save_name) 112 | 113 | val_gopro = _valid(model, args, epoch_idx) 114 | print('%03d epoch \n Average GOPRO PSNR %.2f dB' % (epoch_idx, val_gopro)) 115 | writer.add_scalar('PSNR_GOPRO', val_gopro, epoch_idx) 116 | if val_gopro >= best_psnr: 117 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 118 | 119 | 120 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 121 | torch.save({'model': model.state_dict()}, overwrite_name) 122 | 123 | 124 | if epoch_idx % args.save_freq == 0: 125 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 126 | torch.save({'model': model.state_dict()}, save_name) 127 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 128 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 129 | epoch_fft_adder.reset() 130 | epoch_pixel_adder.reset() 131 | scheduler.step() 132 | 133 | if epoch_idx % args.valid_freq == 0: 134 | val = _valid(model, args, epoch_idx) 135 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val)) 136 | writer.add_scalar('PSNR', val, epoch_idx) 137 | if val >= best_psnr: 138 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 139 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 140 | torch.save({'model': model.state_dict()}, save_name) 141 | -------------------------------------------------------------------------------- /Dehazing/OTS/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Dehazing/OTS/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | ots = valid_dataloader(args.data_dir, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Evaluation') 18 | factor = 32 19 | for idx, data in enumerate(ots): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /Image_deraining/README.md: -------------------------------------------------------------------------------- 1 | ### Download the dataset and model 2 | 3 | [training_rainy](https://drive.google.com/file/d/1zmT3EuYAKfqiLlrnAz5SrOyuEHl52M71/view?usp=sharing) 4 | [training_gt](https://drive.google.com/file/d/1unbTfkL3hhWtiwL27_ipLkPMC-aW3Ifc/view?usp=drive_link) 5 | [testset](https://drive.google.com/file/d/1Tmz96YWtiBcA_mEu6Jrmp2R86j2g6Mfp/view?usp=drive_link) 6 | model: [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 7 | 8 | 9 | ### Training 10 | 11 | After setting the data_dir and valid_data in ``main.py``, then run 12 | 13 | ~~~ 14 | python main.py 15 | ~~~ 16 | 17 | ### Testing 18 | After setting the data_dir and test_model in ``test.py``, then run 19 | ~~~ 20 | python test.py 21 | ~~~ 22 | 23 | By doing this, the resulting images will be saved. 24 | Next, run the matlab file to get the scores. 25 | 26 | 27 | -------------------------------------------------------------------------------- /Image_deraining/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor, PairCenterCrop 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Image_deraining/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | w,h = image.size[0], image.size[1] 11 | padw = self.size[0]-w if w 0 22 | parfor j = 1:img_num 23 | image_name = path_list(j).name; 24 | gt_name = gt_list(j).name; 25 | input = imread(strcat(file_path,image_name)); 26 | gt = imread(strcat(gt_path, gt_name)); 27 | ssim_val = compute_ssim(input, gt); 28 | psnr_val = compute_psnr(input, gt); 29 | total_ssim = total_ssim + ssim_val; 30 | total_psnr = total_psnr + psnr_val; 31 | end 32 | end 33 | qm_psnr = total_psnr / img_num; 34 | qm_ssim = total_ssim / img_num; 35 | 36 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 37 | 38 | psnr_alldatasets = psnr_alldatasets + qm_psnr; 39 | ssim_alldatasets = ssim_alldatasets + qm_ssim; 40 | 41 | end 42 | 43 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set); 44 | 45 | delete(gcp('nocreate')) 46 | toc 47 | 48 | function ssim_mean=compute_ssim(img1,img2) 49 | if size(img1, 3) == 3 50 | img1 = rgb2ycbcr(img1); 51 | img1 = img1(:, :, 1); 52 | end 53 | 54 | if size(img2, 3) == 3 55 | img2 = rgb2ycbcr(img2); 56 | img2 = img2(:, :, 1); 57 | end 58 | ssim_mean = SSIM_index(img1, img2); 59 | end 60 | 61 | function psnr=compute_psnr(img1,img2) 62 | if size(img1, 3) == 3 63 | img1 = rgb2ycbcr(img1); 64 | img1 = img1(:, :, 1); 65 | end 66 | 67 | if size(img2, 3) == 3 68 | img2 = rgb2ycbcr(img2); 69 | img2 = img2(:, :, 1); 70 | end 71 | 72 | imdff = double(img1) - double(img2); 73 | imdff = imdff(:); 74 | rmse = sqrt(mean(imdff.^2)); 75 | psnr = 20*log10(255/rmse); 76 | 77 | end 78 | 79 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L) 80 | 81 | %======================================================================== 82 | %SSIM Index, Version 1.0 83 | %Copyright(c) 2003 Zhou Wang 84 | %All Rights Reserved. 85 | % 86 | %The author is with Howard Hughes Medical Institute, and Laboratory 87 | %for Computational Vision at Center for Neural Science and Courant 88 | %Institute of Mathematical Sciences, New York University. 89 | % 90 | %---------------------------------------------------------------------- 91 | %Permission to use, copy, or modify this software and its documentation 92 | %for educational and research purposes only and without fee is hereby 93 | %granted, provided that this copyright notice and the original authors' 94 | %names appear on all copies and supporting documentation. This program 95 | %shall not be used, rewritten, or adapted as the basis of a commercial 96 | %software or hardware product without first obtaining permission of the 97 | %authors. The authors make no representations about the suitability of 98 | %this software for any purpose. It is provided "as is" without express 99 | %or implied warranty. 100 | %---------------------------------------------------------------------- 101 | % 102 | %This is an implementation of the algorithm for calculating the 103 | %Structural SIMilarity (SSIM) index between two images. Please refer 104 | %to the following paper: 105 | % 106 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 107 | %quality assessment: From error measurement to structural similarity" 108 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 109 | % 110 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 111 | % 112 | %---------------------------------------------------------------------- 113 | % 114 | %Input : (1) img1: the first image being compared 115 | % (2) img2: the second image being compared 116 | % (3) K: constants in the SSIM index formula (see the above 117 | % reference). defualt value: K = [0.01 0.03] 118 | % (4) window: local window for statistics (see the above 119 | % reference). default widnow is Gaussian given by 120 | % window = fspecial('gaussian', 11, 1.5); 121 | % (5) L: dynamic range of the images. default: L = 255 122 | % 123 | %Output: (1) mssim: the mean SSIM index value between 2 images. 124 | % If one of the images being compared is regarded as 125 | % perfect quality, then mssim can be considered as the 126 | % quality measure of the other image. 127 | % If img1 = img2, then mssim = 1. 128 | % (2) ssim_map: the SSIM index map of the test image. The map 129 | % has a smaller size than the input images. The actual size: 130 | % size(img1) - size(window) + 1. 131 | % 132 | %Default Usage: 133 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 134 | % 135 | % [mssim ssim_map] = ssim_index(img1, img2); 136 | % 137 | %Advanced Usage: 138 | % User defined parameters. For example 139 | % 140 | % K = [0.05 0.05]; 141 | % window = ones(8); 142 | % L = 100; 143 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 144 | % 145 | %See the results: 146 | % 147 | % mssim %Gives the mssim value 148 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 149 | % 150 | %======================================================================== 151 | 152 | 153 | if (nargin < 2 || nargin > 5) 154 | ssim_index = -Inf; 155 | ssim_map = -Inf; 156 | return; 157 | end 158 | 159 | if (size(img1) ~= size(img2)) 160 | ssim_index = -Inf; 161 | ssim_map = -Inf; 162 | return; 163 | end 164 | 165 | [M N] = size(img1); 166 | 167 | if (nargin == 2) 168 | if ((M < 11) || (N < 11)) 169 | ssim_index = -Inf; 170 | ssim_map = -Inf; 171 | return 172 | end 173 | window = fspecial('gaussian', 11, 1.5); % 174 | K(1) = 0.01; % default settings 175 | K(2) = 0.03; % 176 | L = 255; % 177 | end 178 | 179 | if (nargin == 3) 180 | if ((M < 11) || (N < 11)) 181 | ssim_index = -Inf; 182 | ssim_map = -Inf; 183 | return 184 | end 185 | window = fspecial('gaussian', 11, 1.5); 186 | L = 255; 187 | if (length(K) == 2) 188 | if (K(1) < 0 || K(2) < 0) 189 | ssim_index = -Inf; 190 | ssim_map = -Inf; 191 | return; 192 | end 193 | else 194 | ssim_index = -Inf; 195 | ssim_map = -Inf; 196 | return; 197 | end 198 | end 199 | 200 | if (nargin == 4) 201 | [H W] = size(window); 202 | if ((H*W) < 4 || (H > M) || (W > N)) 203 | ssim_index = -Inf; 204 | ssim_map = -Inf; 205 | return 206 | end 207 | L = 255; 208 | if (length(K) == 2) 209 | if (K(1) < 0 || K(2) < 0) 210 | ssim_index = -Inf; 211 | ssim_map = -Inf; 212 | return; 213 | end 214 | else 215 | ssim_index = -Inf; 216 | ssim_map = -Inf; 217 | return; 218 | end 219 | end 220 | 221 | if (nargin == 5) 222 | [H W] = size(window); 223 | if ((H*W) < 4 || (H > M) || (W > N)) 224 | ssim_index = -Inf; 225 | ssim_map = -Inf; 226 | return 227 | end 228 | if (length(K) == 2) 229 | if (K(1) < 0 || K(2) < 0) 230 | ssim_index = -Inf; 231 | ssim_map = -Inf; 232 | return; 233 | end 234 | else 235 | ssim_index = -Inf; 236 | ssim_map = -Inf; 237 | return; 238 | end 239 | end 240 | 241 | C1 = (K(1)*L)^2; 242 | C2 = (K(2)*L)^2; 243 | window = window/sum(sum(window)); 244 | img1 = double(img1); 245 | img2 = double(img2); 246 | 247 | mu1 = filter2(window, img1, 'valid'); 248 | mu2 = filter2(window, img2, 'valid'); 249 | mu1_sq = mu1.*mu1; 250 | mu2_sq = mu2.*mu2; 251 | mu1_mu2 = mu1.*mu2; 252 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 253 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 254 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 255 | 256 | if (C1 > 0 & C2 > 0) 257 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 258 | else 259 | numerator1 = 2*mu1_mu2 + C1; 260 | numerator2 = 2*sigma12 + C2; 261 | denominator1 = mu1_sq + mu2_sq + C1; 262 | denominator2 = sigma1_sq + sigma2_sq + C2; 263 | ssim_map = ones(size(mu1)); 264 | index = (denominator1.*denominator2 > 0); 265 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 266 | index = (denominator1 ~= 0) & (denominator2 == 0); 267 | ssim_map(index) = numerator1(index)./denominator1(index); 268 | end 269 | 270 | mssim = mean2(ssim_map); 271 | 272 | end 273 | -------------------------------------------------------------------------------- /Image_deraining/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.ConvIR import build_net 6 | from train import _train 7 | from eval import _eval 8 | 9 | def main(args): 10 | cudnn.benchmark = True 11 | 12 | if not os.path.exists('results/'): 13 | os.makedirs(args.model_save_dir) 14 | if not os.path.exists('results/' + args.model_name + '/'): 15 | os.makedirs('results/' + args.model_name + '/') 16 | if not os.path.exists(args.result_dir): 17 | os.makedirs(args.result_dir) 18 | 19 | model = build_net() 20 | print(model) 21 | 22 | if torch.cuda.is_available(): 23 | model.cuda() 24 | if args.mode == 'train': 25 | _train(model, args) 26 | 27 | elif args.mode == 'test': 28 | _eval(model, args) 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser() 33 | 34 | # Directories 35 | parser.add_argument('--model_name', default='ConvIR', type=str) 36 | parser.add_argument('--data_dir', type=str, default='../../dataset/deraining/train/Rain13K/') 37 | parser.add_argument('--valid_data', type=str, default='/Rain100K/') 38 | 39 | # Train 40 | parser.add_argument('--batch_size', type=int, default=4) 41 | parser.add_argument('--learning_rate', type=float, default=1e-4) 42 | parser.add_argument('--weight_decay', type=float, default=0) 43 | parser.add_argument('--num_epoch', type=int, default=300) 44 | parser.add_argument('--print_freq', type=int, default=100) 45 | parser.add_argument('--num_worker', type=int, default=8) 46 | parser.add_argument('--save_freq', type=int, default=10) 47 | 48 | parser.add_argument('--valid_freq', type=int, default=10) 49 | parser.add_argument('--resume', type=str, default='') 50 | parser.add_argument('--gamma', type=float, default=0.5) 51 | # Test 52 | parser.add_argument('--test_model', type=str, default='') 53 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 54 | 55 | args = parser.parse_args() 56 | args.model_save_dir = os.path.join('results/', 'ConvIR', 'train_results/') 57 | args.result_dir = os.path.join('results/', args.model_name, 'test') 58 | if not os.path.exists(args.model_save_dir): 59 | os.makedirs(args.model_save_dir) 60 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 61 | os.system(command) 62 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir 63 | os.system(command) 64 | command = 'cp ' + 'train.py ' + args.model_save_dir 65 | os.system(command) 66 | command = 'cp ' + 'main.py ' + args.model_save_dir 67 | os.system(command) 68 | print(args) 69 | main(args) 70 | -------------------------------------------------------------------------------- /Image_deraining/models/ConvIR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | 8 | class EBlock(nn.Module): 9 | def __init__(self, out_channel, num_res=8): 10 | super(EBlock, self).__init__() 11 | 12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)] 13 | layers.append(ResBlock(out_channel, out_channel, filter=True)) 14 | 15 | self.layers = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DBlock(nn.Module): 22 | def __init__(self, channel, num_res=8): 23 | super(DBlock, self).__init__() 24 | 25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)] 26 | layers.append(ResBlock(channel, channel, filter=True)) 27 | self.layers = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.layers(x) 31 | 32 | 33 | class SCM(nn.Module): 34 | def __init__(self, out_plane): 35 | super(SCM, self).__init__() 36 | self.main = nn.Sequential( 37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 41 | nn.InstanceNorm2d(out_plane, affine=True) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.main(x) 46 | return x 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, channel): 50 | super(FAM, self).__init__() 51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 52 | 53 | def forward(self, x1, x2): 54 | return self.merge(torch.cat([x1, x2], dim=1)) 55 | 56 | class ConvIR(nn.Module): 57 | def __init__(self, num_res=16): 58 | super(ConvIR, self).__init__() 59 | 60 | base_channel = 32 61 | 62 | self.Encoder = nn.ModuleList([ 63 | EBlock(base_channel, num_res), 64 | EBlock(base_channel*2, num_res), 65 | EBlock(base_channel*4, num_res), 66 | ]) 67 | 68 | self.feat_extract = nn.ModuleList([ 69 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 70 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 71 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 72 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 73 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 74 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 75 | ]) 76 | 77 | self.Decoder = nn.ModuleList([ 78 | DBlock(base_channel * 4, num_res), 79 | DBlock(base_channel * 2, num_res), 80 | DBlock(base_channel, num_res) 81 | ]) 82 | 83 | self.Convs = nn.ModuleList([ 84 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 85 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 86 | ]) 87 | 88 | self.ConvsOut = nn.ModuleList( 89 | [ 90 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 91 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 92 | ] 93 | ) 94 | 95 | self.FAM1 = FAM(base_channel * 4) 96 | self.SCM1 = SCM(base_channel * 4) 97 | self.FAM2 = FAM(base_channel * 2) 98 | self.SCM2 = SCM(base_channel * 2) 99 | 100 | def forward(self, x): 101 | x_2 = F.interpolate(x, scale_factor=0.5) 102 | x_4 = F.interpolate(x_2, scale_factor=0.5) 103 | z2 = self.SCM2(x_2) 104 | z4 = self.SCM1(x_4) 105 | 106 | outputs = list() 107 | # 256 108 | x_ = self.feat_extract[0](x) 109 | res1 = self.Encoder[0](x_) 110 | # 128 111 | z = self.feat_extract[1](res1) 112 | z = self.FAM2(z, z2) 113 | res2 = self.Encoder[1](z) 114 | # 64 115 | z = self.feat_extract[2](res2) 116 | z = self.FAM1(z, z4) 117 | z = self.Encoder[2](z) 118 | 119 | z = self.Decoder[0](z) 120 | z_ = self.ConvsOut[0](z) 121 | # 128 122 | z = self.feat_extract[3](z) 123 | outputs.append(z_+x_4) 124 | 125 | z = torch.cat([z, res2], dim=1) 126 | z = self.Convs[0](z) 127 | z = self.Decoder[1](z) 128 | z_ = self.ConvsOut[1](z) 129 | # 256 130 | z = self.feat_extract[4](z) 131 | outputs.append(z_+x_2) 132 | 133 | z = torch.cat([z, res1], dim=1) 134 | z = self.Convs[1](z) 135 | z = self.Decoder[2](z) 136 | z = self.feat_extract[5](z) 137 | outputs.append(z+x) 138 | 139 | return outputs 140 | 141 | 142 | def build_net(): 143 | return ConvIR() 144 | -------------------------------------------------------------------------------- /Image_deraining/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | class ResBlock(nn.Module): 29 | def __init__(self, in_channel, out_channel, filter=False): 30 | super(ResBlock, self).__init__() 31 | self.main = nn.Sequential( 32 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 33 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(), 34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.main(x) + x 39 | 40 | class DeepPoolLayer(nn.Module): 41 | def __init__(self, k, k_out): 42 | super(DeepPoolLayer, self).__init__() 43 | self.pools_sizes = [8,4,2] 44 | dilation = [3,7,9] 45 | pools, convs, dynas = [],[],[] 46 | for j, i in enumerate(self.pools_sizes): 47 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 48 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 49 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j])) 50 | self.pools = nn.ModuleList(pools) 51 | self.convs = nn.ModuleList(convs) 52 | self.dynas = nn.ModuleList(dynas) 53 | self.relu = nn.GELU() 54 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 55 | 56 | def forward(self, x): 57 | x_size = x.size() 58 | resl = x 59 | for i in range(len(self.pools_sizes)): 60 | if i == 0: 61 | y = self.dynas[i](self.convs[i](self.pools[i](x))) 62 | else: 63 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up)) 64 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 65 | if i != len(self.pools_sizes)-1: 66 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) 67 | resl = self.relu(resl) 68 | resl = self.conv_sum(resl) 69 | 70 | return resl 71 | 72 | class dynamic_filter(nn.Module): 73 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8): 74 | super(dynamic_filter, self).__init__() 75 | self.stride = stride 76 | self.kernel_size = kernel_size 77 | self.group = group 78 | self.dilation = dilation 79 | 80 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) 81 | self.bn = nn.BatchNorm2d(group*kernel_size**2) 82 | self.act = nn.Tanh() 83 | 84 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 85 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 86 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 87 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2) 88 | 89 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 90 | self.gap = nn.AdaptiveAvgPool2d(1) 91 | 92 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True) 93 | 94 | def forward(self, x): 95 | identity_input = x 96 | low_filter = self.ap(x) 97 | low_filter = self.conv(low_filter) 98 | low_filter = self.bn(low_filter) 99 | 100 | n, c, h, w = x.shape 101 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) 102 | 103 | n,c1,p,q = low_filter.shape 104 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) 105 | 106 | low_filter = self.act(low_filter) 107 | 108 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) 109 | 110 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 111 | 112 | out_low = out_low * self.lamb_l[None,:,None,None] 113 | 114 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.) 115 | 116 | return out_low + out_high 117 | 118 | 119 | class cubic_attention(nn.Module): 120 | def __init__(self, dim, group, dilation, kernel) -> None: 121 | super().__init__() 122 | 123 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel) 124 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False) 125 | self.gamma = nn.Parameter(torch.zeros(dim,1,1)) 126 | self.beta = nn.Parameter(torch.ones(dim,1,1)) 127 | 128 | def forward(self, x): 129 | out = self.H_spatial_att(x) 130 | out = self.W_spatial_att(out) 131 | return self.gamma * out + x * self.beta 132 | 133 | 134 | class spatial_strip_att(nn.Module): 135 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None: 136 | super().__init__() 137 | 138 | self.k = kernel 139 | pad = dilation*(kernel-1) // 2 140 | self.kernel = (1, kernel) if H else (kernel, 1) 141 | self.padding = (kernel//2, 1) if H else (1, kernel//2) 142 | self.dilation = dilation 143 | self.group = group 144 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad)) 145 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False) 146 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 147 | self.filter_act = nn.Tanh() 148 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True) 149 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True) 150 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True) 151 | gap_kernel = (None,1) if H else (1, None) 152 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel) 153 | 154 | def forward(self, x): 155 | identity_input = x.clone() 156 | filter = self.ap(x) 157 | filter = self.conv(filter) 158 | n, c, h, w = x.shape 159 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w) 160 | n, c1, p, q = filter.shape 161 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2) 162 | filter = self.filter_act(filter) 163 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w) 164 | 165 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 166 | out_low = out_low * self.lamb_l[None,:,None,None] 167 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.) 168 | 169 | return out_low + out_high 170 | 171 | 172 | class MultiShapeKernel(nn.Module): 173 | def __init__(self, dim, kernel_size=3, dilation=1, group=8): 174 | super().__init__() 175 | 176 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size) 177 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size) 178 | 179 | def forward(self, x): 180 | 181 | x1 = self.strip_att(x) 182 | x2 = self.square_att(x) 183 | 184 | return x1+x2 185 | 186 | 187 | -------------------------------------------------------------------------------- /Image_deraining/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from models.SFNet import build_net 5 | from data import test_dataloader 6 | from utils import Adder 7 | import time 8 | from torchvision.transforms import functional as F 9 | from skimage.metrics import peak_signal_noise_ratio 10 | from torch.utils.data import Dataset, DataLoader 11 | from PIL import Image 12 | from tqdm import tqdm 13 | import torch.nn.functional as f 14 | 15 | class DeblurDataset(Dataset): 16 | def __init__(self, image_dir, transform=None, is_test=False): 17 | self.image_dir = image_dir 18 | self.image_list = os.listdir(os.path.join(image_dir, 'input/')) 19 | self._check_image(self.image_list) 20 | self.image_list.sort() 21 | self.transform = transform 22 | self.is_test = is_test 23 | 24 | def __len__(self): 25 | return len(self.image_list) 26 | 27 | def __getitem__(self, idx): 28 | image = Image.open(os.path.join(self.image_dir, 'input', self.image_list[idx])) 29 | label = Image.open(os.path.join(self.image_dir, 'target', self.image_list[idx])) 30 | 31 | if self.transform: 32 | image, label = self.transform(image, label) 33 | else: 34 | image = F.to_tensor(image) 35 | label = F.to_tensor(label) 36 | if self.is_test: 37 | name = self.image_list[idx] 38 | return image, label, name 39 | return image, label 40 | 41 | @staticmethod 42 | def _check_image(lst): 43 | for x in lst: 44 | splits = x.split('.') 45 | if splits[-1] not in ['png', 'jpg', 'jpeg']: 46 | raise ValueError 47 | 48 | def test_dataloader(path, batch_size=1, num_workers=0): 49 | dataloader = DataLoader( 50 | DeblurDataset(path, is_test=True), 51 | batch_size=batch_size, 52 | shuffle=False, 53 | num_workers=num_workers, 54 | pin_memory=True 55 | ) 56 | 57 | return dataloader 58 | 59 | 60 | parser = argparse.ArgumentParser() 61 | 62 | # Directories 63 | parser.add_argument('--model_name', default='SFNet', type=str) 64 | parser.add_argument('--data_dir', type=str, default='/root/autodl-tmp/deraining_testset') 65 | 66 | parser.add_argument('--test_model', type=str, default='/root/autodl-tmp/sfnet/deraining.pkl') 67 | parser.add_argument('--save_image', type=bool, default=True, choices=[True, False]) 68 | args = parser.parse_args() 69 | 70 | args.result_dir = os.path.join('results/', args.model_name, 'deraining/') 71 | 72 | if not os.path.exists('results/'): 73 | os.makedirs(args.model_save_dir) 74 | if not os.path.exists('results/' + args.model_name + '/'): 75 | os.makedirs('results/' + args.model_name + '/') 76 | if not os.path.exists(args.result_dir): 77 | os.makedirs(args.result_dir) 78 | 79 | model = build_net() 80 | 81 | if torch.cuda.is_available(): 82 | model.cuda() 83 | 84 | state_dict = torch.load(args.test_model) 85 | model.load_state_dict(state_dict['model']) 86 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 87 | torch.cuda.empty_cache() 88 | adder = Adder() 89 | model.eval() 90 | 91 | datasets = ['Rain100L', 'Rain100H', 'Test100', 'Test1200', 'Test2800'] 92 | 93 | for dataset in datasets: 94 | if not os.path.exists(args.result_dir+dataset+'/'): 95 | os.makedirs(args.result_dir+dataset) 96 | print(args.result_dir+dataset) 97 | dataloader = test_dataloader(os.path.join(args.data_dir, dataset), batch_size=1, num_workers=4) 98 | factor = 32 99 | with torch.no_grad(): 100 | psnr_adder = Adder() 101 | 102 | 103 | # Main Evaluation 104 | for iter_idx, data in enumerate(tqdm(dataloader), 0): 105 | input_img, label_img, name = data 106 | 107 | input_img = input_img.to(device) 108 | 109 | h, w = input_img.shape[2], input_img.shape[3] 110 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 111 | padh = H-h if h%factor!=0 else 0 112 | padw = W-w if w%factor!=0 else 0 113 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 114 | 115 | 116 | tm = time.time() 117 | 118 | pred = model(input_img)[2] 119 | 120 | elapsed = time.time() - tm 121 | adder(elapsed) 122 | 123 | pred_clip = torch.clamp(pred, 0, 1) 124 | 125 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 126 | label_numpy = label_img.squeeze(0).cpu().numpy() 127 | 128 | if args.save_image: 129 | save_name = os.path.join(args.result_dir, dataset, name[0]) 130 | pred_clip += 0.5 / 255 131 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 132 | pred.save(save_name) 133 | 134 | print('==========================================================') 135 | -------------------------------------------------------------------------------- /Image_deraining/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | 13 | def _train(model, args): 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | criterion = torch.nn.L1Loss() 16 | 17 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 18 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 19 | max_iter = len(dataloader) 20 | warmup_epochs=3 21 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 22 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 23 | scheduler.step() 24 | epoch = 1 25 | if args.resume: 26 | state = torch.load(args.resume) 27 | epoch = state['epoch'] 28 | optimizer.load_state_dict(state['optimizer']) 29 | model.load_state_dict(state['model']) 30 | print('Resume from %d'%epoch) 31 | epoch += 1 32 | 33 | 34 | 35 | writer = SummaryWriter() 36 | epoch_pixel_adder = Adder() 37 | epoch_fft_adder = Adder() 38 | iter_pixel_adder = Adder() 39 | iter_fft_adder = Adder() 40 | epoch_timer = Timer('m') 41 | iter_timer = Timer('m') 42 | best_psnr=-1 43 | 44 | for epoch_idx in range(epoch, args.num_epoch + 1): 45 | 46 | epoch_timer.tic() 47 | iter_timer.tic() 48 | for iter_idx, batch_data in enumerate(dataloader): 49 | 50 | input_img, label_img = batch_data 51 | input_img = input_img.to(device) 52 | label_img = label_img.to(device) 53 | 54 | optimizer.zero_grad() 55 | pred_img = model(input_img) 56 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 57 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 58 | l1 = criterion(pred_img[0], label_img4) 59 | l2 = criterion(pred_img[1], label_img2) 60 | l3 = criterion(pred_img[2], label_img) 61 | loss_content = l1+l2+l3 62 | 63 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 64 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 65 | 66 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 67 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 68 | 69 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 70 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 71 | 72 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 73 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 74 | 75 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 76 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 77 | 78 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 79 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 80 | 81 | f1 = criterion(pred_fft1, label_fft1) 82 | f2 = criterion(pred_fft2, label_fft2) 83 | f3 = criterion(pred_fft3, label_fft3) 84 | loss_fft = f1+f2+f3 85 | 86 | loss = loss_content + 0.1 * loss_fft 87 | loss.backward() 88 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) 89 | optimizer.step() 90 | 91 | iter_pixel_adder(loss_content.item()) 92 | iter_fft_adder(loss_fft.item()) 93 | 94 | epoch_pixel_adder(loss_content.item()) 95 | epoch_fft_adder(loss_fft.item()) 96 | 97 | if (iter_idx + 1) % args.print_freq == 0: 98 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 99 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 100 | iter_fft_adder.average())) 101 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 102 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 103 | 104 | iter_timer.tic() 105 | iter_pixel_adder.reset() 106 | iter_fft_adder.reset() 107 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 108 | torch.save({'model': model.state_dict(), 109 | 'optimizer': optimizer.state_dict(), 110 | 'epoch': epoch_idx}, overwrite_name) 111 | 112 | if epoch_idx % args.save_freq == 0: 113 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 114 | torch.save({'model': model.state_dict()}, save_name) 115 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 116 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 117 | epoch_fft_adder.reset() 118 | epoch_pixel_adder.reset() 119 | scheduler.step() 120 | if epoch_idx % args.valid_freq == 0: 121 | val_rain = _valid(model, args, epoch_idx) 122 | print('%03d epoch \n Average DeRain PSNR %.2f dB' % (epoch_idx, val_rain)) 123 | writer.add_scalar('PSNR_DeRain', val_rain, epoch_idx) 124 | if val_rain >= best_psnr: 125 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 126 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 127 | torch.save({'model': model.state_dict()}, save_name) 128 | -------------------------------------------------------------------------------- /Image_deraining/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Image_deraining/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | gopro = valid_dataloader(args.valid_data, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Derain Evaluation') 18 | factor = 32 19 | for idx, data in enumerate(gopro): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /Image_desnowing/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - SRRS [[gdrive](https://drive.google.com/file/d/11h1cZ0NXx6ev35cl5NKOAL3PCgLlWUl2/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1VXqsamkl12fPsI1Qek97TQ?pwd=vcfg)] 3 | - CSD [[gdrive](https://drive.google.com/file/d/1pns-7uWy-0SamxjA40qOCkkhSu7o7ULb/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1N52Jnx0co9udJeYrbd3blA?pwd=sb4a)] 4 | - Snow100K [[gdrive](https://drive.google.com/file/d/19zJs0cJ6F3G3IlDHLU2BO7nHnCTMNrIS/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1QGd5z9uM6vBKPnD5d7jQmA?pwd=aph4)] 5 | 6 | ### Training 7 | version=small/base/large for different versions 8 | ~~~ 9 | python main.py --data CSD --version small --mode train --data_dir your_path/CSD 10 | python main.py --data SRRS --version small --mode train --data_dir your_path/SRRS 11 | python main.py --data Snow100K --version small --mode train --data_dir your_path/Snow100K 12 | ~~~ 13 | 14 | ### Evaluation 15 | #### Download the model 16 | [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 17 | #### Testing 18 | version=small/base/large for different versions 19 | ~~~ 20 | python main.py --data CSD --version small --save_image True --mode test --data_dir your_path/CSD --test_model path_to_CSD_model 21 | 22 | python main.py --data SRRS --version small --save_image True --mode test --data_dir your_path/SRRS --test_model path_to_SRRS_model 23 | 24 | python main.py --data Snow100K --version small --save_image True --mode test --data_dir your_path/Snow100K --test_model path_to_Snow100K_model 25 | ~~~ 26 | 27 | 28 | For training and testing, your directory structure should look like this 29 | 30 | `Your path`
31 |  `├──CSD`
32 |      `├──train2500`
33 |           `├──Gt`
34 |           `└──Snow` 35 |      `└──test2000`
36 |           `├──Gt`
37 |           `└──Snow` 38 |  `├──SRRS`
39 |      `├──train2500`
40 |           `├──Gt`
41 |           `└──Snow` 42 |      `└──test2000`
43 |           `├──Gt`
44 |           `└──Snow` 45 |  `└──Snow100K`
46 |      `├──train2500`
47 |           `├──Gt`
48 |           `└──Snow` 49 |      `└──test2000`
50 |           `├──Gt`
51 |           `└──Snow` 52 | -------------------------------------------------------------------------------- /Image_desnowing/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Image_desnowing/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Image_desnowing/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image as Image 3 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor 4 | from torchvision.transforms import functional as F 5 | from torch.utils.data import Dataset, DataLoader 6 | 7 | 8 | def train_dataloader(path, batch_size=64, num_workers=0, data='CSD', use_transform=True): 9 | image_dir = os.path.join(path, 'train2500') 10 | 11 | transform = None 12 | if use_transform: 13 | transform = PairCompose( 14 | [ 15 | PairRandomCrop(256), 16 | PairRandomHorizontalFilp(), 17 | PairToTensor() 18 | ] 19 | ) 20 | dataloader = DataLoader( 21 | DeblurDataset(image_dir, data, transform=transform), 22 | batch_size=batch_size, 23 | shuffle=True, 24 | num_workers=num_workers, 25 | pin_memory=True 26 | ) 27 | return dataloader 28 | 29 | 30 | def test_dataloader(path, data, batch_size=1, num_workers=0): 31 | image_dir = os.path.join(path, 'test2000') 32 | dataloader = DataLoader( 33 | DeblurDataset(image_dir, data, is_test=True), 34 | 35 | batch_size=batch_size, 36 | shuffle=False, 37 | num_workers=num_workers, 38 | pin_memory=True 39 | ) 40 | 41 | return dataloader 42 | 43 | 44 | def valid_dataloader(path, data, batch_size=1, num_workers=0): 45 | dataloader = DataLoader( 46 | DeblurDataset(os.path.join(path, 'test2000'), data), 47 | batch_size=batch_size, 48 | shuffle=False, 49 | num_workers=num_workers 50 | ) 51 | 52 | return dataloader 53 | 54 | 55 | class DeblurDataset(Dataset): 56 | def __init__(self, image_dir, data, transform=None, is_test=False): 57 | self.image_dir = image_dir 58 | self.image_list = os.listdir(os.path.join(image_dir, 'Snow/')) 59 | self.image_list.sort() 60 | self.transform = transform 61 | self.is_test = is_test 62 | self.data = data 63 | 64 | def __len__(self): 65 | return len(self.image_list) 66 | 67 | def __getitem__(self, idx): 68 | image = Image.open(os.path.join(self.image_dir, 'Snow', self.image_list[idx])) 69 | if self.data == 'SRRS': 70 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx].split('.')[0]+'.jpg')) 71 | else: 72 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx])) 73 | 74 | if self.transform: 75 | image, label = self.transform(image, label) 76 | else: 77 | image = F.to_tensor(image) 78 | label = F.to_tensor(label) 79 | if self.is_test: 80 | name = self.image_list[idx] 81 | return image, label, name 82 | return image, label 83 | 84 | -------------------------------------------------------------------------------- /Image_desnowing/eval.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | from pytorch_msssim import ssim 5 | from torchvision.transforms import functional as F 6 | from utils import Adder 7 | from data import test_dataloader 8 | from skimage.metrics import peak_signal_noise_ratio 9 | import torch.nn.functional as f 10 | 11 | def _eval(model, args): 12 | state_dict = torch.load(args.test_model) 13 | model.load_state_dict(state_dict['model']) 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | dataloader = test_dataloader(args.data_dir, args.data, batch_size=1, num_workers=0) 16 | torch.cuda.empty_cache() 17 | model.eval() 18 | factor = 32 19 | with torch.no_grad(): 20 | psnr_adder = Adder() 21 | ssim_adder = Adder() 22 | 23 | for iter_idx, data in enumerate(dataloader): 24 | input_img, label_img, name = data 25 | input_img = input_img.to(device) 26 | 27 | h, w = input_img.shape[2], input_img.shape[3] 28 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 29 | padh = H-h if h%factor!=0 else 0 30 | padw = W-w if w%factor!=0 else 0 31 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 32 | 33 | pred = model(input_img)[2] 34 | pred = pred[:,:,:h,:w] 35 | 36 | pred_clip = torch.clamp(pred, 0, 1) 37 | 38 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 39 | label_numpy = label_img.squeeze(0).cpu().numpy() 40 | 41 | 42 | if args.save_image: 43 | save_name = os.path.join(args.result_dir, name[0]) 44 | pred_clip += 0.5 / 255 45 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 46 | pred.save(save_name) 47 | 48 | 49 | label_img = (label_img).cuda() 50 | down_ratio = max(1, round(min(H, W) / 256)) 51 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 52 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 53 | data_range=1, size_average=False) 54 | ssim_adder(ssim_val) 55 | 56 | psnr = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 57 | psnr_adder(psnr) 58 | 59 | print('%d iter PSNR: %.2f SSIM: %f' % (iter_idx + 1, psnr, ssim_val)) 60 | 61 | print('==========================================================') 62 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 63 | print('The average SSIM is %.4f' % (ssim_adder.average())) 64 | 65 | -------------------------------------------------------------------------------- /Image_desnowing/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.ConvIR import build_net 6 | from train import _train 7 | from eval import _eval 8 | 9 | def main(args): 10 | cudnn.benchmark = True 11 | 12 | if not os.path.exists('results/'): 13 | os.makedirs(args.model_save_dir) 14 | if not os.path.exists('results/' + args.model_name + '/'): 15 | os.makedirs('results/' + args.model_name + '/') 16 | if not os.path.exists(args.result_dir): 17 | os.makedirs(args.result_dir) 18 | 19 | model = build_net(args.version) 20 | # print(model) 21 | 22 | if torch.cuda.is_available(): 23 | model.cuda() 24 | 25 | if args.mode == 'train': 26 | _train(model, args) 27 | 28 | elif args.mode == 'test': 29 | _eval(model, args) 30 | 31 | 32 | if __name__ == '__main__': 33 | parser = argparse.ArgumentParser() 34 | 35 | # Directories 36 | parser.add_argument('--model_name', default='ConvIR', type=str) 37 | parser.add_argument('--data', type=str, default='CSD', choices=['CSD', 'SRRS', 'Snow100K']) 38 | 39 | parser.add_argument('--data_dir', type=str, default='CSD') 40 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str) 41 | parser.add_argument('--version', default='small', choices=['small', 'base', 'large'], type=str) 42 | 43 | # Train 44 | parser.add_argument('--batch_size', type=int, default=8) 45 | parser.add_argument('--learning_rate', type=float, default=2e-4) 46 | parser.add_argument('--weight_decay', type=float, default=0) 47 | parser.add_argument('--num_epoch', type=int, default=2000) 48 | parser.add_argument('--print_freq', type=int, default=100) 49 | parser.add_argument('--num_worker', type=int, default=8) 50 | parser.add_argument('--save_freq', type=int, default=50) 51 | parser.add_argument('--valid_freq', type=int, default=50) 52 | parser.add_argument('--resume', type=str, default='') 53 | 54 | # Test 55 | parser.add_argument('--test_model', type=str, default='CSD.pkl') 56 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 57 | 58 | args = parser.parse_args() 59 | args.model_save_dir = os.path.join('results/', args.model_name, 'Training-Results/') 60 | args.result_dir = os.path.join('results/', args.model_name, 'images', args.data) 61 | if not os.path.exists(args.model_save_dir): 62 | os.makedirs(args.model_save_dir) 63 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 64 | os.system(command) 65 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir 66 | os.system(command) 67 | command = 'cp ' + 'train.py ' + args.model_save_dir 68 | os.system(command) 69 | command = 'cp ' + 'main.py ' + args.model_save_dir 70 | os.system(command) 71 | print(args) 72 | main(args) 73 | -------------------------------------------------------------------------------- /Image_desnowing/models/ConvIR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | 8 | class EBlock(nn.Module): 9 | def __init__(self, out_channel, num_res=8): 10 | super(EBlock, self).__init__() 11 | 12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)] 13 | layers.append(ResBlock(out_channel, out_channel, filter=True)) 14 | 15 | self.layers = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DBlock(nn.Module): 22 | def __init__(self, channel, num_res=8): 23 | super(DBlock, self).__init__() 24 | 25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)] 26 | layers.append(ResBlock(channel, channel, filter=True)) 27 | self.layers = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.layers(x) 31 | 32 | 33 | class SCM(nn.Module): 34 | def __init__(self, out_plane): 35 | super(SCM, self).__init__() 36 | self.main = nn.Sequential( 37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 41 | nn.InstanceNorm2d(out_plane, affine=True) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.main(x) 46 | return x 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, channel): 50 | super(FAM, self).__init__() 51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 52 | 53 | def forward(self, x1, x2): 54 | return self.merge(torch.cat([x1, x2], dim=1)) 55 | 56 | class ConvIR(nn.Module): 57 | def __init__(self, version): 58 | super(ConvIR, self).__init__() 59 | 60 | if version == 'small': 61 | num_res = 4 62 | elif version == 'base': 63 | num_res = 8 64 | elif version == 'large': 65 | num_res = 16 66 | 67 | base_channel = 32 68 | 69 | self.Encoder = nn.ModuleList([ 70 | EBlock(base_channel, num_res), 71 | EBlock(base_channel*2, num_res), 72 | EBlock(base_channel*4, num_res), 73 | ]) 74 | 75 | self.feat_extract = nn.ModuleList([ 76 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 77 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 78 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 79 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 80 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 81 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 82 | ]) 83 | 84 | self.Decoder = nn.ModuleList([ 85 | DBlock(base_channel * 4, num_res), 86 | DBlock(base_channel * 2, num_res), 87 | DBlock(base_channel, num_res) 88 | ]) 89 | 90 | self.Convs = nn.ModuleList([ 91 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 92 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 93 | ]) 94 | 95 | self.ConvsOut = nn.ModuleList( 96 | [ 97 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 98 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 99 | ] 100 | ) 101 | 102 | self.FAM1 = FAM(base_channel * 4) 103 | self.SCM1 = SCM(base_channel * 4) 104 | self.FAM2 = FAM(base_channel * 2) 105 | self.SCM2 = SCM(base_channel * 2) 106 | 107 | def forward(self, x): 108 | x_2 = F.interpolate(x, scale_factor=0.5) 109 | x_4 = F.interpolate(x_2, scale_factor=0.5) 110 | z2 = self.SCM2(x_2) 111 | z4 = self.SCM1(x_4) 112 | 113 | outputs = list() 114 | # 256 115 | x_ = self.feat_extract[0](x) 116 | res1 = self.Encoder[0](x_) 117 | # 128 118 | z = self.feat_extract[1](res1) 119 | z = self.FAM2(z, z2) 120 | res2 = self.Encoder[1](z) 121 | # 64 122 | z = self.feat_extract[2](res2) 123 | z = self.FAM1(z, z4) 124 | z = self.Encoder[2](z) 125 | 126 | z = self.Decoder[0](z) 127 | z_ = self.ConvsOut[0](z) 128 | # 128 129 | z = self.feat_extract[3](z) 130 | outputs.append(z_+x_4) 131 | 132 | z = torch.cat([z, res2], dim=1) 133 | z = self.Convs[0](z) 134 | z = self.Decoder[1](z) 135 | z_ = self.ConvsOut[1](z) 136 | # 256 137 | z = self.feat_extract[4](z) 138 | outputs.append(z_+x_2) 139 | 140 | z = torch.cat([z, res1], dim=1) 141 | z = self.Convs[1](z) 142 | z = self.Decoder[2](z) 143 | z = self.feat_extract[5](z) 144 | outputs.append(z+x) 145 | 146 | return outputs 147 | 148 | 149 | def build_net(version): 150 | return ConvIR(version) 151 | -------------------------------------------------------------------------------- /Image_desnowing/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel, filter=False): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(), 35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.main(x) + x 40 | 41 | 42 | class DeepPoolLayer(nn.Module): 43 | def __init__(self, k, k_out): 44 | super(DeepPoolLayer, self).__init__() 45 | self.pools_sizes = [8,4,2] 46 | dilation = [7,9,11] 47 | pools, convs, dynas = [],[],[] 48 | for j, i in enumerate(self.pools_sizes): 49 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 50 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 51 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j])) 52 | self.pools = nn.ModuleList(pools) 53 | self.convs = nn.ModuleList(convs) 54 | self.dynas = nn.ModuleList(dynas) 55 | self.relu = nn.GELU() 56 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 57 | 58 | def forward(self, x): 59 | x_size = x.size() 60 | resl = x 61 | for i in range(len(self.pools_sizes)): 62 | if i == 0: 63 | y = self.dynas[i](self.convs[i](self.pools[i](x))) 64 | else: 65 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up)) 66 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 67 | if i != len(self.pools_sizes)-1: 68 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) 69 | resl = self.relu(resl) 70 | resl = self.conv_sum(resl) 71 | 72 | return resl 73 | 74 | 75 | class dynamic_filter(nn.Module): 76 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8): 77 | super(dynamic_filter, self).__init__() 78 | self.stride = stride 79 | self.kernel_size = kernel_size 80 | self.group = group 81 | self.dilation = dilation 82 | 83 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) 84 | self.bn = nn.BatchNorm2d(group*kernel_size**2) 85 | self.act = nn.Tanh() 86 | 87 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 88 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 89 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 90 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2) 91 | 92 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 93 | self.gap = nn.AdaptiveAvgPool2d(1) 94 | 95 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True) 96 | 97 | def forward(self, x): 98 | identity_input = x 99 | low_filter = self.ap(x) 100 | low_filter = self.conv(low_filter) 101 | low_filter = self.bn(low_filter) 102 | 103 | n, c, h, w = x.shape 104 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) 105 | 106 | n,c1,p,q = low_filter.shape 107 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) 108 | 109 | low_filter = self.act(low_filter) 110 | 111 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) 112 | 113 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 114 | 115 | out_low = out_low * self.lamb_l[None,:,None,None] 116 | 117 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.) 118 | 119 | return out_low + out_high 120 | 121 | 122 | class cubic_attention(nn.Module): 123 | def __init__(self, dim, group, dilation, kernel) -> None: 124 | super().__init__() 125 | 126 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel) 127 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False) 128 | self.gamma = nn.Parameter(torch.zeros(dim,1,1)) 129 | self.beta = nn.Parameter(torch.ones(dim,1,1)) 130 | 131 | def forward(self, x): 132 | out = self.H_spatial_att(x) 133 | out = self.W_spatial_att(out) 134 | return self.gamma * out + x * self.beta 135 | 136 | 137 | class spatial_strip_att(nn.Module): 138 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None: 139 | super().__init__() 140 | 141 | self.k = kernel 142 | pad = dilation*(kernel-1) // 2 143 | self.kernel = (1, kernel) if H else (kernel, 1) 144 | self.padding = (kernel//2, 1) if H else (1, kernel//2) 145 | self.dilation = dilation 146 | self.group = group 147 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad)) 148 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False) 149 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 150 | self.filter_act = nn.Tanh() 151 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True) 152 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True) 153 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True) 154 | gap_kernel = (None,1) if H else (1, None) 155 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel) 156 | 157 | def forward(self, x): 158 | identity_input = x.clone() 159 | filter = self.ap(x) 160 | filter = self.conv(filter) 161 | n, c, h, w = x.shape 162 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w) 163 | n, c1, p, q = filter.shape 164 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2) 165 | filter = self.filter_act(filter) 166 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w) 167 | 168 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 169 | out_low = out_low * self.lamb_l[None,:,None,None] 170 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.) 171 | 172 | return out_low + out_high 173 | 174 | 175 | class MultiShapeKernel(nn.Module): 176 | def __init__(self, dim, kernel_size=3, dilation=1, group=8): 177 | super().__init__() 178 | 179 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size) 180 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size) 181 | 182 | def forward(self, x): 183 | 184 | x1 = self.strip_att(x) 185 | x2 = self.square_att(x) 186 | 187 | return x1+x2 188 | 189 | 190 | -------------------------------------------------------------------------------- /Image_desnowing/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | 13 | def _train(model, args): 14 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 15 | criterion = torch.nn.L1Loss() 16 | 17 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 18 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker, args.data) 19 | max_iter = len(dataloader) 20 | warmup_epochs=3 21 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 22 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 23 | scheduler.step() 24 | epoch = 1 25 | if args.resume: 26 | state = torch.load(args.resume) 27 | epoch = state['epoch'] 28 | optimizer.load_state_dict(state['optimizer']) 29 | model.load_state_dict(state['model']) 30 | print('Resume from %d'%epoch) 31 | epoch += 1 32 | 33 | 34 | 35 | writer = SummaryWriter() 36 | epoch_pixel_adder = Adder() 37 | epoch_fft_adder = Adder() 38 | iter_pixel_adder = Adder() 39 | iter_fft_adder = Adder() 40 | epoch_timer = Timer('m') 41 | iter_timer = Timer('m') 42 | best_psnr=-1 43 | 44 | for epoch_idx in range(epoch, args.num_epoch + 1): 45 | 46 | epoch_timer.tic() 47 | iter_timer.tic() 48 | for iter_idx, batch_data in enumerate(dataloader): 49 | 50 | input_img, label_img = batch_data 51 | input_img = input_img.to(device) 52 | label_img = label_img.to(device) 53 | 54 | optimizer.zero_grad() 55 | pred_img = model(input_img) 56 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 57 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 58 | l1 = criterion(pred_img[0], label_img4) 59 | l2 = criterion(pred_img[1], label_img2) 60 | l3 = criterion(pred_img[2], label_img) 61 | loss_content = l1+l2+l3 62 | 63 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 64 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 65 | 66 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 67 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 68 | 69 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 70 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 71 | 72 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 73 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 74 | 75 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 76 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 77 | 78 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 79 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 80 | 81 | f1 = criterion(pred_fft1, label_fft1) 82 | f2 = criterion(pred_fft2, label_fft2) 83 | f3 = criterion(pred_fft3, label_fft3) 84 | loss_fft = f1+f2+f3 85 | 86 | loss = loss_content + 0.1 * loss_fft 87 | loss.backward() 88 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) 89 | optimizer.step() 90 | 91 | iter_pixel_adder(loss_content.item()) 92 | iter_fft_adder(loss_fft.item()) 93 | 94 | epoch_pixel_adder(loss_content.item()) 95 | epoch_fft_adder(loss_fft.item()) 96 | 97 | if (iter_idx + 1) % args.print_freq == 0: 98 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 99 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 100 | iter_fft_adder.average())) 101 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 102 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 103 | 104 | iter_timer.tic() 105 | iter_pixel_adder.reset() 106 | iter_fft_adder.reset() 107 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 108 | torch.save({'model': model.state_dict(), 109 | 'optimizer': optimizer.state_dict(), 110 | 'epoch': epoch_idx}, overwrite_name) 111 | 112 | if epoch_idx % args.save_freq == 0: 113 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 114 | torch.save({'model': model.state_dict()}, save_name) 115 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 116 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 117 | epoch_fft_adder.reset() 118 | epoch_pixel_adder.reset() 119 | scheduler.step() 120 | if epoch_idx % args.valid_freq == 0: 121 | val_snow = _valid(model, args, epoch_idx) 122 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val_snow)) 123 | writer.add_scalar('PSNR_Desnowing', val_snow, epoch_idx) 124 | if val_snow >= best_psnr: 125 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 126 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 127 | torch.save({'model': model.state_dict()}, save_name) 128 | -------------------------------------------------------------------------------- /Image_desnowing/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Image_desnowing/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | snow_data = valid_dataloader(args.data_dir, args.data, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Desnowing Evaluation') 18 | factor = 32 19 | for idx, data in enumerate(snow_data): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Yuning Cui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Motion_Deblurring/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - Gopro [[gdrive](https://drive.google.com/file/d/1y_wQ5G5B65HS_mdIjxKYTcnRys_AGh5v/view?usp=sharing), [百度网盘](https://pan.baidu.com/s/1eNCvqewdUp15-0dD2MfJbg?pwd=ea0r)] 3 | 4 | ### Training on GoPro 5 | ~~~ 6 | python main.py --data_dir your_path/GOPRO 7 | ~~~ 8 | ### Evaluation 9 | Download model: [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 10 | #### Testing on GoPro 11 | ~~~ 12 | python main.py --mode test --data_dir your_path/GOPRO --test_model path_to_gopro_model --save_image True 13 | ~~~ 14 | -------------------------------------------------------------------------------- /Motion_Deblurring/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Motion_Deblurring/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | #class PairRandomVerticalFlip(transforms.RandomVerticalFlip): 50 | # def __call__(self, img, label): 51 | """ 52 | Args: 53 | img (PIL Image): Image to be flipped. 54 | 55 | Returns: 56 | PIL Image: Randomly flipped image. 57 | """ 58 | # if random.random() < self.p: 59 | # return F.vflip(img), F.vflip(label) 60 | # return img, label 61 | 62 | 63 | class PairToTensor(transforms.ToTensor): 64 | def __call__(self, pic, label): 65 | """ 66 | Args: 67 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 68 | 69 | Returns: 70 | Tensor: Converted image. 71 | """ 72 | return F.to_tensor(pic), F.to_tensor(label) 73 | -------------------------------------------------------------------------------- /Motion_Deblurring/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image as Image 5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor 6 | from torchvision.transforms import functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True): 11 | image_dir = os.path.join(path, 'train') 12 | 13 | transform = None 14 | if use_transform: 15 | transform = PairCompose( 16 | [ 17 | PairRandomCrop(256), 18 | PairRandomHorizontalFilp(), 19 | PairToTensor() 20 | ] 21 | ) 22 | dataloader = DataLoader( 23 | DeblurDataset(image_dir, transform=transform), 24 | batch_size=batch_size, 25 | shuffle=True, 26 | num_workers=num_workers, 27 | pin_memory=True 28 | ) 29 | return dataloader 30 | 31 | 32 | def test_dataloader(path, batch_size=1, num_workers=0): 33 | image_dir = os.path.join(path, 'valid') 34 | dataloader = DataLoader( 35 | DeblurDataset(image_dir, is_test=True), 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers, 39 | pin_memory=True 40 | ) 41 | 42 | return dataloader 43 | 44 | 45 | def valid_dataloader(path, batch_size=1, num_workers=0): 46 | dataloader = DataLoader( 47 | DeblurDataset(os.path.join(path, 'valid')), 48 | batch_size=batch_size, 49 | shuffle=False, 50 | num_workers=num_workers 51 | ) 52 | 53 | return dataloader 54 | 55 | 56 | class DeblurDataset(Dataset): 57 | def __init__(self, image_dir, transform=None, is_test=False): 58 | self.image_dir = image_dir 59 | self.image_list = os.listdir(os.path.join(image_dir, 'blur/')) 60 | self._check_image(self.image_list) 61 | self.image_list.sort() 62 | self.transform = transform 63 | self.is_test = is_test 64 | 65 | def __len__(self): 66 | return len(self.image_list) 67 | 68 | def __getitem__(self, idx): 69 | image = Image.open(os.path.join(self.image_dir, 'blur', self.image_list[idx])) 70 | label = Image.open(os.path.join(self.image_dir, 'sharp', self.image_list[idx].replace('blur', 'gt'))) 71 | 72 | if self.transform: 73 | image, label = self.transform(image, label) 74 | else: 75 | image = F.to_tensor(image) 76 | label = F.to_tensor(label) 77 | if self.is_test: 78 | name = self.image_list[idx] 79 | return image, label, name 80 | return image, label 81 | 82 | @staticmethod 83 | def _check_image(lst): 84 | for x in lst: 85 | splits = x.split('.') 86 | if splits[-1] not in ['png', 'jpg', 'jpeg']: 87 | raise ValueError 88 | -------------------------------------------------------------------------------- /Motion_Deblurring/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | from utils import Adder 5 | from data import test_dataloader 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import time 8 | import torch.nn.functional as f 9 | 10 | factor = 32 11 | 12 | def _eval(model, args): 13 | state_dict = torch.load(args.test_model) 14 | model.load_state_dict(state_dict['model']) 15 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 16 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 17 | adder = Adder() 18 | model.eval() 19 | 20 | with torch.no_grad(): 21 | psnr_adder = Adder() 22 | for iter_idx, data in enumerate(dataloader): 23 | input_img, label_img, name = data 24 | 25 | input_img = input_img.to(device) 26 | h, w = input_img.shape[2], input_img.shape[3] 27 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 28 | padh = H-h if h%factor!=0 else 0 29 | padw = W-w if w%factor!=0 else 0 30 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 31 | tm = time.time() 32 | 33 | pred = model(input_img)[2] 34 | pred = pred[:,:,:h,:w] 35 | elapsed = time.time() - tm 36 | adder(elapsed) 37 | 38 | pred_clip = torch.clamp(pred, 0, 1) 39 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 40 | label_numpy = label_img.squeeze(0).cpu().numpy() 41 | 42 | if args.save_image: 43 | save_name = os.path.join(args.result_dir, name[0]) 44 | pred_clip += 0.5 / 255 45 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 46 | pred.save(save_name) 47 | 48 | psnr = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 49 | psnr_adder(psnr) 50 | print('%d iter PSNR: %.4f time: %f' % (iter_idx + 1, psnr, elapsed)) 51 | 52 | print('==========================================================') 53 | print('The average PSNR is %.4f dB' % (psnr_adder.average())) 54 | print("Average time: %f" % adder.average()) 55 | -------------------------------------------------------------------------------- /Motion_Deblurring/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.ConvIR import build_net 6 | from train import _train 7 | from eval import _eval 8 | 9 | def main(args): 10 | cudnn.benchmark = True 11 | 12 | if not os.path.exists('results/'): 13 | os.makedirs(args.model_save_dir) 14 | if not os.path.exists('results/' + args.model_name + '/'): 15 | os.makedirs('results/' + args.model_name + '/') 16 | if not os.path.exists(args.model_save_dir): 17 | os.makedirs(args.model_save_dir) 18 | if not os.path.exists(args.result_dir): 19 | os.makedirs(args.result_dir) 20 | 21 | model = build_net() 22 | # print(model) 23 | 24 | if torch.cuda.is_available(): 25 | model.cuda() 26 | 27 | if args.mode == 'train': 28 | _train(model, args) 29 | 30 | elif args.mode == 'test': 31 | _eval(model, args) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser() 36 | 37 | # Directories 38 | parser.add_argument('--model_name', default='ConvIR', type=str) 39 | parser.add_argument('--data_dir', type=str, default='') 40 | 41 | parser.add_argument('--mode', default='train', choices=['train', 'test'], type=str) 42 | 43 | # Train 44 | parser.add_argument('--batch_size', type=int, default=4) 45 | parser.add_argument('--learning_rate', type=float, default=1e-4) 46 | parser.add_argument('--weight_decay', type=float, default=0) 47 | parser.add_argument('--num_epoch', type=int, default=3000) 48 | parser.add_argument('--print_freq', type=int, default=100) 49 | parser.add_argument('--num_worker', type=int, default=8) 50 | parser.add_argument('--save_freq', type=int, default=100) 51 | parser.add_argument('--valid_freq', type=int, default=100) 52 | parser.add_argument('--resume', type=str, default='') 53 | 54 | # Test 55 | parser.add_argument('--test_model', type=str, default='gopro.pkl') 56 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 57 | 58 | args = parser.parse_args() 59 | args.model_save_dir = os.path.join('results/', 'ConvIR', 'test') 60 | args.result_dir = os.path.join('results/', args.model_name, 'GOPRO') 61 | if not os.path.exists(args.model_save_dir): 62 | os.makedirs(args.model_save_dir) 63 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 64 | os.system(command) 65 | command = 'cp ' + 'models/ConvIR.py ' + args.model_save_dir 66 | os.system(command) 67 | command = 'cp ' + 'train.py ' + args.model_save_dir 68 | os.system(command) 69 | command = 'cp ' + 'main.py ' + args.model_save_dir 70 | os.system(command) 71 | print(args) 72 | main(args) 73 | -------------------------------------------------------------------------------- /Motion_Deblurring/models/ConvIR.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | 8 | class EBlock(nn.Module): 9 | def __init__(self, out_channel, num_res=8): 10 | super(EBlock, self).__init__() 11 | 12 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res-1)] 13 | layers.append(ResBlock(out_channel, out_channel, filter=True)) 14 | 15 | self.layers = nn.Sequential(*layers) 16 | 17 | def forward(self, x): 18 | return self.layers(x) 19 | 20 | 21 | class DBlock(nn.Module): 22 | def __init__(self, channel, num_res=8): 23 | super(DBlock, self).__init__() 24 | 25 | layers = [ResBlock(channel, channel) for _ in range(num_res-1)] 26 | layers.append(ResBlock(channel, channel, filter=True)) 27 | self.layers = nn.Sequential(*layers) 28 | 29 | def forward(self, x): 30 | return self.layers(x) 31 | 32 | 33 | class SCM(nn.Module): 34 | def __init__(self, out_plane): 35 | super(SCM, self).__init__() 36 | self.main = nn.Sequential( 37 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 38 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 39 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 40 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 41 | nn.InstanceNorm2d(out_plane, affine=True) 42 | ) 43 | 44 | def forward(self, x): 45 | x = self.main(x) 46 | return x 47 | 48 | class FAM(nn.Module): 49 | def __init__(self, channel): 50 | super(FAM, self).__init__() 51 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 52 | 53 | def forward(self, x1, x2): 54 | return self.merge(torch.cat([x1, x2], dim=1)) 55 | 56 | class ConvIR(nn.Module): 57 | def __init__(self, num_res=16): 58 | super(ConvIR, self).__init__() 59 | 60 | base_channel = 32 61 | 62 | self.Encoder = nn.ModuleList([ 63 | EBlock(base_channel, num_res), 64 | EBlock(base_channel*2, num_res), 65 | EBlock(base_channel*4, num_res), 66 | ]) 67 | 68 | self.feat_extract = nn.ModuleList([ 69 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 70 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 71 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 72 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 73 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 74 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 75 | ]) 76 | 77 | self.Decoder = nn.ModuleList([ 78 | DBlock(base_channel * 4, num_res), 79 | DBlock(base_channel * 2, num_res), 80 | DBlock(base_channel, num_res) 81 | ]) 82 | 83 | self.Convs = nn.ModuleList([ 84 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 85 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 86 | ]) 87 | 88 | self.ConvsOut = nn.ModuleList( 89 | [ 90 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 91 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 92 | ] 93 | ) 94 | 95 | self.FAM1 = FAM(base_channel * 4) 96 | self.SCM1 = SCM(base_channel * 4) 97 | self.FAM2 = FAM(base_channel * 2) 98 | self.SCM2 = SCM(base_channel * 2) 99 | 100 | def forward(self, x): 101 | x_2 = F.interpolate(x, scale_factor=0.5) 102 | x_4 = F.interpolate(x_2, scale_factor=0.5) 103 | z2 = self.SCM2(x_2) 104 | z4 = self.SCM1(x_4) 105 | 106 | outputs = list() 107 | # 256 108 | x_ = self.feat_extract[0](x) 109 | res1 = self.Encoder[0](x_) 110 | # 128 111 | z = self.feat_extract[1](res1) 112 | z = self.FAM2(z, z2) 113 | res2 = self.Encoder[1](z) 114 | # 64 115 | z = self.feat_extract[2](res2) 116 | z = self.FAM1(z, z4) 117 | z = self.Encoder[2](z) 118 | 119 | z = self.Decoder[0](z) 120 | z_ = self.ConvsOut[0](z) 121 | # 128 122 | z = self.feat_extract[3](z) 123 | outputs.append(z_+x_4) 124 | 125 | z = torch.cat([z, res2], dim=1) 126 | z = self.Convs[0](z) 127 | z = self.Decoder[1](z) 128 | z_ = self.ConvsOut[1](z) 129 | # 256 130 | z = self.feat_extract[4](z) 131 | outputs.append(z_+x_2) 132 | 133 | z = torch.cat([z, res1], dim=1) 134 | z = self.Convs[1](z) 135 | z = self.Decoder[2](z) 136 | z = self.feat_extract[5](z) 137 | outputs.append(z+x) 138 | 139 | return outputs 140 | 141 | 142 | def build_net(): 143 | return ConvIR() 144 | -------------------------------------------------------------------------------- /Motion_Deblurring/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel, filter=False): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | DeepPoolLayer(in_channel, out_channel) if filter else nn.Identity(), 35 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.main(x) + x 40 | 41 | 42 | class DeepPoolLayer(nn.Module): 43 | def __init__(self, k, k_out): 44 | super(DeepPoolLayer, self).__init__() 45 | self.pools_sizes = [8,4,2] 46 | dilation = [7,9,11] 47 | pools, convs, dynas = [],[],[] 48 | for j, i in enumerate(self.pools_sizes): 49 | pools.append(nn.AvgPool2d(kernel_size=i, stride=i)) 50 | convs.append(nn.Conv2d(k, k, 3, 1, 1, bias=False)) 51 | dynas.append(MultiShapeKernel(dim=k, kernel_size=3, dilation=dilation[j])) 52 | self.pools = nn.ModuleList(pools) 53 | self.convs = nn.ModuleList(convs) 54 | self.dynas = nn.ModuleList(dynas) 55 | self.relu = nn.GELU() 56 | self.conv_sum = nn.Conv2d(k, k_out, 3, 1, 1, bias=False) 57 | 58 | def forward(self, x): 59 | x_size = x.size() 60 | resl = x 61 | for i in range(len(self.pools_sizes)): 62 | if i == 0: 63 | y = self.dynas[i](self.convs[i](self.pools[i](x))) 64 | else: 65 | y = self.dynas[i](self.convs[i](self.pools[i](x)+y_up)) 66 | resl = torch.add(resl, F.interpolate(y, x_size[2:], mode='bilinear', align_corners=True)) 67 | if i != len(self.pools_sizes)-1: 68 | y_up = F.interpolate(y, scale_factor=2, mode='bilinear', align_corners=True) 69 | resl = self.relu(resl) 70 | resl = self.conv_sum(resl) 71 | 72 | return resl 73 | 74 | class dynamic_filter(nn.Module): 75 | def __init__(self, inchannels, kernel_size=3, dilation=1, stride=1, group=8): 76 | super(dynamic_filter, self).__init__() 77 | self.stride = stride 78 | self.kernel_size = kernel_size 79 | self.group = group 80 | self.dilation = dilation 81 | 82 | self.conv = nn.Conv2d(inchannels, group*kernel_size**2, kernel_size=1, stride=1, bias=False) 83 | self.bn = nn.BatchNorm2d(group*kernel_size**2) 84 | self.act = nn.Tanh() 85 | 86 | nn.init.kaiming_normal_(self.conv.weight, mode='fan_out', nonlinearity='relu') 87 | self.lamb_l = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 88 | self.lamb_h = nn.Parameter(torch.zeros(inchannels), requires_grad=True) 89 | self.pad = nn.ReflectionPad2d(self.dilation*(kernel_size-1)//2) 90 | 91 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 92 | self.gap = nn.AdaptiveAvgPool2d(1) 93 | 94 | self.inside_all = nn.Parameter(torch.zeros(inchannels,1,1), requires_grad=True) 95 | 96 | def forward(self, x): 97 | identity_input = x 98 | low_filter = self.ap(x) 99 | low_filter = self.conv(low_filter) 100 | low_filter = self.bn(low_filter) 101 | 102 | n, c, h, w = x.shape 103 | x = F.unfold(self.pad(x), kernel_size=self.kernel_size, dilation=self.dilation).reshape(n, self.group, c//self.group, self.kernel_size**2, h*w) 104 | 105 | n,c1,p,q = low_filter.shape 106 | low_filter = low_filter.reshape(n, c1//self.kernel_size**2, self.kernel_size**2, p*q).unsqueeze(2) 107 | 108 | low_filter = self.act(low_filter) 109 | 110 | low_part = torch.sum(x * low_filter, dim=3).reshape(n, c, h, w) 111 | 112 | out_low = low_part * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 113 | 114 | out_low = out_low * self.lamb_l[None,:,None,None] 115 | 116 | out_high = (identity_input) * (self.lamb_h[None,:,None,None] + 1.) 117 | 118 | return out_low + out_high 119 | 120 | 121 | class cubic_attention(nn.Module): 122 | def __init__(self, dim, group, dilation, kernel) -> None: 123 | super().__init__() 124 | 125 | self.H_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel) 126 | self.W_spatial_att = spatial_strip_att(dim, dilation=dilation, group=group, kernel=kernel, H=False) 127 | self.gamma = nn.Parameter(torch.zeros(dim,1,1)) 128 | self.beta = nn.Parameter(torch.ones(dim,1,1)) 129 | 130 | def forward(self, x): 131 | out = self.H_spatial_att(x) 132 | out = self.W_spatial_att(out) 133 | return self.gamma * out + x * self.beta 134 | 135 | 136 | class spatial_strip_att(nn.Module): 137 | def __init__(self, dim, kernel=3, dilation=1, group=2, H=True) -> None: 138 | super().__init__() 139 | 140 | self.k = kernel 141 | pad = dilation*(kernel-1) // 2 142 | self.kernel = (1, kernel) if H else (kernel, 1) 143 | self.padding = (kernel//2, 1) if H else (1, kernel//2) 144 | self.dilation = dilation 145 | self.group = group 146 | self.pad = nn.ReflectionPad2d((pad, pad, 0, 0)) if H else nn.ReflectionPad2d((0, 0, pad, pad)) 147 | self.conv = nn.Conv2d(dim, group*kernel, kernel_size=1, stride=1, bias=False) 148 | self.ap = nn.AdaptiveAvgPool2d((1, 1)) 149 | self.filter_act = nn.Tanh() 150 | self.inside_all = nn.Parameter(torch.zeros(dim,1,1), requires_grad=True) 151 | self.lamb_l = nn.Parameter(torch.zeros(dim), requires_grad=True) 152 | self.lamb_h = nn.Parameter(torch.zeros(dim), requires_grad=True) 153 | gap_kernel = (None,1) if H else (1, None) 154 | self.gap = nn.AdaptiveAvgPool2d(gap_kernel) 155 | 156 | def forward(self, x): 157 | identity_input = x.clone() 158 | filter = self.ap(x) 159 | filter = self.conv(filter) 160 | n, c, h, w = x.shape 161 | x = F.unfold(self.pad(x), kernel_size=self.kernel, dilation=self.dilation).reshape(n, self.group, c//self.group, self.k, h*w) 162 | n, c1, p, q = filter.shape 163 | filter = filter.reshape(n, c1//self.k, self.k, p*q).unsqueeze(2) 164 | filter = self.filter_act(filter) 165 | out = torch.sum(x * filter, dim=3).reshape(n, c, h, w) 166 | 167 | out_low = out * (self.inside_all + 1.) - self.inside_all * self.gap(identity_input) 168 | out_low = out_low * self.lamb_l[None,:,None,None] 169 | out_high = identity_input * (self.lamb_h[None,:,None,None]+1.) 170 | 171 | return out_low + out_high 172 | 173 | 174 | class MultiShapeKernel(nn.Module): 175 | def __init__(self, dim, kernel_size=3, dilation=1, group=8): 176 | super().__init__() 177 | 178 | self.square_att = dynamic_filter(inchannels=dim, dilation=dilation, group=group, kernel_size=kernel_size) 179 | self.strip_att = cubic_attention(dim, group=group, dilation=dilation, kernel=kernel_size) 180 | 181 | def forward(self, x): 182 | 183 | x1 = self.strip_att(x) 184 | x2 = self.square_att(x) 185 | 186 | return x1+x2 187 | 188 | 189 | -------------------------------------------------------------------------------- /Motion_Deblurring/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | 9 | from warmup_scheduler import GradualWarmupScheduler 10 | 11 | def _train(model, args): 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | criterion = torch.nn.L1Loss() 14 | 15 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 16 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 17 | max_iter = len(dataloader) 18 | warmup_epochs=3 19 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 20 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 21 | scheduler.step() 22 | epoch = 1 23 | if args.resume: 24 | state = torch.load(args.resume) 25 | epoch = state['epoch'] 26 | optimizer.load_state_dict(state['optimizer']) 27 | model.load_state_dict(state['model']) 28 | print('Resume from %d'%epoch) 29 | epoch += 1 30 | 31 | writer = SummaryWriter() 32 | epoch_pixel_adder = Adder() 33 | epoch_fft_adder = Adder() 34 | iter_pixel_adder = Adder() 35 | iter_fft_adder = Adder() 36 | epoch_timer = Timer('m') 37 | iter_timer = Timer('m') 38 | best_psnr=-1 39 | 40 | for epoch_idx in range(epoch, args.num_epoch + 1): 41 | 42 | epoch_timer.tic() 43 | iter_timer.tic() 44 | for iter_idx, batch_data in enumerate(dataloader): 45 | 46 | input_img, label_img = batch_data 47 | input_img = input_img.to(device) 48 | label_img = label_img.to(device) 49 | 50 | optimizer.zero_grad() 51 | pred_img = model(input_img) 52 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 53 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 54 | l1 = criterion(pred_img[0], label_img4) 55 | l2 = criterion(pred_img[1], label_img2) 56 | l3 = criterion(pred_img[2], label_img) 57 | loss_content = l1+l2+l3 58 | 59 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 60 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 61 | 62 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 63 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 64 | 65 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 66 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 67 | 68 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 69 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 70 | 71 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 72 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 73 | 74 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 75 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 76 | 77 | f1 = criterion(pred_fft1, label_fft1) 78 | f2 = criterion(pred_fft2, label_fft2) 79 | f3 = criterion(pred_fft3, label_fft3) 80 | loss_fft = f1+f2+f3 81 | 82 | loss = loss_content + 0.1 * loss_fft 83 | loss.backward() 84 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001) 85 | optimizer.step() 86 | 87 | iter_pixel_adder(loss_content.item()) 88 | iter_fft_adder(loss_fft.item()) 89 | 90 | epoch_pixel_adder(loss_content.item()) 91 | epoch_fft_adder(loss_fft.item()) 92 | 93 | if (iter_idx + 1) % args.print_freq == 0: 94 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 95 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 96 | iter_fft_adder.average())) 97 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 98 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 99 | iter_timer.tic() 100 | iter_pixel_adder.reset() 101 | iter_fft_adder.reset() 102 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 103 | torch.save({'model': model.state_dict(), 104 | 'optimizer': optimizer.state_dict(), 105 | 'epoch': epoch_idx}, overwrite_name) 106 | 107 | if epoch_idx % args.save_freq == 0: 108 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 109 | torch.save({'model': model.state_dict()}, save_name) 110 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 111 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 112 | epoch_fft_adder.reset() 113 | epoch_pixel_adder.reset() 114 | scheduler.step() 115 | 116 | if epoch_idx % args.valid_freq == 0: 117 | val_gopro = _valid(model, args, epoch_idx) 118 | print('%03d epoch \n Average GOPRO PSNR %.2f dB' % (epoch_idx, val_gopro)) 119 | writer.add_scalar('PSNR_GOPRO', val_gopro, epoch_idx) 120 | if val_gopro >= best_psnr: 121 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 122 | 123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 124 | torch.save({'model': model.state_dict()}, save_name) 125 | -------------------------------------------------------------------------------- /Motion_Deblurring/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Motion_Deblurring/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | gopro = valid_dataloader(args.data_dir, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start GoPro Evaluation') 18 | factor = 32 19 | for idx, data in enumerate(gopro): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /README.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | Revitalizing Convolutional Network for Image Restoration 6 | 21 | 22 | 23 | 24 | 31 | 42 | 53 | 66 | 138 | 139 | 140 | 141 |

PWC 142 | PWC 143 | PWC 144 | PWC 145 | PWC 146 | PWC 147 | PWC

148 |

Revitalizing Convolutional Network for Image Restoration

149 |

The official pytorch implementation of the paper Revitalizing Convolutional Network for Image Restoration 150 | (T-PAMI'24)

151 |

Yuning Cui, Wenqi Ren, Xiaochun Cao, Alois Knoll

152 |

Installation

153 |

The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5 154 | For installing, follow these instructions:

155 |
conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch
156 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python
157 | 
158 |

Install warmup scheduler:

159 |
cd pytorch-gradual-warmup-lr/
160 | python setup.py install
161 | cd ..
162 | 
163 |

Training and Evaluation

164 |

Please refer to respective directories.

165 |

Results [Download]

166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 |
ModelParametersFLOPs
ConvIR-S (small)5.53M42.1G
ConvIR-B (base)8.63M71.22G
ConvIR-L (large)14.83M129.34G
192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 |
TaskDatasetPSNRSSIM
Image DehazingSOTS-Indoor41.53/42.720.996/0.997
SOTS-Outdoor37.95/39.420.994/0.996
Haze4K33.36/34.15/34.500.99/0.99/0.99
Dense-Haze17.45/16.860.648/0.621
NH-HAZE20.65/20.660.807/0.802
O-HAZE25.25/25.360.784/0.780
I-HAZE21.95/22.440.888/0.887
SateHaze-1k-Thin/Moderate/Thick25.11/26.79/22.650.978/0.978/0.950
NHR28.85/29.490.981/0.983
GTA531.68/31.830.917/0.921
Image DesnowingCSD38.43/39.100.99/0.99
SRRS32.25/32.390.98/0.98
Snow100K33.79/33.920.95/0.96
Image DerainingTest10031.400.919
Test280033.730.937
Defocus DeblurringDPDD26.06/26.16/26.360.810/0.814/0.820
Motion DeblurringGoPro33.280.963
RSBlur34.060.868
312 |

Citation

313 |
@article{cui2024revitalizing,
314 |   title={Revitalizing Convolutional Network for Image Restoration},
315 |   author={Cui, Yuning and Ren, Wenqi and Cao, Xiaochun and Knoll, Alois},
316 |   journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
317 |   year={2024},
318 |   publisher={IEEE}
319 | }
320 | 
321 | @inproceedings{cui2023irnext,
322 |   title={IRNeXt: Rethinking Convolutional Network Design for Image Restoration},
323 |   author={Cui, Yuning and Ren, Wenqi and Yang, Sining and Cao, Xiaochun and Knoll, Alois},
324 |   booktitle={International Conference on Machine Learning},
325 |   pages={6545--6564},
326 |   year={2023},
327 |   organization={PMLR}
328 | }
329 | 
330 |

Contact

331 |

Should you have any problem, please contact Yuning Cui.

332 | 333 | 334 | 335 | 336 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/image-dehazing-on-sots-indoor)](https://paperswithcode.com/sota/image-dehazing-on-sots-indoor) 2 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/image-dehazing-on-sots-outdoor)](https://paperswithcode.com/sota/image-dehazing-on-sots-outdoor) 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/image-dehazing-on-haze4k)](https://paperswithcode.com/sota/image-dehazing-on-haze4k) 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/image-dehazing-on-i-haze)](https://paperswithcode.com/sota/image-dehazing-on-i-haze) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/image-dehazing-on-o-haze)](https://paperswithcode.com/sota/image-dehazing-on-o-haze) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/snow-removal-on-snow100k)](https://paperswithcode.com/sota/snow-removal-on-snow100k) 7 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/revitalizing-convolutional-network-for-image/snow-removal-on-srrs)](https://paperswithcode.com/sota/snow-removal-on-srrs) 8 | 9 | 10 | ## Revitalizing Convolutional Network for Image Restoration 11 | 12 | The official pytorch implementation of the paper **[Revitalizing Convolutional Network for Image Restoration](https://ieeexplore.ieee.org/abstract/document/10571568)** 13 | 14 | #### Yuning Cui, Wenqi Ren, Xiaochun Cao, Alois Knoll 15 | 16 | ## News 17 | All resulting images and pre-trained models are available in the provided links. 18 | 19 | **11/26/2024** Code for real haze and haze4k are released. 20 | 21 | **07/22/2024** We release the code for dehazing (ITS/OTS), desnowing, deraining, and motion deblurring. 22 | 23 | ## Pretrained models 24 | [gdrive](https://drive.google.com/drive/folders/1_5fO2p5xoWO5cUEVoXJ7x3Uhg1AP18FQ?usp=sharing), [百度网盘](https://pan.baidu.com/s/1oYzdxs3FvLJMWx7S5GW0rA?pwd=dvta) 25 | 26 | 27 | ## Installation 28 | The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5 29 | For installing, follow these instructions: 30 | ~~~ 31 | conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch 32 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python 33 | ~~~ 34 | 35 | *Please use the pillow package downloaded by Conda rather than pip.* 36 | 37 | 38 | 39 | Install warmup scheduler: 40 | ~~~ 41 | cd pytorch-gradual-warmup-lr/ 42 | python setup.py install 43 | cd .. 44 | ~~~ 45 | ## Training and Evaluation 46 | Please refer to respective directories. 47 | ## Results 48 | ### Visualization Results: [gdrive](https://drive.google.com/drive/folders/1YiuiYG36zqgHsoUhbk6UJAAywGc0avnj?usp=sharing), [百度网盘](https://pan.baidu.com/s/1mDlRfEoMSi8vpCLRUxk2tQ?pwd=y2gv) 49 | |Model|Parameters|FLOPs| 50 | |------|-----|-----| 51 | |*ConvIR-S (small)*|5.53M|42.1G| 52 | |**ConvIR-B (base)**| 8.63M|71.22G| 53 | |ConvIR-L (large)| 14.83M |129.34G| 54 | 55 | |Task|Dataset|PSNR|SSIM| 56 | |----|------|-----|----| 57 | |**Image Dehazing**|SOTS-Indoor|*41.53*/**42.72**|*0.996*/**0.997**| 58 | ||SOTS-Outdoor|*37.95*/**39.42**|*0.994*/**0.996**| 59 | ||Haze4K|*33.36*/**34.15**/34.50|*0.99*/**0.99**/0.99| 60 | ||Dense-Haze|*17.45*/**16.86**|*0.648*/**0.621**| 61 | ||NH-HAZE|*20.65*/**20.66**|*0.807*/**0.802**| 62 | ||O-HAZE|*25.25*/**25.36**|*0.784*/**0.780**| 63 | ||I-HAZE|*21.95*/**22.44**|*0.888*/**0.887**| 64 | ||SateHaze-1k-Thin/Moderate/Thick|*25.11*/*26.79*/*22.65*|*0.978*/*0.978*/*0.950*| 65 | ||NHR|*28.85*/**29.49**|*0.981*/**0.983**| 66 | ||GTA5|*31.68*/**31.83**|*0.917*/**0.921**| 67 | |**Image Desnowing**|CSD|*38.43*/**39.10**|*0.99*/**0.99**| 68 | ||SRRS|*32.25*/**32.39**|*0.98*/**0.98**| 69 | ||Snow100K|*33.79*/**33.92**|*0.95*/**0.96**| 70 | |**Image Deraining**|Test100|31.40|0.919| 71 | ||Test2800|33.73|0.937| 72 | |**Defocus Deblurring**|DPDD|*26.06*/**26.16**/26.36|*0.810*/**0.814**/0.820| 73 | |**Motion Deblurring**|GoPro|33.28|0.963| 74 | ||RSBlur|34.06|0.868| 75 | 76 | 77 | ## Citation 78 | ~~~ 79 | @article{cui2024revitalizing, 80 | title={Revitalizing Convolutional Network for Image Restoration}, 81 | author={Cui, Yuning and Ren, Wenqi and Cao, Xiaochun and Knoll, Alois}, 82 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 83 | year={2024}, 84 | publisher={IEEE} 85 | } 86 | 87 | @inproceedings{cui2023irnext, 88 | title={IRNeXt: Rethinking Convolutional Network Design for Image Restoration}, 89 | author={Cui, Yuning and Ren, Wenqi and Yang, Sining and Cao, Xiaochun and Knoll, Alois}, 90 | booktitle={International Conference on Machine Learning}, 91 | pages={6545--6564}, 92 | year={2023}, 93 | organization={PMLR} 94 | } 95 | ~~~ 96 | 97 | ## Contact 98 | Should you have any problem, please contact Yuning Cui. 99 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | from warmup_scheduler.scheduler import GradualWarmupScheduler -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) --------------------------------------------------------------------------------