├── assets ├── README.md ├── framework.png └── overview.png ├── checkpoint └── download_pretrained_weight.txt ├── .gitignore ├── prepare ├── prepare_restormer.py └── prepare_polyu.py ├── prepare.py ├── model.py ├── data.py ├── metric.py ├── adapt ├── zsn2n.py └── nbr2nbr.py ├── README.md └── main.py /assets/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /checkpoint/download_pretrained_weight.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chjinny/LAN/HEAD/assets/framework.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chjinny/LAN/HEAD/assets/overview.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | polyu* 2 | Restormer* 3 | *.pth 4 | __pycache__ 5 | .vscode 6 | *.csv 7 | *.log 8 | *.yaml -------------------------------------------------------------------------------- /prepare/prepare_restormer.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | subprocess.run(["git", "clone", "https://github.com/swz30/Restormer.git"]) -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | # run prepare/prepare_polyu.py 4 | subprocess.run(["python", "prepare/prepare_polyu.py"]) 5 | # run prepare/prepare_restormer.py 6 | subprocess.run(["python", "prepare/prepare_restormer.py"]) 7 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from pathlib import Path 3 | import torch 4 | 5 | def get_model(): 6 | path = 'Restormer.basicsr.models.archs.restormer_arch' 7 | module_spec = importlib.util.spec_from_file_location(path, str(Path().joinpath(*path.split('.')))+'.py') 8 | module = importlib.util.module_from_spec(module_spec) 9 | module_spec.loader.exec_module(module) 10 | model = module.Restormer(LayerNorm_type = 'BiasFree').cuda() 11 | checkpoint = torch.load("./checkpoint/real_denoising.pth")["params"] 12 | model.load_state_dict(checkpoint, strict=False) 13 | model.eval() 14 | return model -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import torchvision 4 | 5 | class Dataset(torch.utils.data.Dataset): 6 | def __init__(self, lq_dir, gt_dir, crop_size=256): 7 | self.lq_dir = Path(lq_dir) 8 | self.gt_dir = Path(gt_dir) 9 | self.crop_size = crop_size 10 | self.lq_paths = sorted(list(self.lq_dir.glob("*.png"))) 11 | self.gt_paths = sorted(list(self.gt_dir.glob("*.png"))) 12 | assert len(self.lq_paths) == len(self.gt_paths) 13 | 14 | def __len__(self): 15 | return len(self.lq_paths) 16 | 17 | def __getitem__(self, idx): 18 | lq_name = self.lq_paths[idx].stem 19 | gt_name = self.gt_paths[idx].stem 20 | assert lq_name == gt_name 21 | lq = torchvision.io.read_image(str(self.lq_paths[idx]))/255.0 22 | gt = torchvision.io.read_image(str(self.gt_paths[idx]))/255.0 23 | lq = lq[:, :self.crop_size, :self.crop_size] 24 | gt = gt[:, :self.crop_size, :self.crop_size] 25 | return lq, gt -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | path = 'Restormer.Denoising.utils' 5 | module_spec = importlib.util.spec_from_file_location(path, str(Path().joinpath(*path.split('.')))+'.py') 6 | module = importlib.util.module_from_spec(module_spec) 7 | module_spec.loader.exec_module(module) 8 | 9 | def cal_psnr(a, b): 10 | a = a.squeeze(0) * 255 11 | b = b.squeeze(0) * 255 12 | a = a.permute(1, 2, 0).cpu().detach().numpy() 13 | b = b.permute(1, 2, 0).cpu().detach().numpy() 14 | return module.calculate_psnr(a, b) 15 | 16 | def cal_ssim(a, b): 17 | a = a.squeeze(0) * 255 18 | b = b.squeeze(0) * 255 19 | a = a.permute(1, 2, 0).cpu().detach().numpy() 20 | b = b.permute(1, 2, 0).cpu().detach().numpy() 21 | return module.calculate_ssim(a, b) 22 | 23 | def cal_batch_psnr_ssim(pred, gt): 24 | psnr = [] 25 | ssim = [] 26 | for i in range(pred.shape[0]): 27 | psnr.append(cal_psnr(pred[i:i+1], gt[i:i+1])) 28 | ssim.append(cal_ssim(pred[i:i+1], gt[i:i+1])) 29 | return psnr, ssim -------------------------------------------------------------------------------- /prepare/prepare_polyu.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import torchvision 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | 6 | subprocess.run(["git", "clone", "https://github.com/csjunxu/PolyU-Real-World-Noisy-Images-Dataset.git"]) 7 | 8 | patch_size = 512 9 | root_dir = Path("PolyU-Real-World-Noisy-Images-Dataset/OriginalImages") 10 | output_dir = Path("polyu") 11 | 12 | gt_dir = output_dir / "gt" 13 | lq_dir = output_dir / "lq" 14 | for lq_path, gt_path in tqdm(list(zip(sorted(root_dir.glob("*Real.JPG")), sorted(root_dir.glob("*mean.JPG"))))): 15 | lq_name = "_".join(lq_path.stem.split("_")[:-1]) 16 | gt_name = "_".join(gt_path.stem.split("_")[:-1]) 17 | assert lq_name == gt_name 18 | lq = torchvision.io.read_image(str(lq_path))/255.0 19 | gt = torchvision.io.read_image(str(gt_path))/255.0 20 | lq_patches = lq.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size) 21 | gt_patches = gt.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size) 22 | for i, (lq_patch, gt_patch) in enumerate(zip(lq_patches, gt_patches)): 23 | lq_dir.mkdir(parents=True, exist_ok=True) 24 | gt_dir.mkdir(parents=True, exist_ok=True) 25 | torchvision.utils.save_image(lq_patch, str(lq_dir / f"{lq_name}_{str(i).zfill(4)}.png")) 26 | torchvision.utils.save_image(gt_patch, str(gt_dir / f"{gt_name}_{str(i).zfill(4)}.png")) 27 | 28 | subprocess.run(["rm", "-rf", "PolyU-Real-World-Noisy-Images-Dataset"]) -------------------------------------------------------------------------------- /adapt/zsn2n.py: -------------------------------------------------------------------------------- 1 | # https://colab.research.google.com/drive/1i82nyizTdszyHkaHBuKPbWnTzao8HF9b?usp=sharing#scrollTo=6gYE_YyWFnIS 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | def pair_downsampler(img): 8 | #img has shape B C H W 9 | c = img.shape[1] 10 | 11 | filter1 = torch.FloatTensor([[[[0 ,0.5],[0.5, 0]]]]).to(img.device) 12 | filter1 = filter1.repeat(c,1, 1, 1) 13 | 14 | filter2 = torch.FloatTensor([[[[0.5 ,0],[0, 0.5]]]]).to(img.device) 15 | filter2 = filter2.repeat(c,1, 1, 1) 16 | 17 | output1 = F.conv2d(img, filter1, stride=2, groups=c) 18 | output2 = F.conv2d(img, filter2, stride=2, groups=c) 19 | 20 | return output1, output2 21 | 22 | def mse(gt: torch.Tensor, pred:torch.Tensor)-> torch.Tensor: 23 | loss = torch.nn.MSELoss() 24 | return loss(gt,pred) 25 | 26 | def loss_func(noisy_img, model, *args, **kwargs): 27 | noisy1, noisy2 = pair_downsampler(noisy_img) 28 | 29 | # pred1 = noisy1 - model(noisy1) 30 | # pred2 = noisy2 - model(noisy2) 31 | # model includes residual 32 | pred1 = model(noisy1) 33 | pred2 = model(noisy2) 34 | 35 | loss_res = 1/2*(mse(noisy1,pred2)+mse(noisy2,pred1)) 36 | 37 | # noisy_denoised = noisy_img - model(noisy_img) 38 | # model includes residual 39 | noisy_denoised = model(noisy_img) 40 | denoised1, denoised2 = pair_downsampler(noisy_denoised) 41 | 42 | loss_cons=1/2*(mse(pred1,denoised1) + mse(pred2,denoised2)) 43 | 44 | loss = loss_res + loss_cons 45 | 46 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAN: Learning to Adapt Noise for Image Denoising 2 | Changjin Kim, Tae Hyun Kim, Sungyong Baik 3 | 4 | A sample code for our framework. 5 | 6 | Paper link : [[CVPR]](https://openaccess.thecvf.com/content/CVPR2024/html/Kim_LAN_Learning_to_Adapt_Noise_for_Image_Denoising_CVPR_2024_paper.html) 7 | 8 | 9 | > **Abstract:** *Removing noise from images a.k.a image denoising can be a very challenging task since the type and amount of noise can greatly vary for each image due to many factors including a camera model and capturing environments. While there have been striking improvements in image denoising with the emergence of advanced deep learning architectures and real-world datasets recent denoising networks struggle to maintain performance on images with noise that has not been seen during training. One typical approach to address the challenge would be to adapt a denoising network to new noise distribution. Instead in this work we shift our attention to the input noise itself for adaptation rather than adapting a network. Thus we keep a pretrained network frozen and adapt an input noise to capture the fine-grained deviations. As such we propose a new denoising algorithm dubbed Learning-to-Adapt-Noise (LAN) where a learnable noise offset is directly added to a given noisy image to bring a given input noise closer towards the noise distribution a denoising network is trained to handle. Consequently the proposed framework exhibits performance improvement on images with unseen noise displaying the potential of the proposed research direction.* 10 | 11 | ## Table of Contents 12 | - [Overview](#overview) 13 | - [Prepare Model and Dataset](#prepare-model-and-dataset) 14 | - [Adaptation](#adaptation) 15 | - [Citation](#citation) 16 | - [Acknowledgement](#acknowledgement) 17 | 18 | 19 | ## Overview 20 |
21 |
22 |
24 |
25 |