├── .gitignore ├── Dehazing ├── ITS │ ├── data │ │ ├── __init__.py │ │ ├── data_augment.py │ │ └── data_load.py │ ├── eval.py │ ├── main.py │ ├── models │ │ ├── OKNet.py │ │ └── layers.py │ ├── train.py │ ├── utils.py │ └── valid.py ├── OTS │ ├── data │ │ ├── __init__.py │ │ ├── data_augment.py │ │ └── data_load.py │ ├── eval.py │ ├── main.py │ ├── models │ │ ├── OKNet.py │ │ └── layers.py │ ├── train.py │ ├── utils.py │ └── valid.py └── README.md ├── Desnowing ├── README.md ├── data │ ├── __init__.py │ ├── data_augment.py │ └── data_load.py ├── eval.py ├── main.py ├── models │ ├── OKNet.py │ └── layers.py ├── train.py ├── utils.py └── valid.py ├── LICENSE ├── README.md └── pytorch-gradual-warmup-lr ├── setup.py └── warmup_scheduler ├── __init__.py ├── run.py └── scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Dehazing/ITS/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image as Image 5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor 6 | from torchvision.transforms import functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True): 11 | image_dir = os.path.join(path, 'train') 12 | 13 | transform = None 14 | if use_transform: 15 | transform = PairCompose( 16 | [ 17 | PairRandomCrop(256), 18 | PairRandomHorizontalFilp(), 19 | PairToTensor() 20 | ] 21 | ) 22 | dataloader = DataLoader( 23 | DeblurDataset(image_dir, transform=transform), 24 | batch_size=batch_size, 25 | shuffle=True, 26 | num_workers=num_workers, 27 | pin_memory=True 28 | ) 29 | return dataloader 30 | 31 | 32 | def test_dataloader(path, batch_size=1, num_workers=0): 33 | image_dir = os.path.join(path, 'test') 34 | dataloader = DataLoader( 35 | DeblurDataset(image_dir, is_test=True), 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers, 39 | pin_memory=True 40 | ) 41 | 42 | return dataloader 43 | 44 | 45 | def valid_dataloader(path, batch_size=1, num_workers=0): 46 | dataloader = DataLoader( 47 | DeblurDataset(os.path.join(path, 'test')), 48 | batch_size=batch_size, 49 | shuffle=False, 50 | num_workers=num_workers 51 | ) 52 | 53 | return dataloader 54 | 55 | 56 | class DeblurDataset(Dataset): 57 | def __init__(self, image_dir, transform=None, is_test=False): 58 | self.image_dir = image_dir 59 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/')) 60 | self._check_image(self.image_list) 61 | self.image_list.sort() 62 | self.transform = transform 63 | self.is_test = is_test 64 | 65 | def __len__(self): 66 | return len(self.image_list) 67 | 68 | def __getitem__(self, idx): 69 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])) 70 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png')) 71 | 72 | if self.transform: 73 | image, label = self.transform(image, label) 74 | else: 75 | image = F.to_tensor(image) 76 | label = F.to_tensor(label) 77 | if self.is_test: 78 | name = self.image_list[idx] 79 | return image, label, name 80 | return image, label 81 | 82 | @staticmethod 83 | def _check_image(lst): 84 | for x in lst: 85 | splits = x.split('.') 86 | if splits[-1] not in ['png', 'jpg', 'jpeg']: 87 | raise ValueError 88 | -------------------------------------------------------------------------------- /Dehazing/ITS/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | import numpy as np 5 | from utils import Adder 6 | from data import test_dataloader 7 | from skimage.metrics import peak_signal_noise_ratio 8 | import time 9 | from pytorch_msssim import ssim 10 | import torch.nn.functional as f 11 | 12 | from skimage import img_as_ubyte 13 | import cv2 14 | 15 | def _eval(model, args): 16 | state_dict = torch.load(args.test_model) 17 | model.load_state_dict(state_dict['model']) 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 20 | torch.cuda.empty_cache() 21 | adder = Adder() 22 | model.eval() 23 | factor = 4 24 | with torch.no_grad(): 25 | psnr_adder = Adder() 26 | ssim_adder = Adder() 27 | 28 | for iter_idx, data in enumerate(dataloader): 29 | input_img, label_img, name = data 30 | 31 | input_img = input_img.to(device) 32 | 33 | h, w = input_img.shape[2], input_img.shape[3] 34 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 35 | padh = H-h if h%factor!=0 else 0 36 | padw = W-w if w%factor!=0 else 0 37 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 38 | 39 | tm = time.time() 40 | 41 | pred = model(input_img)[2] 42 | pred = pred[:,:,:h,:w] 43 | 44 | elapsed = time.time() - tm 45 | adder(elapsed) 46 | 47 | pred_clip = torch.clamp(pred, 0, 1) 48 | 49 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 50 | label_numpy = label_img.squeeze(0).cpu().numpy() 51 | 52 | 53 | label_img = (label_img).cuda() 54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img)) 55 | down_ratio = max(1, round(min(H, W) / 256)) 56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 58 | data_range=1, size_average=False) 59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val)) 60 | ssim_adder(ssim_val) 61 | 62 | if args.save_image: 63 | save_name = os.path.join(args.result_dir, name[0]) 64 | pred_clip += 0.5 / 255 65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 66 | pred.save(save_name) 67 | 68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 69 | psnr_adder(psnr_val) 70 | 71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed)) 72 | 73 | print('==========================================================') 74 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 75 | print('The average SSIM is %.5f dB' % (ssim_adder.average())) 76 | 77 | print("Average time: %f" % adder.average()) 78 | 79 | -------------------------------------------------------------------------------- /Dehazing/ITS/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.OKNet import build_net 6 | from train import _train 7 | from eval import _eval 8 | import numpy as np 9 | import random 10 | 11 | def main(args): 12 | # CUDNN 13 | cudnn.benchmark = True 14 | 15 | if not os.path.exists('results/'): 16 | os.makedirs(args.model_save_dir) 17 | if not os.path.exists('results/' + args.model_name + '/'): 18 | os.makedirs('results/' + args.model_name + '/') 19 | if not os.path.exists(args.model_save_dir): 20 | os.makedirs(args.model_save_dir) 21 | if not os.path.exists(args.result_dir): 22 | os.makedirs(args.result_dir) 23 | 24 | model = build_net() 25 | print(model) 26 | 27 | if torch.cuda.is_available(): 28 | model.cuda() 29 | if args.mode == 'train': 30 | _train(model, args) 31 | 32 | elif args.mode == 'test': 33 | _eval(model, args) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | 39 | # Directories 40 | parser.add_argument('--model_name', default='OKNet', type=str) 41 | 42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) 43 | parser.add_argument('--data_dir', type=str, default='') 44 | 45 | # Train 46 | parser.add_argument('--batch_size', type=int, default=8) 47 | parser.add_argument('--learning_rate', type=float, default=2e-4) 48 | parser.add_argument('--weight_decay', type=float, default=0) 49 | parser.add_argument('--num_epoch', type=int, default=1000) 50 | parser.add_argument('--print_freq', type=int, default=100) 51 | parser.add_argument('--num_worker', type=int, default=16) 52 | parser.add_argument('--save_freq', type=int, default=20) 53 | parser.add_argument('--valid_freq', type=int, default=20) 54 | parser.add_argument('--resume', type=str, default='') 55 | 56 | 57 | # Test 58 | parser.add_argument('--test_model', type=str, default='') 59 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 60 | 61 | args = parser.parse_args() 62 | args.model_save_dir = os.path.join('results/', 'OKNet', 'ITS/') 63 | args.result_dir = os.path.join('results/', args.model_name, 'test') 64 | if not os.path.exists(args.model_save_dir): 65 | os.makedirs(args.model_save_dir) 66 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 67 | os.system(command) 68 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir 69 | os.system(command) 70 | command = 'cp ' + 'train.py ' + args.model_save_dir 71 | os.system(command) 72 | command = 'cp ' + 'main.py ' + args.model_save_dir 73 | os.system(command) 74 | print(args) 75 | main(args) 76 | -------------------------------------------------------------------------------- /Dehazing/ITS/models/OKNet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | class EBlock(nn.Module): 8 | def __init__(self, out_channel, num_res=8): 9 | super(EBlock, self).__init__() 10 | 11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 12 | 13 | self.layers = nn.Sequential(*layers) 14 | 15 | def forward(self, x): 16 | return self.layers(x) 17 | 18 | 19 | class DBlock(nn.Module): 20 | def __init__(self, channel, num_res=8): 21 | super(DBlock, self).__init__() 22 | 23 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 24 | self.layers = nn.Sequential(*layers) 25 | 26 | def forward(self, x): 27 | return self.layers(x) 28 | 29 | 30 | class SCM(nn.Module): 31 | def __init__(self, out_plane): 32 | super(SCM, self).__init__() 33 | self.main = nn.Sequential( 34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 38 | nn.InstanceNorm2d(out_plane, affine=True) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class BottleNect(nn.Module): 47 | def __init__(self, dim) -> None: 48 | super().__init__() 49 | 50 | ker = 63 51 | pad = ker // 2 52 | self.in_conv = nn.Sequential( 53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1), 54 | nn.GELU() 55 | ) 56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1) 57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim) 58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim) 59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim) 60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim) 61 | 62 | self.act = nn.ReLU() 63 | 64 | ### sca ### 65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 66 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 67 | 68 | ### fca ### 69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1)) 71 | self.fgm = FGM(dim) 72 | 73 | def forward(self, x): 74 | out = self.in_conv(x) 75 | 76 | ### fca ### 77 | x_att = self.fac_conv(self.fac_pool(out)) 78 | x_fft = torch.fft.fft2(out, norm='backward') 79 | x_fft = x_att * x_fft 80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward') 81 | x_fca = torch.abs(x_fca) 82 | 83 | ### fca ### 84 | ### sca ### 85 | x_att = self.conv(self.pool(x_fca)) 86 | x_sca = x_att * x_fca 87 | ### sca ### 88 | x_sca = self.fgm(x_sca) 89 | 90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca 91 | out = self.act(out) 92 | return self.out_conv(out) 93 | 94 | class FGM(nn.Module): 95 | def __init__(self, dim) -> None: 96 | super().__init__() 97 | 98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1) 99 | 100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1) 101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1) 102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1)) 103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1)) 104 | 105 | def forward(self, x): 106 | # res = x.clone() 107 | fft_size = x.size()[2:] 108 | x1 = self.dwconv1(x) 109 | x2 = self.dwconv2(x) 110 | 111 | x2_fft = torch.fft.fft2(x2, norm='backward') 112 | 113 | out = x1 * x2_fft 114 | 115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward') 116 | out = torch.abs(out) 117 | 118 | return out * self.alpha + x * self.beta 119 | 120 | 121 | class FAM(nn.Module): 122 | def __init__(self, channel): 123 | super(FAM, self).__init__() 124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 125 | 126 | def forward(self, x1, x2): 127 | return self.merge(torch.cat([x1, x2], dim=1)) 128 | 129 | class OKNet(nn.Module): 130 | def __init__(self, num_res=4): 131 | super(OKNet, self).__init__() 132 | 133 | base_channel = 32 134 | 135 | self.Encoder = nn.ModuleList([ 136 | EBlock(base_channel, num_res), 137 | EBlock(base_channel*2, num_res), 138 | EBlock(base_channel*4, num_res), 139 | ]) 140 | 141 | self.feat_extract = nn.ModuleList([ 142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 148 | ]) 149 | 150 | self.Decoder = nn.ModuleList([ 151 | DBlock(base_channel * 4, num_res), 152 | DBlock(base_channel * 2, num_res), 153 | DBlock(base_channel, num_res) 154 | ]) 155 | 156 | self.Convs = nn.ModuleList([ 157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 159 | ]) 160 | 161 | self.ConvsOut = nn.ModuleList( 162 | [ 163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 165 | ] 166 | ) 167 | 168 | self.FAM1 = FAM(base_channel * 4) 169 | self.SCM1 = SCM(base_channel * 4) 170 | self.FAM2 = FAM(base_channel * 2) 171 | self.SCM2 = SCM(base_channel * 2) 172 | 173 | self.bottelneck = BottleNect(base_channel * 4) 174 | 175 | 176 | def forward(self, x): 177 | x_2 = F.interpolate(x, scale_factor=0.5) 178 | x_4 = F.interpolate(x_2, scale_factor=0.5) 179 | z2 = self.SCM2(x_2) 180 | z4 = self.SCM1(x_4) 181 | 182 | outputs = list() 183 | # 256 184 | x_ = self.feat_extract[0](x) 185 | res1 = self.Encoder[0](x_) 186 | # 128 187 | z = self.feat_extract[1](res1) 188 | z = self.FAM2(z, z2) 189 | res2 = self.Encoder[1](z) 190 | # 64 191 | z = self.feat_extract[2](res2) 192 | z = self.FAM1(z, z4) 193 | z = self.Encoder[2](z) 194 | z = self.bottelneck(z) 195 | 196 | z = self.Decoder[0](z) 197 | z_ = self.ConvsOut[0](z) 198 | # 128 199 | z = self.feat_extract[3](z) 200 | outputs.append(z_+x_4) 201 | 202 | z = torch.cat([z, res2], dim=1) 203 | z = self.Convs[0](z) 204 | z = self.Decoder[1](z) 205 | z_ = self.ConvsOut[1](z) 206 | # 256 207 | z = self.feat_extract[4](z) 208 | outputs.append(z_+x_2) 209 | 210 | z = torch.cat([z, res1], dim=1) 211 | z = self.Convs[1](z) 212 | z = self.Decoder[2](z) 213 | z = self.feat_extract[5](z) 214 | outputs.append(z+x) 215 | 216 | return outputs 217 | 218 | def build_net(): 219 | return OKNet() 220 | 221 | -------------------------------------------------------------------------------- /Dehazing/ITS/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.main(x) + x -------------------------------------------------------------------------------- /Dehazing/ITS/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | def _train(model, args): 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | criterion = torch.nn.L1Loss() 15 | 16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 18 | max_iter = len(dataloader) 19 | warmup_epochs=3 20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 22 | scheduler.step() 23 | epoch = 1 24 | if args.resume: 25 | state = torch.load(args.resume) 26 | epoch = state['epoch'] 27 | optimizer.load_state_dict(state['optimizer']) 28 | model.load_state_dict(state['model']) 29 | print('Resume from %d'%epoch) 30 | epoch += 1 31 | 32 | writer = SummaryWriter() 33 | epoch_pixel_adder = Adder() 34 | epoch_fft_adder = Adder() 35 | iter_pixel_adder = Adder() 36 | iter_fft_adder = Adder() 37 | epoch_timer = Timer('m') 38 | iter_timer = Timer('m') 39 | best_psnr=-1 40 | 41 | for epoch_idx in range(epoch, args.num_epoch + 1): 42 | 43 | epoch_timer.tic() 44 | iter_timer.tic() 45 | for iter_idx, batch_data in enumerate(dataloader): 46 | 47 | input_img, label_img = batch_data 48 | input_img = input_img.to(device) 49 | label_img = label_img.to(device) 50 | 51 | optimizer.zero_grad() 52 | pred_img = model(input_img) 53 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 54 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 55 | l1 = criterion(pred_img[0], label_img4) 56 | l2 = criterion(pred_img[1], label_img2) 57 | l3 = criterion(pred_img[2], label_img) 58 | loss_content = l1+l2+l3 59 | 60 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 61 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 62 | 63 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 64 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 65 | 66 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 67 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 68 | 69 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 70 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 71 | 72 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 73 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 74 | 75 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 76 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 77 | 78 | f1 = criterion(pred_fft1, label_fft1) 79 | f2 = criterion(pred_fft2, label_fft2) 80 | f3 = criterion(pred_fft3, label_fft3) 81 | loss_fft = f1+f2+f3 82 | 83 | loss = loss_content + 0.1 * loss_fft 84 | loss.backward() 85 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001) 86 | optimizer.step() 87 | 88 | iter_pixel_adder(loss_content.item()) 89 | iter_fft_adder(loss_fft.item()) 90 | 91 | epoch_pixel_adder(loss_content.item()) 92 | epoch_fft_adder(loss_fft.item()) 93 | 94 | if (iter_idx + 1) % args.print_freq == 0: 95 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 96 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 97 | iter_fft_adder.average())) 98 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 99 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 100 | 101 | iter_timer.tic() 102 | iter_pixel_adder.reset() 103 | iter_fft_adder.reset() 104 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 105 | torch.save({'model': model.state_dict(), 106 | 'optimizer': optimizer.state_dict(), 107 | 'epoch': epoch_idx}, overwrite_name) 108 | 109 | if epoch_idx % args.save_freq == 0: 110 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 111 | torch.save({'model': model.state_dict()}, save_name) 112 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 113 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 114 | epoch_fft_adder.reset() 115 | epoch_pixel_adder.reset() 116 | scheduler.step() 117 | if epoch_idx % args.valid_freq == 0: 118 | val = _valid(model, args, epoch_idx) 119 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val)) 120 | writer.add_scalar('PSNR', val, epoch_idx) 121 | if val >= best_psnr: 122 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 124 | torch.save({'model': model.state_dict()}, save_name) 125 | -------------------------------------------------------------------------------- /Dehazing/ITS/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Dehazing/ITS/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | its = valid_dataloader(args.data_dir, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Evaluation') 18 | factor = 4 19 | for idx, data in enumerate(its): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /Dehazing/OTS/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Dehazing/OTS/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Dehazing/OTS/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image as Image 5 | from torchvision.transforms import functional as F 6 | from torch.utils.data import Dataset, DataLoader 7 | from PIL import ImageFile 8 | ImageFile.LOAD_TRUNCATED_IMAGES = True 9 | 10 | def train_dataloader(path, batch_size=64, num_workers=0): 11 | image_dir = os.path.join(path, 'train') 12 | 13 | dataloader = DataLoader( 14 | DeblurDataset(image_dir, ps=256), 15 | batch_size=batch_size, 16 | shuffle=True, 17 | num_workers=num_workers, 18 | pin_memory=True 19 | ) 20 | return dataloader 21 | 22 | 23 | def test_dataloader(path, batch_size=1, num_workers=0): 24 | image_dir = os.path.join(path, 'test') 25 | dataloader = DataLoader( 26 | DeblurDataset(image_dir, is_test=True), 27 | batch_size=batch_size, 28 | shuffle=False, 29 | num_workers=num_workers, 30 | pin_memory=True 31 | ) 32 | 33 | return dataloader 34 | 35 | 36 | def valid_dataloader(path, batch_size=1, num_workers=0): 37 | dataloader = DataLoader( 38 | DeblurDataset(os.path.join(path, 'test'), is_valid=True), 39 | batch_size=batch_size, 40 | shuffle=False, 41 | num_workers=num_workers 42 | ) 43 | 44 | return dataloader 45 | 46 | import random 47 | class DeblurDataset(Dataset): 48 | def __init__(self, image_dir, transform=None, is_test=False, is_valid=False, ps=None): 49 | self.image_dir = image_dir 50 | self.image_list = os.listdir(os.path.join(image_dir, 'hazy/')) 51 | self._check_image(self.image_list) 52 | self.image_list.sort() 53 | self.transform = transform 54 | self.is_test = is_test 55 | self.is_valid = is_valid 56 | self.ps = ps 57 | 58 | def __len__(self): 59 | return len(self.image_list) 60 | 61 | def __getitem__(self, idx): 62 | image = Image.open(os.path.join(self.image_dir, 'hazy', self.image_list[idx])).convert('RGB') 63 | if self.is_valid or self.is_test: 64 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.png')).convert('RGB') 65 | else: 66 | label = Image.open(os.path.join(self.image_dir, 'gt', self.image_list[idx].split('_')[0]+'.jpg')).convert('RGB') 67 | ps = self.ps 68 | 69 | if self.ps is not None: 70 | image = F.to_tensor(image) 71 | label = F.to_tensor(label) 72 | 73 | hh, ww = label.shape[1], label.shape[2] 74 | 75 | rr = random.randint(0, hh-ps) 76 | cc = random.randint(0, ww-ps) 77 | 78 | image = image[:, rr:rr+ps, cc:cc+ps] 79 | label = label[:, rr:rr+ps, cc:cc+ps] 80 | 81 | if random.random() < 0.5: 82 | image = image.flip(2) 83 | label = label.flip(2) 84 | else: 85 | image = F.to_tensor(image) 86 | label = F.to_tensor(label) 87 | 88 | if self.is_test: 89 | name = self.image_list[idx] 90 | return image, label, name 91 | return image, label 92 | 93 | 94 | 95 | @staticmethod 96 | def _check_image(lst): 97 | for x in lst: 98 | splits = x.split('.') 99 | if splits[-1] not in ['png', 'jpg', 'jpeg']: 100 | raise ValueError 101 | -------------------------------------------------------------------------------- /Dehazing/OTS/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | import numpy as np 5 | from utils import Adder 6 | from data import test_dataloader 7 | from skimage.metrics import peak_signal_noise_ratio 8 | import time 9 | from pytorch_msssim import ssim 10 | import torch.nn.functional as f 11 | 12 | from skimage import img_as_ubyte 13 | import cv2 14 | # --------------------------------------------------- 15 | 16 | def _eval(model, args): 17 | state_dict = torch.load(args.test_model) 18 | model.load_state_dict(state_dict['model']) 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 21 | torch.cuda.empty_cache() 22 | adder = Adder() 23 | model.eval() 24 | factor = 4 25 | with torch.no_grad(): 26 | psnr_adder = Adder() 27 | ssim_adder = Adder() 28 | 29 | for iter_idx, data in enumerate(dataloader): 30 | input_img, label_img, name = data 31 | 32 | input_img = input_img.to(device) 33 | 34 | h, w = input_img.shape[2], input_img.shape[3] 35 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 36 | padh = H-h if h%factor!=0 else 0 37 | padw = W-w if w%factor!=0 else 0 38 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 39 | 40 | tm = time.time() 41 | 42 | pred = model(input_img)[2] 43 | pred = pred[:,:,:h,:w] 44 | 45 | elapsed = time.time() - tm 46 | adder(elapsed) 47 | 48 | pred_clip = torch.clamp(pred, 0, 1) 49 | 50 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 51 | label_numpy = label_img.squeeze(0).cpu().numpy() 52 | 53 | label_img = (label_img).cuda() 54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img)) 55 | down_ratio = max(1, round(min(H, W) / 256)) 56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 58 | data_range=1, size_average=False) 59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val)) 60 | ssim_adder(ssim_val) 61 | 62 | if args.save_image: 63 | save_name = os.path.join(args.result_dir, name[0]) 64 | pred_clip += 0.5 / 255 65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 66 | pred.save(save_name) 67 | 68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 69 | psnr_adder(psnr_val) 70 | 71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed)) 72 | 73 | print('==========================================================') 74 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 75 | print('The average SSIM is %.4f dB' % (ssim_adder.average())) 76 | 77 | print("Average time: %f" % adder.average()) 78 | 79 | -------------------------------------------------------------------------------- /Dehazing/OTS/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.OKNet import build_net 6 | from train import _train 7 | from eval import _eval 8 | import numpy as np 9 | import random 10 | 11 | def main(args): 12 | # CUDNN 13 | cudnn.benchmark = True 14 | 15 | if not os.path.exists('results/'): 16 | os.makedirs(args.model_save_dir) 17 | if not os.path.exists('results/' + args.model_name + '/'): 18 | os.makedirs('results/' + args.model_name + '/') 19 | if not os.path.exists(args.model_save_dir): 20 | os.makedirs(args.model_save_dir) 21 | if not os.path.exists(args.result_dir): 22 | os.makedirs(args.result_dir) 23 | 24 | model = build_net() 25 | print(model) 26 | 27 | if torch.cuda.is_available(): 28 | model.cuda() 29 | if args.mode == 'train': 30 | _train(model, args) 31 | 32 | elif args.mode == 'test': 33 | _eval(model, args) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | 39 | # Directories 40 | parser.add_argument('--model_name', default='OKNet',type=str) 41 | parser.add_argument('--data_dir', type=str, default='') 42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) 43 | 44 | # Train 45 | parser.add_argument('--batch_size', type=int, default=8) 46 | parser.add_argument('--learning_rate', type=float, default=1e-4) 47 | parser.add_argument('--weight_decay', type=float, default=0) 48 | parser.add_argument('--num_epoch', type=int, default=30) 49 | parser.add_argument('--print_freq', type=int, default=100) 50 | parser.add_argument('--num_worker', type=int, default=8) 51 | parser.add_argument('--save_freq', type=int, default=1) 52 | parser.add_argument('--valid_freq', type=int, default=1) 53 | parser.add_argument('--resume', type=str, default='') 54 | 55 | 56 | # Test 57 | parser.add_argument('--test_model', type=str, default='') 58 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 59 | 60 | args = parser.parse_args() 61 | args.model_save_dir = os.path.join('results/', 'OKNet', 'ots/') 62 | args.result_dir = os.path.join('results/', args.model_name, 'test') 63 | if not os.path.exists(args.model_save_dir): 64 | os.makedirs(args.model_save_dir) 65 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 66 | os.system(command) 67 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir 68 | os.system(command) 69 | command = 'cp ' + 'train.py ' + args.model_save_dir 70 | os.system(command) 71 | command = 'cp ' + 'main.py ' + args.model_save_dir 72 | os.system(command) 73 | print(args) 74 | main(args) 75 | -------------------------------------------------------------------------------- /Dehazing/OTS/models/OKNet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | class EBlock(nn.Module): 8 | def __init__(self, out_channel, num_res=8): 9 | super(EBlock, self).__init__() 10 | 11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 12 | 13 | self.layers = nn.Sequential(*layers) 14 | 15 | def forward(self, x): 16 | return self.layers(x) 17 | 18 | 19 | class DBlock(nn.Module): 20 | def __init__(self, channel, num_res=8): 21 | super(DBlock, self).__init__() 22 | 23 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 24 | self.layers = nn.Sequential(*layers) 25 | 26 | def forward(self, x): 27 | return self.layers(x) 28 | 29 | 30 | class SCM(nn.Module): 31 | def __init__(self, out_plane): 32 | super(SCM, self).__init__() 33 | self.main = nn.Sequential( 34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 38 | nn.InstanceNorm2d(out_plane, affine=True) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class BottleNect(nn.Module): 47 | def __init__(self, dim) -> None: 48 | super().__init__() 49 | 50 | ker = 63 51 | pad = ker // 2 52 | self.in_conv = nn.Sequential( 53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1), 54 | nn.GELU() 55 | ) 56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1) 57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim) 58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim) 59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim) 60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim) 61 | 62 | self.act = nn.ReLU() 63 | 64 | ### sca ### 65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 66 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 67 | 68 | ### fca ### 69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1)) 71 | self.fgm = FGM(dim) 72 | 73 | def forward(self, x): 74 | out = self.in_conv(x) 75 | 76 | ### fca ### 77 | x_att = self.fac_conv(self.fac_pool(out)) 78 | x_fft = torch.fft.fft2(out, norm='backward') 79 | x_fft = x_att * x_fft 80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward') 81 | x_fca = torch.abs(x_fca) 82 | 83 | ### fca ### 84 | ### sca ### 85 | x_att = self.conv(self.pool(x_fca)) 86 | x_sca = x_att * x_fca 87 | ### sca ### 88 | x_sca = self.fgm(x_sca) 89 | 90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca 91 | out = self.act(out) 92 | return self.out_conv(out) 93 | 94 | class FGM(nn.Module): 95 | def __init__(self, dim) -> None: 96 | super().__init__() 97 | 98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1) 99 | 100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1) 101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1) 102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1)) 103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1)) 104 | 105 | def forward(self, x): 106 | # res = x.clone() 107 | fft_size = x.size()[2:] 108 | x1 = self.dwconv1(x) 109 | x2 = self.dwconv2(x) 110 | 111 | x2_fft = torch.fft.fft2(x2, norm='backward') 112 | 113 | out = x1 * x2_fft 114 | 115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward') 116 | out = torch.abs(out) 117 | 118 | return out * self.alpha + x * self.beta 119 | 120 | 121 | class FAM(nn.Module): 122 | def __init__(self, channel): 123 | super(FAM, self).__init__() 124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 125 | 126 | def forward(self, x1, x2): 127 | return self.merge(torch.cat([x1, x2], dim=1)) 128 | 129 | class OKNet(nn.Module): 130 | def __init__(self, num_res=4): 131 | super(OKNet, self).__init__() 132 | 133 | base_channel = 32 134 | 135 | self.Encoder = nn.ModuleList([ 136 | EBlock(base_channel, num_res), 137 | EBlock(base_channel*2, num_res), 138 | EBlock(base_channel*4, num_res), 139 | ]) 140 | 141 | self.feat_extract = nn.ModuleList([ 142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 148 | ]) 149 | 150 | self.Decoder = nn.ModuleList([ 151 | DBlock(base_channel * 4, num_res), 152 | DBlock(base_channel * 2, num_res), 153 | DBlock(base_channel, num_res) 154 | ]) 155 | 156 | self.Convs = nn.ModuleList([ 157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 159 | ]) 160 | 161 | self.ConvsOut = nn.ModuleList( 162 | [ 163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 165 | ] 166 | ) 167 | 168 | self.FAM1 = FAM(base_channel * 4) 169 | self.SCM1 = SCM(base_channel * 4) 170 | self.FAM2 = FAM(base_channel * 2) 171 | self.SCM2 = SCM(base_channel * 2) 172 | 173 | self.bottelneck = BottleNect(base_channel * 4) 174 | 175 | 176 | def forward(self, x): 177 | x_2 = F.interpolate(x, scale_factor=0.5) 178 | x_4 = F.interpolate(x_2, scale_factor=0.5) 179 | z2 = self.SCM2(x_2) 180 | z4 = self.SCM1(x_4) 181 | 182 | outputs = list() 183 | # 256 184 | x_ = self.feat_extract[0](x) 185 | res1 = self.Encoder[0](x_) 186 | # 128 187 | z = self.feat_extract[1](res1) 188 | z = self.FAM2(z, z2) 189 | res2 = self.Encoder[1](z) 190 | # 64 191 | z = self.feat_extract[2](res2) 192 | z = self.FAM1(z, z4) 193 | z = self.Encoder[2](z) 194 | z = self.bottelneck(z) 195 | 196 | z = self.Decoder[0](z) 197 | z_ = self.ConvsOut[0](z) 198 | # 128 199 | z = self.feat_extract[3](z) 200 | outputs.append(z_+x_4) 201 | 202 | z = torch.cat([z, res2], dim=1) 203 | z = self.Convs[0](z) 204 | z = self.Decoder[1](z) 205 | z_ = self.ConvsOut[1](z) 206 | # 256 207 | z = self.feat_extract[4](z) 208 | outputs.append(z_+x_2) 209 | 210 | z = torch.cat([z, res1], dim=1) 211 | z = self.Convs[1](z) 212 | z = self.Decoder[2](z) 213 | z = self.feat_extract[5](z) 214 | outputs.append(z+x) 215 | 216 | return outputs 217 | 218 | def build_net(): 219 | return OKNet() 220 | 221 | -------------------------------------------------------------------------------- /Dehazing/OTS/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.main(x) + x -------------------------------------------------------------------------------- /Dehazing/OTS/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | def _train(model, args): 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | criterion = torch.nn.L1Loss() 15 | 16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 18 | max_iter = len(dataloader) 19 | warmup_epochs=1 20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 22 | scheduler.step() 23 | epoch = 1 24 | if args.resume: 25 | state = torch.load(args.resume) 26 | epoch = state['epoch'] 27 | optimizer.load_state_dict(state['optimizer']) 28 | model.load_state_dict(state['model']) 29 | print('Resume from %d'%epoch) 30 | epoch += 1 31 | 32 | writer = SummaryWriter() 33 | epoch_pixel_adder = Adder() 34 | epoch_fft_adder = Adder() 35 | iter_pixel_adder = Adder() 36 | iter_fft_adder = Adder() 37 | epoch_timer = Timer('m') 38 | iter_timer = Timer('m') 39 | best_psnr=-1 40 | 41 | eval_now = max_iter//6-1 42 | 43 | for epoch_idx in range(epoch, args.num_epoch + 1): 44 | 45 | epoch_timer.tic() 46 | iter_timer.tic() 47 | for iter_idx, batch_data in enumerate(dataloader): 48 | 49 | input_img, label_img = batch_data 50 | input_img = input_img.to(device) 51 | label_img = label_img.to(device) 52 | 53 | optimizer.zero_grad() 54 | pred_img = model(input_img) 55 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 56 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 57 | l1 = criterion(pred_img[0], label_img4) 58 | l2 = criterion(pred_img[1], label_img2) 59 | l3 = criterion(pred_img[2], label_img) 60 | loss_content = l1+l2+l3 61 | 62 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 63 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 64 | 65 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 66 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 67 | 68 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 69 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 70 | 71 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 72 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 73 | 74 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 75 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 76 | 77 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 78 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 79 | 80 | f1 = criterion(pred_fft1, label_fft1) 81 | f2 = criterion(pred_fft2, label_fft2) 82 | f3 = criterion(pred_fft3, label_fft3) 83 | loss_fft = f1+f2+f3 84 | 85 | loss = loss_content + 0.1 * loss_fft 86 | loss.backward() 87 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.01) 88 | optimizer.step() 89 | 90 | iter_pixel_adder(loss_content.item()) 91 | iter_fft_adder(loss_fft.item()) 92 | 93 | epoch_pixel_adder(loss_content.item()) 94 | epoch_fft_adder(loss_fft.item()) 95 | 96 | if (iter_idx + 1) % args.print_freq == 0: 97 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 98 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 99 | iter_fft_adder.average())) 100 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 101 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 102 | 103 | iter_timer.tic() 104 | iter_pixel_adder.reset() 105 | iter_fft_adder.reset() 106 | 107 | 108 | if iter_idx%eval_now==0 and iter_idx>0 and (epoch_idx>20 or epoch_idx == 1): 109 | 110 | save_name = os.path.join(args.model_save_dir, 'model_%d_%d.pkl' % (epoch_idx, iter_idx)) 111 | torch.save({'model': model.state_dict()}, save_name) 112 | 113 | val_gopro = _valid(model, args, epoch_idx) 114 | print('%03d epoch \n Average GOPRO PSNR %.2f dB' % (epoch_idx, val_gopro)) 115 | writer.add_scalar('PSNR_GOPRO', val_gopro, epoch_idx) 116 | if val_gopro >= best_psnr: 117 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 118 | 119 | 120 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 121 | torch.save({'model': model.state_dict()}, overwrite_name) 122 | 123 | 124 | if epoch_idx % args.save_freq == 0: 125 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 126 | torch.save({'model': model.state_dict()}, save_name) 127 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 128 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 129 | epoch_fft_adder.reset() 130 | epoch_pixel_adder.reset() 131 | scheduler.step() 132 | 133 | if epoch_idx % args.valid_freq == 0: 134 | val = _valid(model, args, epoch_idx) 135 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val)) 136 | writer.add_scalar('PSNR', val, epoch_idx) 137 | if val >= best_psnr: 138 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 139 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 140 | torch.save({'model': model.state_dict()}, save_name) 141 | -------------------------------------------------------------------------------- /Dehazing/OTS/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Dehazing/OTS/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | ots = valid_dataloader(args.data_dir, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Evaluation') 18 | factor = 4 19 | for idx, data in enumerate(ots): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /Dehazing/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - reside-indoor [[gdrive](https://drive.google.com/drive/folders/1pbtfTp29j7Ip-mRzDpMpyopCfXd-ZJhC?usp=sharing), [Baidu](https://pan.baidu.com/s/1jD-TU0wdtSoEb4ki-Cut2A?pwd=1lr0)] 3 | - reside-outdoor [[gdrive](https://drive.google.com/drive/folders/1eL4Qs-WNj7PzsKwDRsgUEzmysdjkRs22?usp=sharing)] 4 | - (Separate SOTS test set if needed) [[gdrive](https://drive.google.com/file/d/16j2dwVIa9q_0RtpIXMzhu-7Q6dwz_D1N/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1R6qWri7sG1hC_Ifj-H6DOQ?pwd=o5sk)] 5 | ### Train on RESIDE-Indoor 6 | 7 | ~~~ 8 | cd ITS 9 | python main.py --mode train --data_dir your_path/reside-indoor 10 | ~~~ 11 | 12 | 13 | ### Train on RESIDE-Outdoor 14 | ~~~ 15 | cd OTS 16 | python main.py --mode train --data_dir your_path/reside-outdoor 17 | ~~~ 18 | 19 | 20 | ### Evaluation 21 | #### Download the model [here](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing) 22 | #### Testing on SOTS-Indoor 23 | ~~~ 24 | cd ITS 25 | python main.py --data_dir your_path/reside-indoor --test_model path_to_its_model 26 | ~~~ 27 | #### Testing on SOTS-Outdoor 28 | ~~~ 29 | cd OTS 30 | python main.py --data_dir your_path/reside-outdoor --test_model path_to_ots_model 31 | ~~~ 32 | 33 | For training and testing, your directory structure should look like this 34 | 35 | `Your path` 
36 | `├──reside-indoor`
37 |      `├──train`
38 |           `├──gt`
39 |           `└──hazy` 40 |      `└──test`
41 |           `├──gt`
42 |           `└──hazy` 43 | `└──reside-outdoor`
44 |      `├──train`
45 |           `├──gt`
46 |           `└──hazy` 47 |      `└──test`
48 |           `├──gt`
49 |           `└──hazy` 50 | -------------------------------------------------------------------------------- /Desnowing/README.md: -------------------------------------------------------------------------------- 1 | ### Download the Datasets 2 | - SRRS [[gdrive](https://drive.google.com/file/d/11h1cZ0NXx6ev35cl5NKOAL3PCgLlWUl2/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1VXqsamkl12fPsI1Qek97TQ?pwd=vcfg)] 3 | - CSD [[gdrive](https://drive.google.com/file/d/1pns-7uWy-0SamxjA40qOCkkhSu7o7ULb/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1N52Jnx0co9udJeYrbd3blA?pwd=sb4a)] 4 | - Snow100K [[gdrive](https://drive.google.com/file/d/19zJs0cJ6F3G3IlDHLU2BO7nHnCTMNrIS/view?usp=sharing), [Baidu](https://pan.baidu.com/s/1QGd5z9uM6vBKPnD5d7jQmA?pwd=aph4)] 5 | 6 | ### Training 7 | 8 | ~~~ 9 | python main.py --mode train --data_dir your_path/CSD 10 | ~~~ 11 | 12 | ### Evaluation 13 | #### Download the model [here](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing) 14 | #### Testing 15 | ~~~ 16 | python main.py --data_dir your_path/CSD 17 | ~~~ 18 | 19 | For training and testing, your directory structure should look like this 20 | 21 | `Your path`
22 |  `├──CSD`
23 |      `├──train2500`
24 |           `├──Gt`
25 |           `└──Snow` 26 |      `└──test2000`
27 |           `├──Gt`
28 |           `└──Snow` 29 |  `├──SRRS`
30 |      `├──train2500`
31 |           `├──Gt`
32 |           `└──Snow` 33 |      `└──test2000`
34 |           `├──Gt`
35 |           `└──Snow` 36 |  `└──Snow100K`
37 |      `├──train2500`
38 |           `├──Gt`
39 |           `└──Snow` 40 |      `└──test2000`
41 |           `├──Gt`
42 |           `└──Snow` 43 | -------------------------------------------------------------------------------- /Desnowing/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_augment import PairRandomCrop, PairCompose, PairRandomHorizontalFilp, PairToTensor 2 | from .data_load import train_dataloader, test_dataloader, valid_dataloader 3 | -------------------------------------------------------------------------------- /Desnowing/data/data_augment.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torchvision.transforms as transforms 3 | import torchvision.transforms.functional as F 4 | 5 | 6 | class PairRandomCrop(transforms.RandomCrop): 7 | 8 | def __call__(self, image, label): 9 | 10 | if self.padding is not None: 11 | image = F.pad(image, self.padding, self.fill, self.padding_mode) 12 | label = F.pad(label, self.padding, self.fill, self.padding_mode) 13 | 14 | # pad the width if needed 15 | if self.pad_if_needed and image.size[0] < self.size[1]: 16 | image = F.pad(image, (self.size[1] - image.size[0], 0), self.fill, self.padding_mode) 17 | label = F.pad(label, (self.size[1] - label.size[0], 0), self.fill, self.padding_mode) 18 | # pad the height if needed 19 | if self.pad_if_needed and image.size[1] < self.size[0]: 20 | image = F.pad(image, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 21 | label = F.pad(label, (0, self.size[0] - image.size[1]), self.fill, self.padding_mode) 22 | 23 | i, j, h, w = self.get_params(image, self.size) 24 | 25 | return F.crop(image, i, j, h, w), F.crop(label, i, j, h, w) 26 | 27 | 28 | class PairCompose(transforms.Compose): 29 | def __call__(self, image, label): 30 | for t in self.transforms: 31 | image, label = t(image, label) 32 | return image, label 33 | 34 | 35 | class PairRandomHorizontalFilp(transforms.RandomHorizontalFlip): 36 | def __call__(self, img, label): 37 | """ 38 | Args: 39 | img (PIL Image): Image to be flipped. 40 | 41 | Returns: 42 | PIL Image: Randomly flipped image. 43 | """ 44 | if random.random() < self.p: 45 | return F.hflip(img), F.hflip(label) 46 | return img, label 47 | 48 | 49 | class PairToTensor(transforms.ToTensor): 50 | def __call__(self, pic, label): 51 | """ 52 | Args: 53 | pic (PIL Image or numpy.ndarray): Image to be converted to tensor. 54 | 55 | Returns: 56 | Tensor: Converted image. 57 | """ 58 | return F.to_tensor(pic), F.to_tensor(label) 59 | -------------------------------------------------------------------------------- /Desnowing/data/data_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from PIL import Image as Image 5 | from data import PairCompose, PairRandomCrop, PairRandomHorizontalFilp, PairToTensor 6 | from torchvision.transforms import functional as F 7 | from torch.utils.data import Dataset, DataLoader 8 | 9 | 10 | def train_dataloader(path, batch_size=64, num_workers=0, use_transform=True): 11 | image_dir = os.path.join(path, 'train2500') 12 | 13 | transform = None 14 | if use_transform: 15 | transform = PairCompose( 16 | [ 17 | PairRandomCrop(256), 18 | PairRandomHorizontalFilp(), 19 | PairToTensor() 20 | ] 21 | ) 22 | dataloader = DataLoader( 23 | DeblurDataset(image_dir, transform=transform), 24 | batch_size=batch_size, 25 | shuffle=True, 26 | num_workers=num_workers, 27 | pin_memory=True 28 | ) 29 | return dataloader 30 | 31 | 32 | def test_dataloader(path, batch_size=1, num_workers=0): 33 | image_dir = os.path.join(path, 'test2000') 34 | dataloader = DataLoader( 35 | DeblurDataset(image_dir, is_test=True), 36 | batch_size=batch_size, 37 | shuffle=False, 38 | num_workers=num_workers, 39 | pin_memory=True 40 | ) 41 | 42 | return dataloader 43 | 44 | 45 | def valid_dataloader(path, batch_size=1, num_workers=0): 46 | dataloader = DataLoader( 47 | DeblurDataset(os.path.join(path, 'test2000')), 48 | batch_size=batch_size, 49 | shuffle=False, 50 | num_workers=num_workers 51 | ) 52 | 53 | return dataloader 54 | 55 | 56 | class DeblurDataset(Dataset): 57 | def __init__(self, image_dir, transform=None, is_test=False): 58 | self.image_dir = image_dir 59 | self.image_list = os.listdir(os.path.join(image_dir, 'Snow/')) 60 | # self._check_image(self.image_list) 61 | self.image_list.sort() 62 | self.transform = transform 63 | self.is_test = is_test 64 | 65 | def __len__(self): 66 | return len(self.image_list) 67 | 68 | def __getitem__(self, idx): 69 | image = Image.open(os.path.join(self.image_dir, 'Snow', self.image_list[idx])) 70 | # label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx].split('.')[0]+'.jpg'))#srrs+jpg 71 | label = Image.open(os.path.join(self.image_dir, 'Gt', self.image_list[idx])) 72 | 73 | if self.transform: 74 | image, label = self.transform(image, label) 75 | else: 76 | image = F.to_tensor(image) 77 | label = F.to_tensor(label) 78 | if self.is_test: 79 | name = self.image_list[idx] 80 | return image, label, name 81 | return image, label 82 | 83 | -------------------------------------------------------------------------------- /Desnowing/eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.transforms import functional as F 4 | import numpy as np 5 | from utils import Adder 6 | from data import test_dataloader 7 | from skimage.metrics import peak_signal_noise_ratio 8 | import time 9 | from pytorch_msssim import ssim 10 | import torch.nn.functional as f 11 | 12 | from skimage import img_as_ubyte 13 | import cv2 14 | 15 | def _eval(model, args): 16 | state_dict = torch.load(args.test_model) 17 | model.load_state_dict(state_dict['model']) 18 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 19 | dataloader = test_dataloader(args.data_dir, batch_size=1, num_workers=0) 20 | torch.cuda.empty_cache() 21 | adder = Adder() 22 | model.eval() 23 | factor = 4 24 | with torch.no_grad(): 25 | psnr_adder = Adder() 26 | ssim_adder = Adder() 27 | 28 | for iter_idx, data in enumerate(dataloader): 29 | input_img, label_img, name = data 30 | 31 | input_img = input_img.to(device) 32 | 33 | h, w = input_img.shape[2], input_img.shape[3] 34 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 35 | padh = H-h if h%factor!=0 else 0 36 | padw = W-w if w%factor!=0 else 0 37 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 38 | 39 | tm = time.time() 40 | 41 | pred = model(input_img)[2] 42 | pred = pred[:,:,:h,:w] 43 | 44 | elapsed = time.time() - tm 45 | adder(elapsed) 46 | 47 | pred_clip = torch.clamp(pred, 0, 1) 48 | 49 | pred_numpy = pred_clip.squeeze(0).cpu().numpy() 50 | label_numpy = label_img.squeeze(0).cpu().numpy() 51 | 52 | 53 | label_img = (label_img).cuda() 54 | psnr_val = 10 * torch.log10(1 / f.mse_loss(pred_clip, label_img)) 55 | down_ratio = max(1, round(min(H, W) / 256)) 56 | ssim_val = ssim(f.adaptive_avg_pool2d(pred_clip, (int(H / down_ratio), int(W / down_ratio))), 57 | f.adaptive_avg_pool2d(label_img, (int(H / down_ratio), int(W / down_ratio))), 58 | data_range=1, size_average=False) 59 | print('%d iter PSNR_dehazing: %.2f ssim: %f' % (iter_idx + 1, psnr_val, ssim_val)) 60 | ssim_adder(ssim_val) 61 | 62 | if args.save_image: 63 | save_name = os.path.join(args.result_dir, name[0]) 64 | pred_clip += 0.5 / 255 65 | pred = F.to_pil_image(pred_clip.squeeze(0).cpu(), 'RGB') 66 | pred.save(save_name) 67 | 68 | psnr_mimo = peak_signal_noise_ratio(pred_numpy, label_numpy, data_range=1) 69 | psnr_adder(psnr_val) 70 | 71 | print('%d iter PSNR: %.2f time: %f' % (iter_idx + 1, psnr_mimo, elapsed)) 72 | 73 | print('==========================================================') 74 | print('The average PSNR is %.2f dB' % (psnr_adder.average())) 75 | print('The average SSIM is %.5f dB' % (ssim_adder.average())) 76 | 77 | print("Average time: %f" % adder.average()) 78 | 79 | -------------------------------------------------------------------------------- /Desnowing/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from torch.backends import cudnn 5 | from models.OKNet import build_net 6 | from train import _train 7 | from eval import _eval 8 | import numpy as np 9 | import random 10 | 11 | def main(args): 12 | # CUDNN 13 | cudnn.benchmark = True 14 | 15 | if not os.path.exists('results/'): 16 | os.makedirs(args.model_save_dir) 17 | if not os.path.exists('results/' + args.model_name + '/'): 18 | os.makedirs('results/' + args.model_name + '/') 19 | if not os.path.exists(args.model_save_dir): 20 | os.makedirs(args.model_save_dir) 21 | if not os.path.exists(args.result_dir): 22 | os.makedirs(args.result_dir) 23 | 24 | model = build_net() 25 | print(model) 26 | 27 | if torch.cuda.is_available(): 28 | model.cuda() 29 | if args.mode == 'train': 30 | _train(model, args) 31 | 32 | elif args.mode == 'test': 33 | _eval(model, args) 34 | 35 | 36 | if __name__ == '__main__': 37 | parser = argparse.ArgumentParser() 38 | 39 | # Directories 40 | parser.add_argument('--model_name', default='OKNet', type=str) 41 | parser.add_argument('--data_dir', type=str, default='') 42 | parser.add_argument('--mode', default='test', choices=['train', 'test'], type=str) 43 | 44 | # Train 45 | parser.add_argument('--batch_size', type=int, default=8) 46 | parser.add_argument('--learning_rate', type=float, default=2e-4) 47 | parser.add_argument('--weight_decay', type=float, default=0) 48 | parser.add_argument('--num_epoch', type=int, default=2000) 49 | parser.add_argument('--print_freq', type=int, default=100) 50 | parser.add_argument('--num_worker', type=int, default=16) 51 | parser.add_argument('--save_freq', type=int, default=50) 52 | parser.add_argument('--valid_freq', type=int, default=50) 53 | parser.add_argument('--resume', type=str, default='') 54 | 55 | # Test 56 | parser.add_argument('--test_model', type=str, default='') 57 | parser.add_argument('--save_image', type=bool, default=False, choices=[True, False]) 58 | 59 | args = parser.parse_args() 60 | args.model_save_dir = os.path.join('results/', 'OKNet', 'CSD/') 61 | args.result_dir = os.path.join('results/', args.model_name, 'test') 62 | if not os.path.exists(args.model_save_dir): 63 | os.makedirs(args.model_save_dir) 64 | command = 'cp ' + 'models/layers.py ' + args.model_save_dir 65 | os.system(command) 66 | command = 'cp ' + 'models/OKNet.py ' + args.model_save_dir 67 | os.system(command) 68 | command = 'cp ' + 'train.py ' + args.model_save_dir 69 | os.system(command) 70 | command = 'cp ' + 'main.py ' + args.model_save_dir 71 | os.system(command) 72 | print(args) 73 | main(args) 74 | -------------------------------------------------------------------------------- /Desnowing/models/OKNet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .layers import * 6 | 7 | class EBlock(nn.Module): 8 | def __init__(self, out_channel, num_res=8): 9 | super(EBlock, self).__init__() 10 | 11 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 12 | 13 | self.layers = nn.Sequential(*layers) 14 | 15 | def forward(self, x): 16 | return self.layers(x) 17 | 18 | 19 | class DBlock(nn.Module): 20 | def __init__(self, channel, num_res=8): 21 | super(DBlock, self).__init__() 22 | 23 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 24 | self.layers = nn.Sequential(*layers) 25 | 26 | def forward(self, x): 27 | return self.layers(x) 28 | 29 | 30 | class SCM(nn.Module): 31 | def __init__(self, out_plane): 32 | super(SCM, self).__init__() 33 | self.main = nn.Sequential( 34 | BasicConv(3, out_plane//4, kernel_size=3, stride=1, relu=True), 35 | BasicConv(out_plane // 4, out_plane // 2, kernel_size=1, stride=1, relu=True), 36 | BasicConv(out_plane // 2, out_plane // 2, kernel_size=3, stride=1, relu=True), 37 | BasicConv(out_plane // 2, out_plane, kernel_size=1, stride=1, relu=False), 38 | nn.InstanceNorm2d(out_plane, affine=True) 39 | ) 40 | 41 | def forward(self, x): 42 | x = self.main(x) 43 | return x 44 | 45 | 46 | class BottleNect(nn.Module): 47 | def __init__(self, dim) -> None: 48 | super().__init__() 49 | 50 | ker = 63 51 | pad = ker // 2 52 | self.in_conv = nn.Sequential( 53 | nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1), 54 | nn.GELU() 55 | ) 56 | self.out_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1) 57 | self.dw_13 = nn.Conv2d(dim, dim, kernel_size=(1,ker), padding=(0,pad), stride=1, groups=dim) 58 | self.dw_31 = nn.Conv2d(dim, dim, kernel_size=(ker,1), padding=(pad,0), stride=1, groups=dim) 59 | self.dw_33 = nn.Conv2d(dim, dim, kernel_size=ker, padding=pad, stride=1, groups=dim) 60 | self.dw_11 = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=dim) 61 | 62 | self.act = nn.ReLU() 63 | 64 | ### sca ### 65 | self.conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 66 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 67 | 68 | ### fca ### 69 | self.fac_conv = nn.Conv2d(dim, dim, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 70 | self.fac_pool = nn.AdaptiveAvgPool2d((1,1)) 71 | self.fgm = FGM(dim) 72 | 73 | def forward(self, x): 74 | out = self.in_conv(x) 75 | 76 | ### fca ### 77 | x_att = self.fac_conv(self.fac_pool(out)) 78 | x_fft = torch.fft.fft2(out, norm='backward') 79 | x_fft = x_att * x_fft 80 | x_fca = torch.fft.ifft2(x_fft, dim=(-2,-1), norm='backward') 81 | x_fca = torch.abs(x_fca) 82 | 83 | ### fca ### 84 | ### sca ### 85 | x_att = self.conv(self.pool(x_fca)) 86 | x_sca = x_att * x_fca 87 | ### sca ### 88 | x_sca = self.fgm(x_sca) 89 | 90 | out = x + self.dw_13(out) + self.dw_31(out) + self.dw_33(out) + self.dw_11(out) + x_sca 91 | out = self.act(out) 92 | return self.out_conv(out) 93 | 94 | class FGM(nn.Module): 95 | def __init__(self, dim) -> None: 96 | super().__init__() 97 | 98 | self.conv = nn.Conv2d(dim, dim*2, 3, 1, 1) 99 | 100 | self.dwconv1 = nn.Conv2d(dim, dim, 1, 1, groups=1) 101 | self.dwconv2 = nn.Conv2d(dim, dim, 1, 1, groups=1) 102 | self.alpha = nn.Parameter(torch.zeros(dim, 1, 1)) 103 | self.beta = nn.Parameter(torch.ones(dim, 1, 1)) 104 | 105 | def forward(self, x): 106 | # res = x.clone() 107 | fft_size = x.size()[2:] 108 | x1 = self.dwconv1(x) 109 | x2 = self.dwconv2(x) 110 | 111 | x2_fft = torch.fft.fft2(x2, norm='backward') 112 | 113 | out = x1 * x2_fft 114 | 115 | out = torch.fft.ifft2(out, dim=(-2,-1), norm='backward') 116 | out = torch.abs(out) 117 | 118 | return out * self.alpha + x * self.beta 119 | 120 | 121 | class FAM(nn.Module): 122 | def __init__(self, channel): 123 | super(FAM, self).__init__() 124 | self.merge = BasicConv(channel*2, channel, kernel_size=3, stride=1, relu=False) 125 | 126 | def forward(self, x1, x2): 127 | return self.merge(torch.cat([x1, x2], dim=1)) 128 | 129 | class OKNet(nn.Module): 130 | def __init__(self, num_res=4): 131 | super(OKNet, self).__init__() 132 | 133 | base_channel = 32 134 | 135 | self.Encoder = nn.ModuleList([ 136 | EBlock(base_channel, num_res), 137 | EBlock(base_channel*2, num_res), 138 | EBlock(base_channel*4, num_res), 139 | ]) 140 | 141 | self.feat_extract = nn.ModuleList([ 142 | BasicConv(3, base_channel, kernel_size=3, relu=True, stride=1), 143 | BasicConv(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 144 | BasicConv(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 145 | BasicConv(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 146 | BasicConv(base_channel*2, base_channel, kernel_size=4, relu=True, stride=2, transpose=True), 147 | BasicConv(base_channel, 3, kernel_size=3, relu=False, stride=1) 148 | ]) 149 | 150 | self.Decoder = nn.ModuleList([ 151 | DBlock(base_channel * 4, num_res), 152 | DBlock(base_channel * 2, num_res), 153 | DBlock(base_channel, num_res) 154 | ]) 155 | 156 | self.Convs = nn.ModuleList([ 157 | BasicConv(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 158 | BasicConv(base_channel * 2, base_channel, kernel_size=1, relu=True, stride=1), 159 | ]) 160 | 161 | self.ConvsOut = nn.ModuleList( 162 | [ 163 | BasicConv(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 164 | BasicConv(base_channel * 2, 3, kernel_size=3, relu=False, stride=1), 165 | ] 166 | ) 167 | 168 | self.FAM1 = FAM(base_channel * 4) 169 | self.SCM1 = SCM(base_channel * 4) 170 | self.FAM2 = FAM(base_channel * 2) 171 | self.SCM2 = SCM(base_channel * 2) 172 | 173 | self.bottelneck = BottleNect(base_channel * 4) 174 | 175 | 176 | def forward(self, x): 177 | x_2 = F.interpolate(x, scale_factor=0.5) 178 | x_4 = F.interpolate(x_2, scale_factor=0.5) 179 | z2 = self.SCM2(x_2) 180 | z4 = self.SCM1(x_4) 181 | 182 | outputs = list() 183 | # 256 184 | x_ = self.feat_extract[0](x) 185 | res1 = self.Encoder[0](x_) 186 | # 128 187 | z = self.feat_extract[1](res1) 188 | z = self.FAM2(z, z2) 189 | res2 = self.Encoder[1](z) 190 | # 64 191 | z = self.feat_extract[2](res2) 192 | z = self.FAM1(z, z4) 193 | z = self.Encoder[2](z) 194 | z = self.bottelneck(z) 195 | 196 | z = self.Decoder[0](z) 197 | z_ = self.ConvsOut[0](z) 198 | # 128 199 | z = self.feat_extract[3](z) 200 | outputs.append(z_+x_4) 201 | 202 | z = torch.cat([z, res2], dim=1) 203 | z = self.Convs[0](z) 204 | z = self.Decoder[1](z) 205 | z_ = self.ConvsOut[1](z) 206 | # 256 207 | z = self.feat_extract[4](z) 208 | outputs.append(z_+x_2) 209 | 210 | z = torch.cat([z, res1], dim=1) 211 | z = self.Convs[1](z) 212 | z = self.Decoder[2](z) 213 | z = self.feat_extract[5](z) 214 | outputs.append(z+x) 215 | 216 | return outputs 217 | 218 | def build_net(): 219 | return OKNet() 220 | 221 | -------------------------------------------------------------------------------- /Desnowing/models/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BasicConv(nn.Module): 6 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 7 | super(BasicConv, self).__init__() 8 | if bias and norm: 9 | bias = False 10 | 11 | padding = kernel_size // 2 12 | layers = list() 13 | if transpose: 14 | padding = kernel_size // 2 -1 15 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 16 | else: 17 | layers.append( 18 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 19 | if norm: 20 | layers.append(nn.BatchNorm2d(out_channel)) 21 | if relu: 22 | layers.append(nn.GELU()) 23 | self.main = nn.Sequential(*layers) 24 | 25 | def forward(self, x): 26 | return self.main(x) 27 | 28 | 29 | class ResBlock(nn.Module): 30 | def __init__(self, in_channel, out_channel): 31 | super(ResBlock, self).__init__() 32 | self.main = nn.Sequential( 33 | BasicConv(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 34 | BasicConv(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 35 | ) 36 | 37 | def forward(self, x): 38 | return self.main(x) + x -------------------------------------------------------------------------------- /Desnowing/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from data import train_dataloader 4 | from utils import Adder, Timer, check_lr 5 | from torch.utils.tensorboard import SummaryWriter 6 | from valid import _valid 7 | import torch.nn.functional as F 8 | import torch.nn as nn 9 | 10 | from warmup_scheduler import GradualWarmupScheduler 11 | 12 | def _train(model, args): 13 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 14 | criterion = torch.nn.L1Loss() 15 | 16 | optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-8) 17 | dataloader = train_dataloader(args.data_dir, args.batch_size, args.num_worker) 18 | max_iter = len(dataloader) 19 | warmup_epochs=3 20 | scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.num_epoch-warmup_epochs, eta_min=1e-6) 21 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 22 | scheduler.step() 23 | epoch = 1 24 | if args.resume: 25 | state = torch.load(args.resume) 26 | epoch = state['epoch'] 27 | optimizer.load_state_dict(state['optimizer']) 28 | model.load_state_dict(state['model']) 29 | print('Resume from %d'%epoch) 30 | epoch += 1 31 | 32 | writer = SummaryWriter() 33 | epoch_pixel_adder = Adder() 34 | epoch_fft_adder = Adder() 35 | iter_pixel_adder = Adder() 36 | iter_fft_adder = Adder() 37 | epoch_timer = Timer('m') 38 | iter_timer = Timer('m') 39 | best_psnr=-1 40 | 41 | for epoch_idx in range(epoch, args.num_epoch + 1): 42 | 43 | epoch_timer.tic() 44 | iter_timer.tic() 45 | for iter_idx, batch_data in enumerate(dataloader): 46 | 47 | input_img, label_img = batch_data 48 | input_img = input_img.to(device) 49 | label_img = label_img.to(device) 50 | 51 | optimizer.zero_grad() 52 | pred_img = model(input_img) 53 | label_img2 = F.interpolate(label_img, scale_factor=0.5, mode='bilinear') 54 | label_img4 = F.interpolate(label_img, scale_factor=0.25, mode='bilinear') 55 | l1 = criterion(pred_img[0], label_img4) 56 | l2 = criterion(pred_img[1], label_img2) 57 | l3 = criterion(pred_img[2], label_img) 58 | loss_content = l1+l2+l3 59 | 60 | label_fft1 = torch.fft.fft2(label_img4, dim=(-2,-1)) 61 | label_fft1 = torch.stack((label_fft1.real, label_fft1.imag), -1) 62 | 63 | pred_fft1 = torch.fft.fft2(pred_img[0], dim=(-2,-1)) 64 | pred_fft1 = torch.stack((pred_fft1.real, pred_fft1.imag), -1) 65 | 66 | label_fft2 = torch.fft.fft2(label_img2, dim=(-2,-1)) 67 | label_fft2 = torch.stack((label_fft2.real, label_fft2.imag), -1) 68 | 69 | pred_fft2 = torch.fft.fft2(pred_img[1], dim=(-2,-1)) 70 | pred_fft2 = torch.stack((pred_fft2.real, pred_fft2.imag), -1) 71 | 72 | label_fft3 = torch.fft.fft2(label_img, dim=(-2,-1)) 73 | label_fft3 = torch.stack((label_fft3.real, label_fft3.imag), -1) 74 | 75 | pred_fft3 = torch.fft.fft2(pred_img[2], dim=(-2,-1)) 76 | pred_fft3 = torch.stack((pred_fft3.real, pred_fft3.imag), -1) 77 | 78 | f1 = criterion(pred_fft1, label_fft1) 79 | f2 = criterion(pred_fft2, label_fft2) 80 | f3 = criterion(pred_fft3, label_fft3) 81 | loss_fft = f1+f2+f3 82 | 83 | loss = loss_content + 0.1 * loss_fft 84 | loss.backward() 85 | torch.nn.utils.clip_grad_norm_(model.parameters(), 0.001) 86 | optimizer.step() 87 | 88 | iter_pixel_adder(loss_content.item()) 89 | iter_fft_adder(loss_fft.item()) 90 | 91 | epoch_pixel_adder(loss_content.item()) 92 | epoch_fft_adder(loss_fft.item()) 93 | 94 | if (iter_idx + 1) % args.print_freq == 0: 95 | print("Time: %7.4f Epoch: %03d Iter: %4d/%4d LR: %.10f Loss content: %7.4f Loss fft: %7.4f" % ( 96 | iter_timer.toc(), epoch_idx, iter_idx + 1, max_iter, scheduler.get_lr()[0], iter_pixel_adder.average(), 97 | iter_fft_adder.average())) 98 | writer.add_scalar('Pixel Loss', iter_pixel_adder.average(), iter_idx + (epoch_idx-1)* max_iter) 99 | writer.add_scalar('FFT Loss', iter_fft_adder.average(), iter_idx + (epoch_idx - 1) * max_iter) 100 | 101 | iter_timer.tic() 102 | iter_pixel_adder.reset() 103 | iter_fft_adder.reset() 104 | overwrite_name = os.path.join(args.model_save_dir, 'model.pkl') 105 | torch.save({'model': model.state_dict(), 106 | 'optimizer': optimizer.state_dict(), 107 | 'epoch': epoch_idx}, overwrite_name) 108 | 109 | if epoch_idx % args.save_freq == 0: 110 | save_name = os.path.join(args.model_save_dir, 'model_%d.pkl' % epoch_idx) 111 | torch.save({'model': model.state_dict()}, save_name) 112 | print("EPOCH: %02d\nElapsed time: %4.2f Epoch Pixel Loss: %7.4f Epoch FFT Loss: %7.4f" % ( 113 | epoch_idx, epoch_timer.toc(), epoch_pixel_adder.average(), epoch_fft_adder.average())) 114 | epoch_fft_adder.reset() 115 | epoch_pixel_adder.reset() 116 | scheduler.step() 117 | if epoch_idx % args.valid_freq == 0: 118 | val = _valid(model, args, epoch_idx) 119 | print('%03d epoch \n Average PSNR %.2f dB' % (epoch_idx, val)) 120 | writer.add_scalar('PSNR', val, epoch_idx) 121 | if val >= best_psnr: 122 | torch.save({'model': model.state_dict()}, os.path.join(args.model_save_dir, 'Best.pkl')) 123 | save_name = os.path.join(args.model_save_dir, 'Final.pkl') 124 | torch.save({'model': model.state_dict()}, save_name) 125 | -------------------------------------------------------------------------------- /Desnowing/utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | 4 | 5 | class Adder(object): 6 | def __init__(self): 7 | self.count = 0 8 | self.num = float(0) 9 | 10 | def reset(self): 11 | self.count = 0 12 | self.num = float(0) 13 | 14 | def __call__(self, num): 15 | self.count += 1 16 | self.num += num 17 | 18 | def average(self): 19 | return self.num / self.count 20 | 21 | 22 | class Timer(object): 23 | def __init__(self, option='s'): 24 | self.tm = 0 25 | self.option = option 26 | if option == 's': 27 | self.devider = 1 28 | elif option == 'm': 29 | self.devider = 60 30 | else: 31 | self.devider = 3600 32 | 33 | def tic(self): 34 | self.tm = time.time() 35 | 36 | def toc(self): 37 | return (time.time() - self.tm) / self.devider 38 | 39 | 40 | def check_lr(optimizer): 41 | for i, param_group in enumerate(optimizer.param_groups): 42 | lr = param_group['lr'] 43 | return lr 44 | -------------------------------------------------------------------------------- /Desnowing/valid.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import functional as F 3 | from data import valid_dataloader 4 | from utils import Adder 5 | import os 6 | from skimage.metrics import peak_signal_noise_ratio 7 | import torch.nn.functional as f 8 | 9 | 10 | def _valid(model, args, ep): 11 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 12 | snow = valid_dataloader(args.data_dir, batch_size=1, num_workers=0) 13 | model.eval() 14 | psnr_adder = Adder() 15 | 16 | with torch.no_grad(): 17 | print('Start Evaluation') 18 | factor = 4 19 | for idx, data in enumerate(snow): 20 | input_img, label_img = data 21 | input_img = input_img.to(device) 22 | 23 | h, w = input_img.shape[2], input_img.shape[3] 24 | H, W = ((h+factor)//factor)*factor, ((w+factor)//factor*factor) 25 | padh = H-h if h%factor!=0 else 0 26 | padw = W-w if w%factor!=0 else 0 27 | input_img = f.pad(input_img, (0, padw, 0, padh), 'reflect') 28 | 29 | if not os.path.exists(os.path.join(args.result_dir, '%d' % (ep))): 30 | os.mkdir(os.path.join(args.result_dir, '%d' % (ep))) 31 | 32 | pred = model(input_img)[2] 33 | pred = pred[:,:,:h,:w] 34 | 35 | pred_clip = torch.clamp(pred, 0, 1) 36 | p_numpy = pred_clip.squeeze(0).cpu().numpy() 37 | label_numpy = label_img.squeeze(0).cpu().numpy() 38 | 39 | psnr = peak_signal_noise_ratio(p_numpy, label_numpy, data_range=1) 40 | 41 | psnr_adder(psnr) 42 | print('\r%03d'%idx, end=' ') 43 | 44 | print('\n') 45 | model.train() 46 | return psnr_adder.average() 47 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Yuning Cui 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Omni-Kernel Network for Image Restoration [AAAI-24] 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | ## Installation 13 | The project is built with PyTorch 3.8, PyTorch 1.8.1. CUDA 10.2, cuDNN 7.6.5 14 | For installing, follow these instructions: 15 | ~~~ 16 | conda install pytorch=1.8.1 torchvision=0.9.1 -c pytorch 17 | pip install tensorboard einops scikit-image pytorch_msssim opencv-python 18 | ~~~ 19 | Install warmup scheduler: 20 | ~~~ 21 | cd pytorch-gradual-warmup-lr/ 22 | python setup.py install 23 | cd .. 24 | ~~~ 25 | 26 | Please download pillow package using Conda instead of pip. 27 | 28 | 29 | 30 | ITS FLOPs: 39.67G, Params: 4.72M 31 | Training and testing details can be found in the individual directories. 32 | 33 | ## [Models](https://drive.google.com/drive/folders/1jrqqTBFHi2XNvBb9-rm9n0fZHfYZabFw?usp=sharing) 34 | ## [Images](https://drive.google.com/drive/folders/1FuaHw5Wr9PTSKAKn2qEuE-IN1Ye-O8Xj?usp=sharing) 35 | 36 | 37 | 38 | ## Citation 39 | ~~~ 40 | @inproceedings{cui2024omni, 41 | title={Omni-Kernel Network for Image Restoration}, 42 | author={Cui, Yuning and Ren, Wenqi and Knoll, Alois}, 43 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 44 | volume={38}, 45 | number={2}, 46 | pages={1426--1434}, 47 | year={2024} 48 | } 49 | ~~~ 50 | 51 | 52 | ## Contact 53 | Should you have any question, please contact Yuning Cui. 54 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/setup.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import setuptools 6 | 7 | _VERSION = '0.3' 8 | 9 | REQUIRED_PACKAGES = [ 10 | ] 11 | 12 | DEPENDENCY_LINKS = [ 13 | ] 14 | 15 | setuptools.setup( 16 | name='warmup_scheduler', 17 | version=_VERSION, 18 | description='Gradually Warm-up LR Scheduler for Pytorch', 19 | install_requires=REQUIRED_PACKAGES, 20 | dependency_links=DEPENDENCY_LINKS, 21 | url='https://github.com/ildoonet/pytorch-gradual-warmup-lr', 22 | license='MIT License', 23 | package_dir={}, 24 | packages=setuptools.find_packages(exclude=['tests']), 25 | ) 26 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /pytorch-gradual-warmup-lr/warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | --------------------------------------------------------------------------------