├── assets ├── README.md ├── framework.png └── overview.png ├── checkpoint └── download_pretrained_weight.txt ├── .gitignore ├── prepare ├── prepare_restormer.py └── prepare_polyu.py ├── prepare.py ├── model.py ├── data.py ├── metric.py ├── adapt ├── zsn2n.py └── nbr2nbr.py ├── README.md └── main.py /assets/README.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /checkpoint/download_pretrained_weight.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chjinny/LAN/HEAD/assets/framework.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chjinny/LAN/HEAD/assets/overview.png -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | polyu* 2 | Restormer* 3 | *.pth 4 | __pycache__ 5 | .vscode 6 | *.csv 7 | *.log 8 | *.yaml -------------------------------------------------------------------------------- /prepare/prepare_restormer.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | subprocess.run(["git", "clone", "https://github.com/swz30/Restormer.git"]) -------------------------------------------------------------------------------- /prepare.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | 3 | # run prepare/prepare_polyu.py 4 | subprocess.run(["python", "prepare/prepare_polyu.py"]) 5 | # run prepare/prepare_restormer.py 6 | subprocess.run(["python", "prepare/prepare_restormer.py"]) 7 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from pathlib import Path 3 | import torch 4 | 5 | def get_model(): 6 | path = 'Restormer.basicsr.models.archs.restormer_arch' 7 | module_spec = importlib.util.spec_from_file_location(path, str(Path().joinpath(*path.split('.')))+'.py') 8 | module = importlib.util.module_from_spec(module_spec) 9 | module_spec.loader.exec_module(module) 10 | model = module.Restormer(LayerNorm_type = 'BiasFree').cuda() 11 | checkpoint = torch.load("./checkpoint/real_denoising.pth")["params"] 12 | model.load_state_dict(checkpoint, strict=False) 13 | model.eval() 14 | return model -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from pathlib import Path 3 | import torchvision 4 | 5 | class Dataset(torch.utils.data.Dataset): 6 | def __init__(self, lq_dir, gt_dir, crop_size=256): 7 | self.lq_dir = Path(lq_dir) 8 | self.gt_dir = Path(gt_dir) 9 | self.crop_size = crop_size 10 | self.lq_paths = sorted(list(self.lq_dir.glob("*.png"))) 11 | self.gt_paths = sorted(list(self.gt_dir.glob("*.png"))) 12 | assert len(self.lq_paths) == len(self.gt_paths) 13 | 14 | def __len__(self): 15 | return len(self.lq_paths) 16 | 17 | def __getitem__(self, idx): 18 | lq_name = self.lq_paths[idx].stem 19 | gt_name = self.gt_paths[idx].stem 20 | assert lq_name == gt_name 21 | lq = torchvision.io.read_image(str(self.lq_paths[idx]))/255.0 22 | gt = torchvision.io.read_image(str(self.gt_paths[idx]))/255.0 23 | lq = lq[:, :self.crop_size, :self.crop_size] 24 | gt = gt[:, :self.crop_size, :self.crop_size] 25 | return lq, gt -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import importlib 3 | 4 | path = 'Restormer.Denoising.utils' 5 | module_spec = importlib.util.spec_from_file_location(path, str(Path().joinpath(*path.split('.')))+'.py') 6 | module = importlib.util.module_from_spec(module_spec) 7 | module_spec.loader.exec_module(module) 8 | 9 | def cal_psnr(a, b): 10 | a = a.squeeze(0) * 255 11 | b = b.squeeze(0) * 255 12 | a = a.permute(1, 2, 0).cpu().detach().numpy() 13 | b = b.permute(1, 2, 0).cpu().detach().numpy() 14 | return module.calculate_psnr(a, b) 15 | 16 | def cal_ssim(a, b): 17 | a = a.squeeze(0) * 255 18 | b = b.squeeze(0) * 255 19 | a = a.permute(1, 2, 0).cpu().detach().numpy() 20 | b = b.permute(1, 2, 0).cpu().detach().numpy() 21 | return module.calculate_ssim(a, b) 22 | 23 | def cal_batch_psnr_ssim(pred, gt): 24 | psnr = [] 25 | ssim = [] 26 | for i in range(pred.shape[0]): 27 | psnr.append(cal_psnr(pred[i:i+1], gt[i:i+1])) 28 | ssim.append(cal_ssim(pred[i:i+1], gt[i:i+1])) 29 | return psnr, ssim -------------------------------------------------------------------------------- /prepare/prepare_polyu.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import torchvision 3 | from pathlib import Path 4 | from tqdm import tqdm 5 | 6 | subprocess.run(["git", "clone", "https://github.com/csjunxu/PolyU-Real-World-Noisy-Images-Dataset.git"]) 7 | 8 | patch_size = 512 9 | root_dir = Path("PolyU-Real-World-Noisy-Images-Dataset/OriginalImages") 10 | output_dir = Path("polyu") 11 | 12 | gt_dir = output_dir / "gt" 13 | lq_dir = output_dir / "lq" 14 | for lq_path, gt_path in tqdm(list(zip(sorted(root_dir.glob("*Real.JPG")), sorted(root_dir.glob("*mean.JPG"))))): 15 | lq_name = "_".join(lq_path.stem.split("_")[:-1]) 16 | gt_name = "_".join(gt_path.stem.split("_")[:-1]) 17 | assert lq_name == gt_name 18 | lq = torchvision.io.read_image(str(lq_path))/255.0 19 | gt = torchvision.io.read_image(str(gt_path))/255.0 20 | lq_patches = lq.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size) 21 | gt_patches = gt.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size).permute(1, 2, 0, 3, 4).reshape(-1, 3, patch_size, patch_size) 22 | for i, (lq_patch, gt_patch) in enumerate(zip(lq_patches, gt_patches)): 23 | lq_dir.mkdir(parents=True, exist_ok=True) 24 | gt_dir.mkdir(parents=True, exist_ok=True) 25 | torchvision.utils.save_image(lq_patch, str(lq_dir / f"{lq_name}_{str(i).zfill(4)}.png")) 26 | torchvision.utils.save_image(gt_patch, str(gt_dir / f"{gt_name}_{str(i).zfill(4)}.png")) 27 | 28 | subprocess.run(["rm", "-rf", "PolyU-Real-World-Noisy-Images-Dataset"]) -------------------------------------------------------------------------------- /adapt/zsn2n.py: -------------------------------------------------------------------------------- 1 | # https://colab.research.google.com/drive/1i82nyizTdszyHkaHBuKPbWnTzao8HF9b?usp=sharing#scrollTo=6gYE_YyWFnIS 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | 7 | def pair_downsampler(img): 8 | #img has shape B C H W 9 | c = img.shape[1] 10 | 11 | filter1 = torch.FloatTensor([[[[0 ,0.5],[0.5, 0]]]]).to(img.device) 12 | filter1 = filter1.repeat(c,1, 1, 1) 13 | 14 | filter2 = torch.FloatTensor([[[[0.5 ,0],[0, 0.5]]]]).to(img.device) 15 | filter2 = filter2.repeat(c,1, 1, 1) 16 | 17 | output1 = F.conv2d(img, filter1, stride=2, groups=c) 18 | output2 = F.conv2d(img, filter2, stride=2, groups=c) 19 | 20 | return output1, output2 21 | 22 | def mse(gt: torch.Tensor, pred:torch.Tensor)-> torch.Tensor: 23 | loss = torch.nn.MSELoss() 24 | return loss(gt,pred) 25 | 26 | def loss_func(noisy_img, model, *args, **kwargs): 27 | noisy1, noisy2 = pair_downsampler(noisy_img) 28 | 29 | # pred1 = noisy1 - model(noisy1) 30 | # pred2 = noisy2 - model(noisy2) 31 | # model includes residual 32 | pred1 = model(noisy1) 33 | pred2 = model(noisy2) 34 | 35 | loss_res = 1/2*(mse(noisy1,pred2)+mse(noisy2,pred1)) 36 | 37 | # noisy_denoised = noisy_img - model(noisy_img) 38 | # model includes residual 39 | noisy_denoised = model(noisy_img) 40 | denoised1, denoised2 = pair_downsampler(noisy_denoised) 41 | 42 | loss_cons=1/2*(mse(pred1,denoised1) + mse(pred2,denoised2)) 43 | 44 | loss = loss_res + loss_cons 45 | 46 | return loss -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LAN: Learning to Adapt Noise for Image Denoising 2 | Changjin Kim, Tae Hyun Kim, Sungyong Baik 3 | 4 | A sample code for our framework. 5 | 6 | Paper link : [[CVPR]](https://openaccess.thecvf.com/content/CVPR2024/html/Kim_LAN_Learning_to_Adapt_Noise_for_Image_Denoising_CVPR_2024_paper.html) 7 | 8 | 9 | > **Abstract:** *Removing noise from images a.k.a image denoising can be a very challenging task since the type and amount of noise can greatly vary for each image due to many factors including a camera model and capturing environments. While there have been striking improvements in image denoising with the emergence of advanced deep learning architectures and real-world datasets recent denoising networks struggle to maintain performance on images with noise that has not been seen during training. One typical approach to address the challenge would be to adapt a denoising network to new noise distribution. Instead in this work we shift our attention to the input noise itself for adaptation rather than adapting a network. Thus we keep a pretrained network frozen and adapt an input noise to capture the fine-grained deviations. As such we propose a new denoising algorithm dubbed Learning-to-Adapt-Noise (LAN) where a learnable noise offset is directly added to a given noisy image to bring a given input noise closer towards the noise distribution a denoising network is trained to handle. Consequently the proposed framework exhibits performance improvement on images with unseen noise displaying the potential of the proposed research direction.* 10 | 11 | ## Table of Contents 12 | - [Overview](#overview) 13 | - [Prepare Model and Dataset](#prepare-model-and-dataset) 14 | - [Adaptation](#adaptation) 15 | - [Citation](#citation) 16 | - [Acknowledgement](#acknowledgement) 17 | 18 | 19 | ## Overview 20 |

21 | 22 |

23 |

24 | 25 |

26 | 27 | ## Prepare Model and Dataset 28 | ```bash 29 | git clone https://github.com/chjinny/LAN.git 30 | python prepare.py 31 | ``` 32 | - Dataset : [PolyU](https://github.com/csjunxu/PolyU-Real-World-Noisy-Images-Dataset) 33 | - Model : [Restormer](https://github.com/swz30/Restorme) 34 | - Download the [pretrained weight file](https://drive.google.com/drive/folders/1Qwsjyny54RZWa7zC4Apg7exixLBo4uF0) and place it as ```./checkpoint/real_denoising.pth```. 35 | 36 | ## Adaptation 37 | ```bash 38 | python main.py --method {lan, finetune} --self-loss {zsn2n, nbr2nbr} 39 | ``` 40 | 41 | ## Citation 42 | ```bibtex 43 | @inproceedings{kim2024lan, 44 | title={LAN: Learning to Adapt Noise for Image Denoising}, 45 | author={Kim, Changjin and Kim, Tae Hyun and Baik, Sungyong}, 46 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 47 | pages={25193--25202}, 48 | year={2024} 49 | } 50 | ``` 51 | 52 | ## Acknowledgement 53 | 54 | The codes are based on follows: 55 | - [Restormer](https://github.com/swz30/Restormer) 56 | - [Neighbor2Neighbor](https://github.com/TaoHuang2018/Neighbor2Neighbor) 57 | - [Zero-Shot Noisr2Noise](https://colab.research.google.com/drive/1i82nyizTdszyHkaHBuKPbWnTzao8HF9b?usp=sharing) 58 | 59 | We thank the authors for sharing their codes. 60 | -------------------------------------------------------------------------------- /adapt/nbr2nbr.py: -------------------------------------------------------------------------------- 1 | # https://github.com/TaoHuang2018/Neighbor2Neighbor/blob/main/train.py 2 | 3 | import torch 4 | 5 | operation_seed_counter = 0 6 | 7 | def get_generator(): 8 | global operation_seed_counter 9 | operation_seed_counter += 1 10 | g_cuda_generator = torch.Generator(device="cuda") 11 | g_cuda_generator.manual_seed(operation_seed_counter) 12 | return g_cuda_generator 13 | 14 | def space_to_depth(x, block_size): 15 | n, c, h, w = x.size() 16 | unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size) 17 | return unfolded_x.view(n, c * block_size**2, h // block_size, 18 | w // block_size) 19 | 20 | def generate_mask_pair(img): 21 | # prepare masks (N x C x H/2 x W/2) 22 | n, c, h, w = img.shape 23 | mask1 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ), 24 | dtype=torch.bool, 25 | device=img.device) 26 | mask2 = torch.zeros(size=(n * h // 2 * w // 2 * 4, ), 27 | dtype=torch.bool, 28 | device=img.device) 29 | # prepare random mask pairs 30 | idx_pair = torch.tensor( 31 | [[0, 1], [0, 2], [1, 3], [2, 3], [1, 0], [2, 0], [3, 1], [3, 2]], 32 | dtype=torch.int64, 33 | device=img.device) 34 | rd_idx = torch.zeros(size=(n * h // 2 * w // 2, ), 35 | dtype=torch.int64, 36 | device=img.device) 37 | torch.randint(low=0, 38 | high=8, 39 | size=(n * h // 2 * w // 2, ), 40 | generator=get_generator(), 41 | out=rd_idx) 42 | rd_pair_idx = idx_pair[rd_idx] 43 | rd_pair_idx += torch.arange(start=0, 44 | end=n * h // 2 * w // 2 * 4, 45 | step=4, 46 | dtype=torch.int64, 47 | device=img.device).reshape(-1, 1) 48 | # get masks 49 | mask1[rd_pair_idx[:, 0]] = 1 50 | mask2[rd_pair_idx[:, 1]] = 1 51 | return mask1, mask2 52 | 53 | def generate_subimages(img, mask): 54 | n, c, h, w = img.shape 55 | subimage = torch.zeros(n, 56 | c, 57 | h // 2, 58 | w // 2, 59 | dtype=img.dtype, 60 | layout=img.layout, 61 | device=img.device) 62 | # per channel 63 | for i in range(c): 64 | img_per_channel = space_to_depth(img[:, i:i + 1, :, :], block_size=2) 65 | img_per_channel = img_per_channel.permute(0, 2, 3, 1).reshape(-1) 66 | subimage[:, i:i + 1, :, :] = img_per_channel[mask].reshape( 67 | n, h // 2, w // 2, 1).permute(0, 3, 1, 2) 68 | return subimage 69 | 70 | def loss_func(noisy, network, tmp_iter, max_iter): 71 | mask1, mask2 = generate_mask_pair(noisy) 72 | noisy_sub1 = generate_subimages(noisy, mask1) 73 | noisy_sub2 = generate_subimages(noisy, mask2) 74 | with torch.no_grad(): 75 | noisy_denoised = network(noisy).clip(0, 1) 76 | noisy_sub1_denoised = generate_subimages(noisy_denoised, mask1) 77 | noisy_sub2_denoised = generate_subimages(noisy_denoised, mask2) 78 | 79 | noisy_output = network(noisy_sub1).clip(0, 1) 80 | noisy_target = noisy_sub2 81 | Lambda = tmp_iter / max_iter * 0.1 82 | diff = noisy_output - noisy_target 83 | exp_diff = noisy_sub1_denoised - noisy_sub2_denoised 84 | 85 | loss1 = torch.mean(diff**2) 86 | loss2 = Lambda * torch.mean((diff - exp_diff)**2) 87 | loss_all = 1.0 * loss1 + 1.0 * loss2 88 | 89 | return loss_all -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from data import Dataset 3 | from model import get_model 4 | from metric import cal_batch_psnr_ssim 5 | import pandas as pd 6 | from tqdm import tqdm 7 | import argparse 8 | from adapt import zsn2n, nbr2nbr 9 | import numpy as np 10 | 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--method", type=str, required=True, choices=["finetune", "lan"]) 13 | parser.add_argument("--self_loss", type=str, required=True, choices=["nbr2nbr", "zsn2n"]) 14 | args = parser.parse_args() 15 | 16 | if args.self_loss == "zsn2n": 17 | loss_func = zsn2n.loss_func 18 | elif args.self_loss == "nbr2nbr": 19 | loss_func = nbr2nbr.loss_func 20 | else: 21 | raise NotImplementedError 22 | 23 | model_generator = get_model 24 | model = model_generator() 25 | for param in model.parameters(): 26 | param.requires_grad = args.method == "finetune" 27 | print("trainable model parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad)) 28 | 29 | dataloader = torch.utils.data.DataLoader(Dataset("polyu/lq", "polyu/gt"), batch_size=1, shuffle=False) 30 | lr = 5e-4 if args.method == "lan" else 5e-6 31 | class Lan(torch.nn.Module): 32 | def __init__(self, shape): 33 | super(Lan, self).__init__() 34 | self.phi = torch.nn.parameter.Parameter(torch.zeros(shape), requires_grad=True) 35 | def forward(self, x): 36 | return x + torch.tanh(self.phi) 37 | 38 | logs_key = ["psnr", "ssim"] 39 | total_logs = {key: [] for key in logs_key} 40 | inner_loop = 20 41 | p_bar = tqdm(dataloader, ncols=100, desc=f"{args.method}_{args.self_loss}") 42 | for lq, gt in p_bar: 43 | lq = lq.cuda() 44 | gt = gt.cuda() 45 | lan = Lan(lq.shape).cuda() if args.method == "lan" else torch.nn.Identity() 46 | tmp_batch_size = lq.shape[0] 47 | model = model_generator() 48 | for param in model.parameters(): 49 | param.requires_grad = args.method == "finetune" 50 | 51 | params = list(lan.parameters()) if args.method == "lan" else list(model.parameters()) 52 | optimizer = torch.optim.Adam(params, lr=lr) 53 | logs = {key: [] for key in logs_key} 54 | 55 | for i in range(inner_loop): 56 | optimizer.zero_grad() 57 | adapted_lq = lan(lq) 58 | with torch.no_grad(): 59 | pred = model(adapted_lq).clip(0, 1) 60 | loss = loss_func(adapted_lq, model, i, inner_loop) 61 | loss.backward() 62 | optimizer.step() 63 | psnr, ssim = cal_batch_psnr_ssim(pred, gt) 64 | for key in logs_key: 65 | logs[key].append(locals()[key]) 66 | else: 67 | with torch.no_grad(): 68 | adapted_lq = lan(lq) 69 | pred = model(adapted_lq).clip(0, 1) 70 | psnr, ssim = cal_batch_psnr_ssim(pred, gt) 71 | for key in logs_key: 72 | logs[key].append(locals()[key]) 73 | 74 | for key in logs_key: 75 | total_logs[key].extend(np.array(logs[key]).transpose()) 76 | p_bar.set_postfix( 77 | PSNR=f"{np.array(total_logs['psnr']).mean(0)[0]:.2f}->{np.array(total_logs['psnr']).mean(0)[-1]:.2f}", 78 | SSIM=f"{np.array(total_logs['ssim']).mean(0)[0]:.3f}->{np.array(total_logs['ssim']).mean(0)[-1]:.3f}" 79 | ) 80 | df_dict = { 81 | "idx": [i for i in range(len(total_logs['psnr'])) for _ in range(inner_loop+1)], 82 | "loop": [i for i in range(inner_loop+1)] * len(total_logs['psnr']), 83 | } 84 | for key in logs_key: 85 | df_dict[key] = [value for value_list in total_logs[key] for value in value_list] 86 | df = pd.DataFrame(df_dict) 87 | df.to_csv(f"result_{args.method}_{args.self_loss}.csv", index=False) 88 | print(df.groupby('loop').mean()[['psnr', 'ssim']]) --------------------------------------------------------------------------------