├── utils ├── __init__.py ├── degradation_utils.py ├── val_utils.py ├── pytorch_ssim │ └── __init__.py ├── loss_utils.py ├── image_utils.py ├── image_io.py ├── imresize.py ├── schedulers.py └── dataset_utils.py ├── requiements.txt ├── .gitignore └── README.md /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /requiements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | 3 | einops==0.8.0 4 | fvcore 5 | lightning==2.0.1 6 | lightning-cloud==0.5.32 7 | lightning-utilities==0.8.0 8 | matplotlib 9 | ninja==1.11.1 10 | numpy==1.26.3 11 | opencv-python==4.7.0.68 12 | pandas==1.5.3 13 | pydantic==1.10.7 14 | scikit-image 15 | scikit-video 16 | scikit-learn 17 | seaborn 18 | tensorboard 19 | timm 20 | torch==2.0.0 21 | torchaudio==2.0.1 22 | torchvision==0.15.1 23 | torchmetrics 24 | triton==2.0.0 25 | tqdm 26 | wandb==0.18.7 27 | warmup_scheduler==0.3 28 | lpips 29 | typeguard 30 | gdown -------------------------------------------------------------------------------- /utils/degradation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor, Grayscale 3 | 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | from utils.image_utils import crop_img 9 | 10 | 11 | class Degradation(object): 12 | def __init__(self, args): 13 | super(Degradation, self).__init__() 14 | self.args = args 15 | self.toTensor = ToTensor() 16 | self.crop_transform = Compose([ 17 | ToPILImage(), 18 | RandomCrop(args.patch_size), 19 | ]) 20 | 21 | def _add_gaussian_noise(self, clean_patch, sigma): 22 | # noise = torch.randn(*(clean_patch.shape)) 23 | # clean_patch = self.toTensor(clean_patch) 24 | noise = np.random.randn(*clean_patch.shape) 25 | noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8) 26 | # noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32) 27 | return noisy_patch, clean_patch 28 | 29 | def _degrade_by_type(self, clean_patch, degrade_type): 30 | if degrade_type == 0: 31 | # denoise sigma=15 32 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15) 33 | elif degrade_type == 1: 34 | # denoise sigma=25 35 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25) 36 | elif degrade_type == 2: 37 | # denoise sigma=50 38 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50) 39 | 40 | return degraded_patch, clean_patch 41 | 42 | def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None): 43 | if degrade_type == None: 44 | degrade_type = random.randint(0, 3) 45 | else: 46 | degrade_type = degrade_type 47 | 48 | degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type) 49 | degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type) 50 | return degrad_patch_1, degrad_patch_2 51 | 52 | def single_degrade(self,clean_patch,degrade_type = None): 53 | if degrade_type == None: 54 | degrade_type = random.randint(0, 3) 55 | else: 56 | degrade_type = degrade_type 57 | 58 | degrad_patch_1, _ = self._degrade_by_type(clean_patch, degrade_type) 59 | return degrad_patch_1 60 | -------------------------------------------------------------------------------- /utils/val_utils.py: -------------------------------------------------------------------------------- 1 | 2 | import time 3 | import numpy as np 4 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 5 | from skvideo.measure import niqe 6 | 7 | 8 | class AverageMeter(): 9 | """ Computes and stores the average and current value """ 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | """ Reset all statistics """ 16 | self.val = 0 17 | self.avg = 0 18 | self.sum = 0 19 | self.count = 0 20 | 21 | def update(self, val, n=1): 22 | """ Update statistics """ 23 | self.val = val 24 | self.sum += val * n 25 | self.count += n 26 | self.avg = self.sum / self.count 27 | 28 | 29 | def accuracy(output, target, topk=(1,)): 30 | """ Computes the precision@k for the specified values of k """ 31 | maxk = max(topk) 32 | batch_size = target.size(0) 33 | 34 | _, pred = output.topk(maxk, 1, True, True) 35 | pred = pred.t() 36 | # one-hot case 37 | if target.ndimension() > 1: 38 | target = target.max(1)[1] 39 | 40 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 41 | 42 | res = [] 43 | for k in topk: 44 | correct_k = correct[:k].view(-1).float().sum(0) 45 | res.append(correct_k.mul_(1.0 / batch_size)) 46 | 47 | return res 48 | 49 | 50 | def compute_psnr_ssim(recoverd, clean): 51 | assert recoverd.shape == clean.shape 52 | recoverd = np.clip(recoverd.detach().cpu().numpy(), 0, 1) 53 | clean = np.clip(clean.detach().cpu().numpy(), 0, 1) 54 | 55 | recoverd = recoverd.transpose(0, 2, 3, 1) 56 | clean = clean.transpose(0, 2, 3, 1) 57 | psnr = 0 58 | ssim = 0 59 | 60 | for i in range(recoverd.shape[0]): 61 | # psnr_val += compare_psnr(clean[i], recoverd[i]) 62 | # ssim += compare_ssim(clean[i], recoverd[i], multichannel=True) 63 | psnr += peak_signal_noise_ratio(clean[i], recoverd[i], data_range=1) 64 | # ssim += structural_similarity(clean[i], recoverd[i], data_range=1, multichannel=True) 65 | ssim += structural_similarity(clean[i], recoverd[i], data_range=1, channel_axis=2) 66 | return psnr / recoverd.shape[0], ssim / recoverd.shape[0], recoverd.shape[0] 67 | 68 | 69 | def compute_niqe(image): 70 | image = np.clip(image.detach().cpu().numpy(), 0, 1) 71 | image = image.transpose(0, 2, 3, 1) 72 | niqe_val = niqe(image) 73 | 74 | return niqe_val.mean() 75 | 76 | class timer(): 77 | def __init__(self): 78 | self.acc = 0 79 | self.tic() 80 | 81 | def tic(self): 82 | self.t0 = time.time() 83 | 84 | def toc(self): 85 | return time.time() - self.t0 86 | 87 | def hold(self): 88 | self.acc += self.toc() 89 | 90 | def release(self): 91 | ret = self.acc 92 | self.acc = 0 93 | 94 | return ret 95 | 96 | def reset(self): 97 | self.acc = 0 -------------------------------------------------------------------------------- /utils/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | # Matlab style 1D gaussian filter. 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 10 | return gauss/gauss.sum() 11 | 12 | # Matlab style n_D gaussian filter. 13 | def create_window(window_size, channel): 14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 17 | return window 18 | 19 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 22 | 23 | mu1_sq = mu1.pow(2) 24 | mu2_sq = mu2.pow(2) 25 | mu1_mu2 = mu1*mu2 26 | 27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 30 | 31 | C1 = 0.01**2 32 | C2 = 0.03**2 33 | 34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 35 | 36 | 37 | # I added this for sm 38 | # ssim_map = torch.exp(1 + ssim_map) 39 | 40 | if size_average: 41 | return ssim_map.mean() 42 | else: 43 | return ssim_map.mean(1).mean(1).mean(1) 44 | 45 | class SSIM(torch.nn.Module): 46 | def __init__(self, window_size = 11, size_average = True): 47 | super(SSIM, self).__init__() 48 | self.window_size = window_size 49 | self.size_average = size_average 50 | self.channel = 1 51 | self.window = create_window(window_size, self.channel) 52 | 53 | def forward(self, img1, img2): 54 | (_, channel, _, _) = img1.size() 55 | 56 | if channel == self.channel and self.window.data.type() == img1.data.type(): 57 | window = self.window 58 | else: 59 | window = create_window(self.window_size, channel) 60 | 61 | if img1.is_cuda: 62 | window = window.cuda(img1.get_device()) 63 | window = window.type_as(img1) 64 | 65 | self.window = window 66 | self.channel = channel 67 | 68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 69 | 70 | def ssim(img1, img2, window_size = 11, size_average = True): 71 | (_, channel, _, _) = img1.size() 72 | window = create_window(window_size, channel) 73 | 74 | if img1.is_cuda: 75 | window = window.cuda(img1.get_device()) 76 | window = window.type_as(img1) 77 | 78 | return _ssim(img1, img2, window, window_size, channel, size_average) 79 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | /wandb 86 | /logs 87 | /train_ckpt 88 | /joblogs 89 | /output 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AnyIR 2 | ### Any Image Restoration via Efficient Spatial-Frequency Degradation Adaptation 3 | 4 | The official PyTorch Implementation of AnyIR for All-in-One Image Restoration 5 | 6 | #### [Bin Ren 1,2,3](https://amazingren.github.io/), [Eduard Zamfir4](https://eduardzamfir.github.io), [Zongwei Wu4](https://sites.google.com/view/zwwu/accueil), [Yawei Li 4](https://yaweili.bitbucket.io/), [Yidi Li3](https://liyidi.github.io/), [Danda Pani Paudel3](https://people.ee.ethz.ch/~paudeld/), [Radu Timofte 4](https://www.informatik.uni-wuerzburg.de/computervision/), [Ming-Hsuan Yang 7](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), [Luc Van Gool 3](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en), and [Nicu Sebe 2](https://scholar.google.com/citations?user=stFCYOAAAAAJ&hl=en) 7 | 8 | 1 University of Pisa, Italy,
9 | 2 University of Trento, Italy,
10 | 3 INSAIT Sofia University, "St. Kliment Ohridski", Bulgaria,
11 | 4 University of Würzburg, Germany,
12 | 5 ETH Zürich, Switzerland,
13 | 6 Taiyuan University of Technology, China,
14 | 7 University of California, Merced, USA
15 | 16 | 17 | [![paper](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/pdf/2407.13372) 18 | 19 | 20 | ## Latest 21 | - `08/01/2025`: Unfortunately, though this work was recommented `Accept` by the AC from the ACM MM2025, while the PC finally reject this work without reason, so this work is still under review. 22 | - `07/18/2024`: Repository is created. Our code will be made publicly available upon acceptance. 23 | 24 | 25 | ## Method 26 |
27 |
28 | 29 | Abstract 30 | 31 | Restoring any degraded image efficiently via just one model has become increasingly significant and impactful, especially with the proliferation of mobile devices. Traditional solutions typically involve training dedicated models per degradation, resulting in inefficiency and redundancy. More recent approaches either introduce additional modules to learn visual prompts - significantly increasing the size of the model - or incorporate cross-modal transfer from large language models trained on vast datasets, adding complexity to the system architecture. In contrast, our approach, termed AnyIR, takes a unified path that leverages inherent similarity across various degradations to enable both efficient and comprehensive restoration through a joint embedding mechanism, without scaling up the model or relying on large language models. 32 | Specifically, we examine the sub-latent space of each input, identifying key components and reweighting them first in a gated manner. To fuse intrinsic degradation awareness and contextualized attention, a spatial-frequency parallel fusion strategy is proposed to enhance spatially aware local-global interactions and enrich restoration details from the frequency perspective. Extensive benchmarking in the all-in-one restoration setting confirms AnyIR’s SOTA performance, reducing model complexity by around \textbf{82\%} in parameters and \textbf{85\%} in FLOPs compared to the baseline solution. 33 | Our code will be available upon acceptance. 34 |
35 | 36 | 37 | ## Installation 38 | 39 | ### Environments 40 | ``` 41 | # Step1: Create the virtual environments via micromamba or conda: 42 | micromamba create -n anyir python=3.9 -y 43 | or 44 | conda create -n anyir python=3.9 -y 45 | 46 | # Step2: Prepare PyTorch and other libs 47 | pip install -r requirements.txt 48 | 49 | # Step3: Set cuda 50 | export LD_LIBRARY_PATH=/opt/modules/nvidia-cuda-11.8/lib64:$LD_LIBRARY_PATH 51 | export PATH=/opt/modules/nvidia-cuda-11.8/bin:$PATH 52 | 53 | ``` 54 | 55 | 56 | ### Datasets 57 | 58 | 59 | ### Checkpoints Downloads: 60 | 61 | 62 | ### Visual Results Downloads: 63 | 64 | 65 | ### Training 66 | 67 | 68 | ### Evaluation: 69 | (I). 3-Degradation Setting: 70 | 71 | (II). 5-Degradation Setting: 72 | 73 | (III). Mix-Degradation Setting: 74 | 75 | (IV). Real-World Setting: 76 | 77 | 78 | 79 | 80 | 81 | ## Citation 82 | 83 | If you find our work helpful, please consider citing the following paper and/or ⭐ the repo. 84 | ``` 85 | @misc{ren2025any, 86 | title={Any Image Restoration via Efficient Spatial-Frequency Degradation Adaptation}, 87 | author={Ren, Bin and Zamfir, Eduard and Wu, Zongwei and Li, Yawei and Li, Yidi and Paudel, Danda Pani and Timofte, Radu and Yang, Ming-Hsuan and Van Gool, Luc and Sebe, Nicu}, 88 | year={2025}, 89 | eprint={2504.14249}, 90 | archivePrefix={arXiv}, 91 | primaryClass={cs.CV} 92 | } 93 | ``` 94 | 95 | 96 | ## Acknowledgements 97 | 98 | This code is built on [PromptIR](https://github.com/va1shn9v/PromptIR) and [AirNet](https://github.com/XLearning-SCU/2022-CVPR-AirNet). -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.functional import mse_loss 6 | 7 | 8 | class GANLoss(nn.Module): 9 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 10 | tensor=torch.FloatTensor): 11 | super(GANLoss, self).__init__() 12 | self.real_label = target_real_label 13 | self.fake_label = target_fake_label 14 | self.real_label_var = None 15 | self.fake_label_var = None 16 | self.Tensor = tensor 17 | if use_lsgan: 18 | self.loss = nn.MSELoss() 19 | else: 20 | self.loss = nn.BCELoss() 21 | 22 | def get_target_tensor(self, input, target_is_real): 23 | target_tensor = None 24 | if target_is_real: 25 | create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel())) 26 | # pdb.set_trace() 27 | if create_label: 28 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 29 | # self.real_label_var = Variable(real_tensor, requires_grad=False) 30 | # self.real_label_var = torch.Tensor(real_tensor) 31 | self.real_label_var = real_tensor 32 | target_tensor = self.real_label_var 33 | else: 34 | # pdb.set_trace() 35 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) 36 | if create_label: 37 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 38 | # self.fake_label_var = Variable(fake_tensor, requires_grad=False) 39 | # self.fake_label_var = torch.Tensor(fake_tensor) 40 | self.fake_label_var = fake_tensor 41 | target_tensor = self.fake_label_var 42 | return target_tensor 43 | 44 | def __call__(self, input, target_is_real): 45 | target_tensor = self.get_target_tensor(input, target_is_real) 46 | # pdb.set_trace() 47 | return self.loss(input, target_tensor) 48 | 49 | 50 | 51 | class FocalL1Loss(nn.Module): 52 | def __init__(self, gamma=2.0, epsilon=1e-6, alpha=0.1): 53 | """ 54 | Focal L1 Loss with adjusted weighting for output values in [0, 1]. 55 | 56 | Args: 57 | gamma (float): Focusing parameter. Larger gamma focuses more on hard examples. 58 | epsilon (float): Small constant to prevent weights from being zero. 59 | alpha (float): Scaling factor to normalize error values. 60 | """ 61 | super(FocalL1Loss, self).__init__() 62 | self.gamma = gamma 63 | self.epsilon = epsilon 64 | self.alpha = alpha # Scaling factor to prevent error values from being too small. 65 | 66 | def forward(self, pred, target): 67 | """ 68 | Compute the Focal L1 Loss between the predicted and target images. 69 | 70 | Args: 71 | pred (torch.Tensor): Predicted image [b, c, h, w]. 72 | target (torch.Tensor): Ground truth image [b, c, h, w]. 73 | 74 | Returns: 75 | torch.Tensor: Scalar Focal L1 Loss. 76 | """ 77 | # Compute the absolute error (L1 Loss) and scale it by alpha 78 | abs_err = torch.abs(pred - target) / self.alpha 79 | 80 | # Apply a logarithmic transformation to the error to prevent very small weights 81 | focal_weight = (torch.log(1 + abs_err + self.epsilon)) ** self.gamma 82 | 83 | # Compute the weighted loss 84 | focal_l1_loss = focal_weight * abs_err 85 | 86 | # Return the mean loss across all pixels 87 | return focal_l1_loss.mean() 88 | 89 | 90 | class FFTLoss(nn.Module): 91 | def __init__(self, loss_weight=1.0, reduction='mean'): 92 | super(FFTLoss, self).__init__() 93 | self.loss_weight = loss_weight 94 | self.criterion = torch.nn.L1Loss(reduction=reduction) 95 | 96 | def forward(self, pred, target): 97 | pred_fft = torch.fft.rfft2(pred) 98 | target_fft = torch.fft.rfft2(target) 99 | 100 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1) 101 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1) 102 | 103 | return self.loss_weight * self.criterion(pred_fft, target_fft) 104 | 105 | 106 | class TemperatureScheduler: 107 | def __init__(self, start_temp, end_temp, total_steps): 108 | """ 109 | Scheduler for Gumbel-Softmax temperature that decreases using a cosine annealing schedule. 110 | 111 | Args: 112 | - start_temp (float): Initial temperature (e.g., 5.0). 113 | - end_temp (float): Final temperature (e.g., 0.01). 114 | - total_steps (int): Total number of steps/epochs to anneal over. 115 | """ 116 | self.start_temp = start_temp 117 | self.end_temp = end_temp 118 | self.total_steps = total_steps 119 | 120 | def get_temperature(self, step): 121 | """ 122 | Get the temperature value for the current step, following a cosine annealing schedule. 123 | 124 | Args: 125 | - step (int): Current step or epoch. 126 | 127 | Returns: 128 | - temperature (float): The temperature for the Gumbel-Softmax at this step. 129 | """ 130 | if step >= self.total_steps: 131 | return self.end_temp 132 | 133 | # Cosine annealing formula to compute the temperature 134 | cos_inner = math.pi * step / self.total_steps 135 | #temp = self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + math.cos(cos_inner)) 136 | temp = self.start_temp + 0.5 * (self.end_temp - self.start_temp) * (1 - math.cos(cos_inner)) 137 | 138 | return temp 139 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2020/9/8 3 | 4 | @author: Boyun Li 5 | """ 6 | import os 7 | import numpy as np 8 | import torch 9 | import random 10 | import torch.nn as nn 11 | from torch.nn import init 12 | from PIL import Image 13 | 14 | class EdgeComputation(nn.Module): 15 | def __init__(self, test=False): 16 | super(EdgeComputation, self).__init__() 17 | self.test = test 18 | def forward(self, x): 19 | if self.test: 20 | x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) 21 | x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) 22 | 23 | # y = torch.Tensor(x.size()).cuda() 24 | y = torch.Tensor(x.size()) 25 | y.fill_(0) 26 | y[:, :, :, 1:] += x_diffx 27 | y[:, :, :, :-1] += x_diffx 28 | y[:, :, 1:, :] += x_diffy 29 | y[:, :, :-1, :] += x_diffy 30 | y = torch.sum(y, 1, keepdim=True) / 3 31 | y /= 4 32 | return y 33 | else: 34 | x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1]) 35 | x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :]) 36 | 37 | y = torch.Tensor(x.size()) 38 | y.fill_(0) 39 | y[:, :, 1:] += x_diffx 40 | y[:, :, :-1] += x_diffx 41 | y[:, 1:, :] += x_diffy 42 | y[:, :-1, :] += x_diffy 43 | y = torch.sum(y, 0) / 3 44 | y /= 4 45 | return y.unsqueeze(0) 46 | 47 | 48 | # randomly crop a patch from image 49 | def crop_patch(im, pch_size): 50 | H = im.shape[0] 51 | W = im.shape[1] 52 | ind_H = random.randint(0, H - pch_size) 53 | ind_W = random.randint(0, W - pch_size) 54 | pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size] 55 | return pch 56 | 57 | 58 | # crop an image to the multiple of base 59 | def crop_img(image, base=64): 60 | h = image.shape[0] 61 | w = image.shape[1] 62 | crop_h = h % base 63 | crop_w = w % base 64 | return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] 65 | 66 | 67 | # image (H, W, C) -> patches (B, H, W, C) 68 | def slice_image2patches(image, patch_size=64, overlap=0): 69 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 70 | H = image.shape[0] 71 | W = image.shape[1] 72 | patches = [] 73 | image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge') 74 | for h in range(H // patch_size): 75 | for w in range(W // patch_size): 76 | idx_h = [h * patch_size, (h + 1) * patch_size + overlap] 77 | idx_w = [w * patch_size, (w + 1) * patch_size + overlap] 78 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0)) 79 | return np.concatenate(patches, axis=0) 80 | 81 | 82 | # patches (B, H, W, C) -> image (H, W, C) 83 | def splice_patches2image(patches, image_size, overlap=0): 84 | assert len(image_size) > 1 85 | assert patches.shape[-3] == patches.shape[-2] 86 | H = image_size[0] 87 | W = image_size[1] 88 | patch_size = patches.shape[-2] - overlap 89 | image = np.zeros(image_size) 90 | idx = 0 91 | for h in range(H // patch_size): 92 | for w in range(W // patch_size): 93 | image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx, 94 | overlap:patch_size + overlap, 95 | overlap:patch_size + overlap, 96 | :] 97 | idx += 1 98 | return image 99 | 100 | 101 | # def data_augmentation(image, mode): 102 | # if mode == 0: 103 | # # original 104 | # out = image.numpy() 105 | # elif mode == 1: 106 | # # flip up and down 107 | # out = np.flipud(image) 108 | # elif mode == 2: 109 | # # rotate counterwise 90 degree 110 | # out = np.rot90(image, axes=(1, 2)) 111 | # elif mode == 3: 112 | # # rotate 90 degree and flip up and down 113 | # out = np.rot90(image, axes=(1, 2)) 114 | # out = np.flipud(out) 115 | # elif mode == 4: 116 | # # rotate 180 degree 117 | # out = np.rot90(image, k=2, axes=(1, 2)) 118 | # elif mode == 5: 119 | # # rotate 180 degree and flip 120 | # out = np.rot90(image, k=2, axes=(1, 2)) 121 | # out = np.flipud(out) 122 | # elif mode == 6: 123 | # # rotate 270 degree 124 | # out = np.rot90(image, k=3, axes=(1, 2)) 125 | # elif mode == 7: 126 | # # rotate 270 degree and flip 127 | # out = np.rot90(image, k=3, axes=(1, 2)) 128 | # out = np.flipud(out) 129 | # else: 130 | # raise Exception('Invalid choice of image transformation') 131 | # return out 132 | 133 | def data_augmentation(image, mode): 134 | if mode == 0: 135 | # original 136 | out = image.numpy() 137 | elif mode == 1: 138 | # flip up and down 139 | out = np.flipud(image) 140 | elif mode == 2: 141 | # rotate counterwise 90 degree 142 | out = np.rot90(image) 143 | elif mode == 3: 144 | # rotate 90 degree and flip up and down 145 | out = np.rot90(image) 146 | out = np.flipud(out) 147 | elif mode == 4: 148 | # rotate 180 degree 149 | out = np.rot90(image, k=2) 150 | elif mode == 5: 151 | # rotate 180 degree and flip 152 | out = np.rot90(image, k=2) 153 | out = np.flipud(out) 154 | elif mode == 6: 155 | # rotate 270 degree 156 | out = np.rot90(image, k=3) 157 | elif mode == 7: 158 | # rotate 270 degree and flip 159 | out = np.rot90(image, k=3) 160 | out = np.flipud(out) 161 | else: 162 | raise Exception('Invalid choice of image transformation') 163 | return out 164 | 165 | 166 | # def random_augmentation(*args): 167 | # out = [] 168 | # if random.randint(0, 1) == 1: 169 | # flag_aug = random.randint(1, 7) 170 | # for data in args: 171 | # out.append(data_augmentation(data, flag_aug).copy()) 172 | # else: 173 | # for data in args: 174 | # out.append(data) 175 | # return out 176 | 177 | def random_augmentation(*args): 178 | out = [] 179 | flag_aug = random.randint(1, 7) 180 | for data in args: 181 | out.append(data_augmentation(data, flag_aug).copy()) 182 | return out 183 | 184 | 185 | def weights_init_normal_(m): 186 | classname = m.__class__.__name__ 187 | if classname.find('Conv') != -1: 188 | init.uniform(m.weight.data, 0.0, 0.02) 189 | elif classname.find('Linear') != -1: 190 | init.uniform(m.weight.data, 0.0, 0.02) 191 | elif classname.find('BatchNorm2d') != -1: 192 | init.uniform(m.weight.data, 1.0, 0.02) 193 | init.constant(m.bias.data, 0.0) 194 | 195 | 196 | def weights_init_normal(m): 197 | classname = m.__class__.__name__ 198 | if classname.find('Conv2d') != -1: 199 | m.apply(weights_init_normal_) 200 | elif classname.find('Linear') != -1: 201 | init.uniform(m.weight.data, 0.0, 0.02) 202 | elif classname.find('BatchNorm2d') != -1: 203 | init.uniform(m.weight.data, 1.0, 0.02) 204 | init.constant(m.bias.data, 0.0) 205 | 206 | 207 | def weights_init_xavier(m): 208 | classname = m.__class__.__name__ 209 | if classname.find('Conv') != -1: 210 | init.xavier_normal(m.weight.data, gain=1) 211 | elif classname.find('Linear') != -1: 212 | init.xavier_normal(m.weight.data, gain=1) 213 | elif classname.find('BatchNorm2d') != -1: 214 | init.uniform(m.weight.data, 1.0, 0.02) 215 | init.constant(m.bias.data, 0.0) 216 | 217 | 218 | def weights_init_kaiming(m): 219 | classname = m.__class__.__name__ 220 | if classname.find('Conv') != -1: 221 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 222 | elif classname.find('Linear') != -1: 223 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 224 | elif classname.find('BatchNorm2d') != -1: 225 | init.uniform(m.weight.data, 1.0, 0.02) 226 | init.constant(m.bias.data, 0.0) 227 | 228 | 229 | def weights_init_orthogonal(m): 230 | classname = m.__class__.__name__ 231 | print(classname) 232 | if classname.find('Conv') != -1: 233 | init.orthogonal(m.weight.data, gain=1) 234 | elif classname.find('Linear') != -1: 235 | init.orthogonal(m.weight.data, gain=1) 236 | elif classname.find('BatchNorm2d') != -1: 237 | init.uniform(m.weight.data, 1.0, 0.02) 238 | init.constant(m.bias.data, 0.0) 239 | 240 | 241 | def init_weights(net, init_type='normal'): 242 | print('initialization method [%s]' % init_type) 243 | if init_type == 'normal': 244 | net.apply(weights_init_normal) 245 | elif init_type == 'xavier': 246 | net.apply(weights_init_xavier) 247 | elif init_type == 'kaiming': 248 | net.apply(weights_init_kaiming) 249 | elif init_type == 'orthogonal': 250 | net.apply(weights_init_orthogonal) 251 | else: 252 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 253 | 254 | 255 | def np_to_torch(img_np): 256 | """ 257 | Converts image in numpy.array to torch.Tensor. 258 | 259 | From C x W x H [0..1] to C x W x H [0..1] 260 | 261 | :param img_np: 262 | :return: 263 | """ 264 | return torch.from_numpy(img_np)[None, :] 265 | 266 | 267 | def torch_to_np(img_var): 268 | """ 269 | Converts an image in torch.Tensor format to np.array. 270 | 271 | From 1 x C x W x H [0..1] to C x W x H [0..1] 272 | :param img_var: 273 | :return: 274 | """ 275 | return img_var.detach().cpu().numpy() 276 | # return img_var.detach().cpu().numpy()[0] 277 | 278 | 279 | def save_image(name, image_np, output_path="output/normal/"): 280 | if not os.path.exists(output_path): 281 | os.mkdir(output_path) 282 | 283 | p = np_to_pil(image_np) 284 | p.save(output_path + "{}.png".format(name)) 285 | 286 | 287 | def np_to_pil(img_np): 288 | """ 289 | Converts image in np.array format to PIL image. 290 | 291 | From C x W x H [0..1] to W x H x C [0...255] 292 | :param img_np: 293 | :return: 294 | """ 295 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 296 | 297 | if img_np.shape[0] == 1: 298 | ar = ar[0] 299 | else: 300 | assert img_np.shape[0] == 3, img_np.shape 301 | ar = ar.transpose(1, 2, 0) 302 | 303 | return Image.fromarray(ar) -------------------------------------------------------------------------------- /utils/image_io.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import torch 4 | import torchvision 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from PIL import Image 9 | 10 | # import skvideo.io 11 | 12 | matplotlib.use('agg') 13 | 14 | 15 | def prepare_hazy_image(file_name): 16 | img_pil = crop_image(get_image(file_name, -1)[0], d=32) 17 | return pil_to_np(img_pil) 18 | 19 | 20 | def prepare_gt_img(file_name, SOTS=True): 21 | if SOTS: 22 | img_pil = crop_image(crop_a_image(get_image(file_name, -1)[0], d=10), d=32) 23 | else: 24 | img_pil = crop_image(get_image(file_name, -1)[0], d=32) 25 | 26 | return pil_to_np(img_pil) 27 | 28 | 29 | def crop_a_image(img, d=10): 30 | bbox = [ 31 | int((d)), 32 | int((d)), 33 | int((img.size[0] - d)), 34 | int((img.size[1] - d)), 35 | ] 36 | img_cropped = img.crop(bbox) 37 | return img_cropped 38 | 39 | 40 | def crop_image(img, d=32): 41 | """ 42 | Make dimensions divisible by d 43 | 44 | :param pil img: 45 | :param d: 46 | :return: 47 | """ 48 | 49 | new_size = (img.size[0] - img.size[0] % d, 50 | img.size[1] - img.size[1] % d) 51 | 52 | bbox = [ 53 | int((img.size[0] - new_size[0]) / 2), 54 | int((img.size[1] - new_size[1]) / 2), 55 | int((img.size[0] + new_size[0]) / 2), 56 | int((img.size[1] + new_size[1]) / 2), 57 | ] 58 | 59 | img_cropped = img.crop(bbox) 60 | return img_cropped 61 | 62 | 63 | def crop_np_image(img_np, d=32): 64 | return torch_to_np(crop_torch_image(np_to_torch(img_np), d)) 65 | 66 | 67 | def crop_torch_image(img, d=32): 68 | """ 69 | Make dimensions divisible by d 70 | image is [1, 3, W, H] or [3, W, H] 71 | :param pil img: 72 | :param d: 73 | :return: 74 | """ 75 | new_size = (img.shape[-2] - img.shape[-2] % d, 76 | img.shape[-1] - img.shape[-1] % d) 77 | pad = ((img.shape[-2] - new_size[-2]) // 2, (img.shape[-1] - new_size[-1]) // 2) 78 | 79 | if len(img.shape) == 4: 80 | return img[:, :, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] 81 | assert len(img.shape) == 3 82 | return img[:, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] 83 | 84 | 85 | def get_params(opt_over, net, net_input, downsampler=None): 86 | """ 87 | Returns parameters that we want to optimize over. 88 | :param opt_over: comma separated list, e.g. "net,input" or "net" 89 | :param net: network 90 | :param net_input: torch.Tensor that stores input `z` 91 | :param downsampler: 92 | :return: 93 | """ 94 | 95 | opt_over_list = opt_over.split(',') 96 | params = [] 97 | 98 | for opt in opt_over_list: 99 | 100 | if opt == 'net': 101 | params += [x for x in net.parameters()] 102 | elif opt == 'down': 103 | assert downsampler is not None 104 | params = [x for x in downsampler.parameters()] 105 | elif opt == 'input': 106 | net_input.requires_grad = True 107 | params += [net_input] 108 | else: 109 | assert False, 'what is it?' 110 | 111 | return params 112 | 113 | 114 | def get_image_grid(images_np, nrow=8): 115 | """ 116 | Creates a grid from a list of images by concatenating them. 117 | :param images_np: 118 | :param nrow: 119 | :return: 120 | """ 121 | images_torch = [torch.from_numpy(x).type(torch.FloatTensor) for x in images_np] 122 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 123 | 124 | return torch_grid.numpy() 125 | 126 | 127 | def plot_image_grid(name, images_np, interpolation='lanczos', output_path="output/"): 128 | """ 129 | Draws images in a grid 130 | 131 | Args: 132 | images_np: list of images, each image is np.array of size 3xHxW or 1xHxW 133 | nrow: how many images will be in one row 134 | interpolation: interpolation used in plt.imshow 135 | """ 136 | assert len(images_np) == 2 137 | n_channels = max(x.shape[0] for x in images_np) 138 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 139 | 140 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 141 | 142 | grid = get_image_grid(images_np, 2) 143 | 144 | if images_np[0].shape[0] == 1: 145 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 146 | else: 147 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 148 | 149 | plt.savefig(output_path + "{}.png".format(name)) 150 | 151 | 152 | def save_image_np(name, image_np, output_path="output/"): 153 | p = np_to_pil(image_np) 154 | p.save(output_path + "{}.png".format(name)) 155 | 156 | 157 | def save_image_tensor(image_tensor, output_path="output/"): 158 | image_np = torch_to_np(image_tensor) 159 | # print(image_np.shape) 160 | p = np_to_pil(image_np) 161 | p.save(output_path) 162 | 163 | 164 | def video_to_images(file_name, name): 165 | video = prepare_video(file_name) 166 | for i, f in enumerate(video): 167 | save_image(name + "_{0:03d}".format(i), f) 168 | 169 | 170 | def images_to_video(images_dir, name, gray=True): 171 | num = len(glob.glob(images_dir + "/*.jpg")) 172 | c = [] 173 | for i in range(num): 174 | if gray: 175 | img = prepare_gray_image(images_dir + "/" + name + "_{}.jpg".format(i)) 176 | else: 177 | img = prepare_image(images_dir + "/" + name + "_{}.jpg".format(i)) 178 | print(img.shape) 179 | c.append(img) 180 | save_video(name, np.array(c)) 181 | 182 | 183 | def save_heatmap(name, image_np): 184 | cmap = plt.get_cmap('jet') 185 | 186 | rgba_img = cmap(image_np) 187 | rgb_img = np.delete(rgba_img, 3, 2) 188 | save_image(name, rgb_img.transpose(2, 0, 1)) 189 | 190 | 191 | def save_graph(name, graph_list, output_path="output/"): 192 | plt.clf() 193 | plt.plot(graph_list) 194 | plt.savefig(output_path + name + ".png") 195 | 196 | 197 | def create_augmentations(np_image): 198 | """ 199 | convention: original, left, upside-down, right, rot1, rot2, rot3 200 | :param np_image: 201 | :return: 202 | """ 203 | aug = [np_image.copy(), np.rot90(np_image, 1, (1, 2)).copy(), 204 | np.rot90(np_image, 2, (1, 2)).copy(), np.rot90(np_image, 3, (1, 2)).copy()] 205 | flipped = np_image[:, ::-1, :].copy() 206 | aug += [flipped.copy(), np.rot90(flipped, 1, (1, 2)).copy(), np.rot90(flipped, 2, (1, 2)).copy(), 207 | np.rot90(flipped, 3, (1, 2)).copy()] 208 | return aug 209 | 210 | 211 | def create_video_augmentations(np_video): 212 | """ 213 | convention: original, left, upside-down, right, rot1, rot2, rot3 214 | :param np_video: 215 | :return: 216 | """ 217 | aug = [np_video.copy(), np.rot90(np_video, 1, (2, 3)).copy(), 218 | np.rot90(np_video, 2, (2, 3)).copy(), np.rot90(np_video, 3, (2, 3)).copy()] 219 | flipped = np_video[:, :, ::-1, :].copy() 220 | aug += [flipped.copy(), np.rot90(flipped, 1, (2, 3)).copy(), np.rot90(flipped, 2, (2, 3)).copy(), 221 | np.rot90(flipped, 3, (2, 3)).copy()] 222 | return aug 223 | 224 | 225 | def save_graphs(name, graph_dict, output_path="output/"): 226 | """ 227 | 228 | :param name: 229 | :param dict graph_dict: a dict from the name of the list to the list itself. 230 | :return: 231 | """ 232 | plt.clf() 233 | fig, ax = plt.subplots() 234 | for k, v in graph_dict.items(): 235 | ax.plot(v, label=k) 236 | # ax.semilogy(v, label=k) 237 | ax.set_xlabel('iterations') 238 | # ax.set_ylabel(name) 239 | ax.set_ylabel('MSE-loss') 240 | # ax.set_ylabel('PSNR') 241 | plt.legend() 242 | plt.savefig(output_path + name + ".png") 243 | 244 | 245 | def load(path): 246 | """Load PIL image.""" 247 | img = Image.open(path) 248 | return img 249 | 250 | 251 | def get_image(path, imsize=-1): 252 | """Load an image and resize to a cpecific size. 253 | 254 | Args: 255 | path: path to image 256 | imsize: tuple or scalar with dimensions; -1 for `no resize` 257 | """ 258 | img = load(path) 259 | if isinstance(imsize, int): 260 | imsize = (imsize, imsize) 261 | 262 | if imsize[0] != -1 and img.size != imsize: 263 | if imsize[0] > img.size[0]: 264 | img = img.resize(imsize, Image.BICUBIC) 265 | else: 266 | img = img.resize(imsize, Image.ANTIALIAS) 267 | 268 | img_np = pil_to_np(img) 269 | # 3*460*620 270 | # print(np.shape(img_np)) 271 | 272 | return img, img_np 273 | 274 | 275 | def prepare_gt(file_name): 276 | """ 277 | loads makes it divisible 278 | :param file_name: 279 | :return: the numpy representation of the image 280 | """ 281 | img = get_image(file_name, -1) 282 | # print(img[0].size) 283 | 284 | img_pil = img[0].crop([10, 10, img[0].size[0] - 10, img[0].size[1] - 10]) 285 | 286 | img_pil = crop_image(img_pil, d=32) 287 | 288 | # img_pil = get_image(file_name, -1)[0] 289 | # print(img_pil.size) 290 | return pil_to_np(img_pil) 291 | 292 | 293 | def prepare_image(file_name): 294 | """ 295 | loads makes it divisible 296 | :param file_name: 297 | :return: the numpy representation of the image 298 | """ 299 | img = get_image(file_name, -1) 300 | # print(img[0].size) 301 | # img_pil = img[0] 302 | img_pil = crop_image(img[0], d=16) 303 | # img_pil = get_image(file_name, -1)[0] 304 | # print(img_pil.size) 305 | return pil_to_np(img_pil) 306 | 307 | 308 | # def prepare_video(file_name, folder="output/"): 309 | # data = skvideo.io.vread(folder + file_name) 310 | # return crop_torch_image(data.transpose(0, 3, 1, 2).astype(np.float32) / 255.)[:35] 311 | # 312 | # 313 | # def save_video(name, video_np, output_path="output/"): 314 | # outputdata = video_np * 255 315 | # outputdata = outputdata.astype(np.uint8) 316 | # skvideo.io.vwrite(output_path + "{}.mp4".format(name), outputdata.transpose(0, 2, 3, 1)) 317 | 318 | 319 | def prepare_gray_image(file_name): 320 | img = prepare_image(file_name) 321 | return np.array([np.mean(img, axis=0)]) 322 | 323 | 324 | def pil_to_np(img_PIL, with_transpose=True): 325 | """ 326 | Converts image in PIL format to np.array. 327 | 328 | From W x H x C [0...255] to C x W x H [0..1] 329 | """ 330 | ar = np.array(img_PIL) 331 | if len(ar.shape) == 3 and ar.shape[-1] == 4: 332 | ar = ar[:, :, :3] 333 | # this is alpha channel 334 | if with_transpose: 335 | if len(ar.shape) == 3: 336 | ar = ar.transpose(2, 0, 1) 337 | else: 338 | ar = ar[None, ...] 339 | 340 | return ar.astype(np.float32) / 255. 341 | 342 | 343 | def median(img_np_list): 344 | """ 345 | assumes C x W x H [0..1] 346 | :param img_np_list: 347 | :return: 348 | """ 349 | assert len(img_np_list) > 0 350 | l = len(img_np_list) 351 | shape = img_np_list[0].shape 352 | result = np.zeros(shape) 353 | for c in range(shape[0]): 354 | for w in range(shape[1]): 355 | for h in range(shape[2]): 356 | result[c, w, h] = sorted(i[c, w, h] for i in img_np_list)[l // 2] 357 | return result 358 | 359 | 360 | def average(img_np_list): 361 | """ 362 | assumes C x W x H [0..1] 363 | :param img_np_list: 364 | :return: 365 | """ 366 | assert len(img_np_list) > 0 367 | l = len(img_np_list) 368 | shape = img_np_list[0].shape 369 | result = np.zeros(shape) 370 | for i in img_np_list: 371 | result += i 372 | return result / l 373 | 374 | 375 | def np_to_pil(img_np): 376 | """ 377 | Converts image in np.array format to PIL image. 378 | 379 | From C x W x H [0..1] to W x H x C [0...255] 380 | :param img_np: 381 | :return: 382 | """ 383 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 384 | 385 | if img_np.shape[0] == 1: 386 | ar = ar[0] 387 | else: 388 | assert img_np.shape[0] == 3, img_np.shape 389 | ar = ar.transpose(1, 2, 0) 390 | 391 | return Image.fromarray(ar) 392 | 393 | 394 | def np_to_torch(img_np): 395 | """ 396 | Converts image in numpy.array to torch.Tensor. 397 | 398 | From C x W x H [0..1] to C x W x H [0..1] 399 | 400 | :param img_np: 401 | :return: 402 | """ 403 | return torch.from_numpy(img_np)[None, :] 404 | 405 | 406 | def torch_to_np(img_var): 407 | """ 408 | Converts an image in torch.Tensor format to np.array. 409 | 410 | From 1 x C x W x H [0..1] to C x W x H [0..1] 411 | :param img_var: 412 | :return: 413 | """ 414 | return img_var.detach().cpu().numpy()[0] 415 | -------------------------------------------------------------------------------- /utils/imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import filters, measurements, interpolation 3 | from math import pi 4 | 5 | 6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 9 | 10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 13 | 14 | # Choose interpolation method, each method has the matching kernel size 15 | method, kernel_width = { 16 | "cubic": (cubic, 4.0), 17 | "lanczos2": (lanczos2, 4.0), 18 | "lanczos3": (lanczos3, 6.0), 19 | "box": (box, 1.0), 20 | "linear": (linear, 2.0), 21 | None: (cubic, 4.0) # set default interpolation method as cubic 22 | }.get(kernel) 23 | 24 | # Antialiasing is only used when downscaling 25 | antialiasing *= (scale_factor[0] < 1) 26 | 27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 29 | 30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 31 | out_im = np.copy(im) 32 | for dim in sorted_dims: 33 | # No point doing calculations for scale-factor 1. nothing will happen anyway 34 | if scale_factor[dim] == 1.0: 35 | continue 36 | 37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 38 | # weights that multiply the values there to get its result. 39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 40 | method, kernel_width, antialiasing) 41 | 42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 44 | 45 | return out_im 46 | 47 | 48 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 50 | # same size as the number of input dimensions) 51 | if scale_factor is not None: 52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 53 | if np.isscalar(scale_factor): 54 | scale_factor = [scale_factor, scale_factor] 55 | 56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 57 | scale_factor = list(scale_factor) 58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 59 | 60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 61 | # to all the unspecified dimensions 62 | if output_shape is not None: 63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 64 | 65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 66 | # sub-optimal, because there can be different scales to the same output-shape. 67 | if scale_factor is None: 68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 69 | 70 | # Dealing with missing output-shape. calculating according to scale-factor 71 | if output_shape is None: 72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 73 | 74 | return scale_factor, output_shape 75 | 76 | 77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 79 | # such that each position from the field_of_view will be multiplied with a matching filter from the 80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 81 | # around it. This is only done for one dimension of the image. 82 | 83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 84 | # 1/sf. this means filtering is more 'low-pass filter'. 85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 86 | kernel_width *= 1.0 / scale if antialiasing else 1.0 87 | 88 | # These are the coordinates of the output image 89 | out_coordinates = np.arange(1, out_length+1) 90 | 91 | # These are the matching positions of the output-coordinates on the input image coordinates. 92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 101 | 102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 103 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 104 | 105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 107 | expanded_kernel_width = np.ceil(kernel_width) + 2 108 | 109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 111 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 113 | 114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 116 | # 'field_of_view') 117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 118 | 119 | # Normalize weights to sum up to 1. be careful from dividing by 0 120 | sum_weights = np.sum(weights, axis=1) 121 | sum_weights[sum_weights == 0] = 1.0 122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 123 | 124 | # We use this mirror structure as a trick for reflection padding at the boundaries 125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 127 | 128 | # Get rid of weights and pixel positions that are of zero weight 129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 130 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 132 | 133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 134 | return weights, field_of_view 135 | 136 | 137 | def resize_along_dim(im, dim, weights, field_of_view): 138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 139 | tmp_im = np.swapaxes(im, dim, 0) 140 | 141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 142 | # tmp_im[field_of_view.T], (bsxfun style) 143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 144 | 145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 150 | # same number 151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 152 | 153 | # Finally we swap back the axes to the original order 154 | return np.swapaxes(tmp_out_im, dim, 0) 155 | 156 | 157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 158 | # See kernel_shift function to understand what this is 159 | if kernel_shift_flag: 160 | kernel = kernel_shift(kernel, scale_factor) 161 | 162 | # First run a correlation (convolution with flipped kernel) 163 | out_im = np.zeros_like(im) 164 | for channel in range(np.ndim(im)): 165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 166 | 167 | # Then subsample and return 168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 170 | 171 | 172 | def kernel_shift(kernel, sf): 173 | # There are two reasons for shifting the kernel: 174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 175 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 178 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 181 | 182 | # First calculate the current center of mass for the kernel 183 | current_center_of_mass = measurements.center_of_mass(kernel) 184 | 185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 186 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 187 | 188 | # Define the shift vector for the kernel shifting (x,y) 189 | shift_vec = wanted_center_of_mass - current_center_of_mass 190 | 191 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 192 | # (biggest shift among dims + 1 for safety) 193 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 194 | 195 | # Finally shift the kernel and return 196 | return interpolation.shift(kernel, shift_vec) 197 | 198 | 199 | # These next functions are all interpolation methods. x is the distance from the left pixel center 200 | 201 | 202 | def cubic(x): 203 | absx = np.abs(x) 204 | absx2 = absx ** 2 205 | absx3 = absx ** 3 206 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + 207 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) 208 | 209 | 210 | def lanczos2(x): 211 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / 212 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) 213 | * (abs(x) < 2)) 214 | 215 | 216 | def box(x): 217 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 218 | 219 | 220 | def lanczos3(x): 221 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / 222 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) 223 | * (abs(x) < 3)) 224 | 225 | 226 | def linear(x): 227 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 228 | 229 | 230 | def np_imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 231 | return np.clip(imresize(im.transpose(1, 2, 0), scale_factor, output_shape, kernel, antialiasing, 232 | kernel_shift_flag).transpose(2, 0, 1), 0, 1) -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import torch 5 | import warnings 6 | from typing import List 7 | 8 | from torch import nn 9 | from torch.optim import Adam, Optimizer 10 | 11 | class MultiStepRestartLR(_LRScheduler): 12 | """ MultiStep with restarts learning rate scheme. 13 | 14 | Args: 15 | optimizer (torch.nn.optimizer): Torch optimizer. 16 | milestones (list): Iterations that will decrease learning rate. 17 | gamma (float): Decrease ratio. Default: 0.1. 18 | restarts (list): Restart iterations. Default: [0]. 19 | restart_weights (list): Restart weights at each restart iteration. 20 | Default: [1]. 21 | last_epoch (int): Used in _LRScheduler. Default: -1. 22 | """ 23 | 24 | def __init__(self, 25 | optimizer, 26 | milestones, 27 | gamma=0.1, 28 | restarts=(0, ), 29 | restart_weights=(1, ), 30 | last_epoch=-1): 31 | self.milestones = Counter(milestones) 32 | self.gamma = gamma 33 | self.restarts = restarts 34 | self.restart_weights = restart_weights 35 | assert len(self.restarts) == len( 36 | self.restart_weights), 'restarts and their weights do not match.' 37 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | if self.last_epoch in self.restarts: 41 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 42 | return [ 43 | group['initial_lr'] * weight 44 | for group in self.optimizer.param_groups 45 | ] 46 | if self.last_epoch not in self.milestones: 47 | return [group['lr'] for group in self.optimizer.param_groups] 48 | return [ 49 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 50 | for group in self.optimizer.param_groups 51 | ] 52 | 53 | class LinearLR(_LRScheduler): 54 | """ 55 | 56 | Args: 57 | optimizer (torch.nn.optimizer): Torch optimizer. 58 | milestones (list): Iterations that will decrease learning rate. 59 | gamma (float): Decrease ratio. Default: 0.1. 60 | last_epoch (int): Used in _LRScheduler. Default: -1. 61 | """ 62 | 63 | def __init__(self, 64 | optimizer, 65 | total_iter, 66 | last_epoch=-1): 67 | self.total_iter = total_iter 68 | super(LinearLR, self).__init__(optimizer, last_epoch) 69 | 70 | def get_lr(self): 71 | process = self.last_epoch / self.total_iter 72 | weight = (1 - process) 73 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 74 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 75 | 76 | class VibrateLR(_LRScheduler): 77 | """ 78 | 79 | Args: 80 | optimizer (torch.nn.optimizer): Torch optimizer. 81 | milestones (list): Iterations that will decrease learning rate. 82 | gamma (float): Decrease ratio. Default: 0.1. 83 | last_epoch (int): Used in _LRScheduler. Default: -1. 84 | """ 85 | 86 | def __init__(self, 87 | optimizer, 88 | total_iter, 89 | last_epoch=-1): 90 | self.total_iter = total_iter 91 | super(VibrateLR, self).__init__(optimizer, last_epoch) 92 | 93 | def get_lr(self): 94 | process = self.last_epoch / self.total_iter 95 | 96 | f = 0.1 97 | if process < 3 / 8: 98 | f = 1 - process * 8 / 3 99 | elif process < 5 / 8: 100 | f = 0.2 101 | 102 | T = self.total_iter // 80 103 | Th = T // 2 104 | 105 | t = self.last_epoch % T 106 | 107 | f2 = t / Th 108 | if t >= Th: 109 | f2 = 2 - f2 110 | 111 | weight = f * f2 112 | 113 | if self.last_epoch < Th: 114 | weight = max(0.1, weight) 115 | 116 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 117 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 118 | 119 | def get_position_from_periods(iteration, cumulative_period): 120 | """Get the position from a period list. 121 | 122 | It will return the index of the right-closest number in the period list. 123 | For example, the cumulative_period = [100, 200, 300, 400], 124 | if iteration == 50, return 0; 125 | if iteration == 210, return 2; 126 | if iteration == 300, return 2. 127 | 128 | Args: 129 | iteration (int): Current iteration. 130 | cumulative_period (list[int]): Cumulative period list. 131 | 132 | Returns: 133 | int: The position of the right-closest number in the period list. 134 | """ 135 | for i, period in enumerate(cumulative_period): 136 | if iteration <= period: 137 | return i 138 | 139 | 140 | class CosineAnnealingRestartLR(_LRScheduler): 141 | """ Cosine annealing with restarts learning rate scheme. 142 | 143 | An example of config: 144 | periods = [10, 10, 10, 10] 145 | restart_weights = [1, 0.5, 0.5, 0.5] 146 | eta_min=1e-7 147 | 148 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 149 | scheduler will restart with the weights in restart_weights. 150 | 151 | Args: 152 | optimizer (torch.nn.optimizer): Torch optimizer. 153 | periods (list): Period for each cosine anneling cycle. 154 | restart_weights (list): Restart weights at each restart iteration. 155 | Default: [1]. 156 | eta_min (float): The mimimum lr. Default: 0. 157 | last_epoch (int): Used in _LRScheduler. Default: -1. 158 | """ 159 | 160 | def __init__(self, 161 | optimizer, 162 | periods, 163 | restart_weights=(1, ), 164 | eta_min=0, 165 | last_epoch=-1): 166 | self.periods = periods 167 | self.restart_weights = restart_weights 168 | self.eta_min = eta_min 169 | assert (len(self.periods) == len(self.restart_weights) 170 | ), 'periods and restart_weights should have the same length.' 171 | self.cumulative_period = [ 172 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 173 | ] 174 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 175 | 176 | def get_lr(self): 177 | idx = get_position_from_periods(self.last_epoch, 178 | self.cumulative_period) 179 | current_weight = self.restart_weights[idx] 180 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 181 | current_period = self.periods[idx] 182 | 183 | return [ 184 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 185 | (1 + math.cos(math.pi * ( 186 | (self.last_epoch - nearest_restart) / current_period))) 187 | for base_lr in self.base_lrs 188 | ] 189 | 190 | class CosineAnnealingRestartCyclicLR(_LRScheduler): 191 | """ Cosine annealing with restarts learning rate scheme. 192 | An example of config: 193 | periods = [10, 10, 10, 10] 194 | restart_weights = [1, 0.5, 0.5, 0.5] 195 | eta_min=1e-7 196 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 197 | scheduler will restart with the weights in restart_weights. 198 | Args: 199 | optimizer (torch.nn.optimizer): Torch optimizer. 200 | periods (list): Period for each cosine anneling cycle. 201 | restart_weights (list): Restart weights at each restart iteration. 202 | Default: [1]. 203 | eta_min (float): The mimimum lr. Default: 0. 204 | last_epoch (int): Used in _LRScheduler. Default: -1. 205 | """ 206 | 207 | def __init__(self, 208 | optimizer, 209 | periods, 210 | restart_weights=(1, ), 211 | eta_mins=(0, ), 212 | last_epoch=-1): 213 | self.periods = periods 214 | self.restart_weights = restart_weights 215 | self.eta_mins = eta_mins 216 | assert (len(self.periods) == len(self.restart_weights) 217 | ), 'periods and restart_weights should have the same length.' 218 | self.cumulative_period = [ 219 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 220 | ] 221 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) 222 | 223 | def get_lr(self): 224 | idx = get_position_from_periods(self.last_epoch, 225 | self.cumulative_period) 226 | current_weight = self.restart_weights[idx] 227 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 228 | current_period = self.periods[idx] 229 | eta_min = self.eta_mins[idx] 230 | 231 | return [ 232 | eta_min + current_weight * 0.5 * (base_lr - eta_min) * 233 | (1 + math.cos(math.pi * ( 234 | (self.last_epoch - nearest_restart) / current_period))) 235 | for base_lr in self.base_lrs 236 | ] 237 | 238 | 239 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 240 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr 241 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min. 242 | .. warning:: 243 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 244 | after each iteration as calling it after each epoch will keep the starting lr at 245 | warmup_start_lr for the first epoch which is 0 in most cases. 246 | .. warning:: 247 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 248 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 249 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 250 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 251 | train and validation methods. 252 | Example: 253 | >>> layer = nn.Linear(10, 1) 254 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 255 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 256 | >>> # 257 | >>> # the default case 258 | >>> for epoch in range(40): 259 | ... # train(...) 260 | ... # validate(...) 261 | ... scheduler.step() 262 | >>> # 263 | >>> # passing epoch param case 264 | >>> for epoch in range(40): 265 | ... scheduler.step(epoch) 266 | ... # train(...) 267 | ... # validate(...) 268 | """ 269 | 270 | def __init__( 271 | self, 272 | optimizer: Optimizer, 273 | warmup_epochs: int, 274 | max_epochs: int, 275 | warmup_start_lr: float = 0.0, 276 | eta_min: float = 0.0, 277 | last_epoch: int = -1, 278 | ) -> None: 279 | """ 280 | Args: 281 | optimizer (Optimizer): Wrapped optimizer. 282 | warmup_epochs (int): Maximum number of iterations for linear warmup 283 | max_epochs (int): Maximum number of iterations 284 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 285 | eta_min (float): Minimum learning rate. Default: 0. 286 | last_epoch (int): The index of last epoch. Default: -1. 287 | """ 288 | self.warmup_epochs = warmup_epochs 289 | self.max_epochs = max_epochs 290 | self.warmup_start_lr = warmup_start_lr 291 | self.eta_min = eta_min 292 | 293 | super().__init__(optimizer, last_epoch) 294 | 295 | def get_lr(self) -> List[float]: 296 | """Compute learning rate using chainable form of the scheduler.""" 297 | if not self._get_lr_called_within_step: 298 | warnings.warn( 299 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 300 | UserWarning, 301 | ) 302 | 303 | if self.last_epoch == 0: 304 | return [self.warmup_start_lr] * len(self.base_lrs) 305 | if self.last_epoch < self.warmup_epochs: 306 | return [ 307 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 308 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 309 | ] 310 | if self.last_epoch == self.warmup_epochs: 311 | return self.base_lrs 312 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 313 | return [ 314 | group["lr"] 315 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 316 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 317 | ] 318 | 319 | return [ 320 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 321 | / ( 322 | 1 323 | + math.cos( 324 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 325 | ) 326 | ) 327 | * (group["lr"] - self.eta_min) 328 | + self.eta_min 329 | for group in self.optimizer.param_groups 330 | ] 331 | 332 | def _get_closed_form_lr(self) -> List[float]: 333 | """Called when epoch is passed as a param to the `step` function of the scheduler.""" 334 | if self.last_epoch < self.warmup_epochs: 335 | return [ 336 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 337 | for base_lr in self.base_lrs 338 | ] 339 | 340 | return [ 341 | self.eta_min 342 | + 0.5 343 | * (base_lr - self.eta_min) 344 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 345 | for base_lr in self.base_lrs 346 | ] 347 | 348 | 349 | # warmup + decay as a function 350 | def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): 351 | """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" 352 | assert not (linear and cosine) 353 | 354 | def fn(step): 355 | if step < warmup_steps: 356 | return float(step) / float(max(1, warmup_steps)) 357 | 358 | if not (cosine or linear): 359 | # no decay 360 | return 1.0 361 | 362 | progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) 363 | if cosine: 364 | # cosine decay 365 | return 0.5 * (1.0 + math.cos(math.pi * progress)) 366 | 367 | # linear decay 368 | return 1.0 - progress 369 | 370 | return fn 371 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | import glob 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from torch.utils.data import Dataset 9 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor 10 | import torch 11 | 12 | from utils.image_utils import random_augmentation, crop_img 13 | from utils.degradation_utils import Degradation 14 | 15 | 16 | class CDD11(Dataset): 17 | def __init__(self, args, split: str = "train", subset: str = "all"): 18 | super(CDD11, self).__init__() 19 | 20 | self.args = args 21 | self.toTensor = ToTensor() 22 | # self.de_type = self.args.de_type 23 | self.dataset_split = split 24 | self.subset = subset 25 | if split == "train": 26 | self.patch_size = args.patch_size 27 | else: 28 | self.patch_size = 64 29 | self.cdd11_dir = args.cdd11_path 30 | 31 | self._init() 32 | 33 | def __getitem__(self, index): 34 | # Randomly select a degradation type 35 | if self.dataset_split == "train": 36 | degradation_type = random.choice(list(self.degraded_dict.keys())) 37 | degraded_image_path = random.choice(self.degraded_dict[degradation_type]) 38 | else: 39 | degradation_type = self.subset 40 | degraded_image_path = self.degraded_dict[degradation_type][index] 41 | 42 | # Select a degraded image within that type 43 | 44 | degraded_name = os.path.basename(degraded_image_path) 45 | 46 | # Get the corresponding clean image based on the file name 47 | image_name = os.path.basename(degraded_image_path) 48 | assert degraded_name == image_name 49 | clean_image_path = os.path.join(os.path.dirname(self.clean[0]), image_name) 50 | 51 | # Load the images 52 | #lr = crop_img(np.array(Image.open(degraded_image_path).convert('RGB')), base=16) 53 | lr = np.array(Image.open(degraded_image_path).convert('RGB')) 54 | #hr = crop_img(np.array(Image.open(clean_image_path).convert('RGB')), base=16) 55 | hr = np.array(Image.open(clean_image_path).convert('RGB')) 56 | # Apply random augmentation and crop 57 | if self.dataset_split == "train": 58 | lr, hr = random_augmentation(*self._crop_patch(lr, hr)) 59 | 60 | # Convert to tensors 61 | lr = self.toTensor(lr) 62 | hr = self.toTensor(hr) 63 | 64 | return [clean_image_path, degradation_type], lr, hr 65 | 66 | def __len__(self): 67 | return sum(len(images) for images in self.degraded_dict.values()) 68 | 69 | def _init(self): 70 | data_dir = os.path.join(self.cdd11_dir) 71 | self.clean = sorted(glob.glob(os.path.join(data_dir, f"{self.dataset_split}/clear", "*.png"))) 72 | 73 | if len(self.clean) == 0: 74 | raise ValueError(f"No clean images found in {os.path.join(data_dir, f'{self.dataset_split}/clear')}") 75 | 76 | self.degraded_dict = {} 77 | allowed_degradation_folders = self._filter_degradation_folders(data_dir) 78 | for folder in allowed_degradation_folders: 79 | folder_name = os.path.basename(folder.strip('/')) 80 | degraded_images = sorted(glob.glob(os.path.join(folder, "*.png"))) 81 | 82 | if len(degraded_images) == 0: 83 | raise ValueError(f"No images found in {folder_name}") 84 | 85 | # scale dataset length 86 | if self.dataset_split == "train": 87 | degraded_images *= 6 88 | 89 | self.degraded_dict[folder_name] = degraded_images 90 | 91 | def _filter_degradation_folders(self, data_dir): 92 | """ 93 | This function returns folders based on the degradation_type_mode. 94 | 'single', 'double', 'triple', or 'all' degradation types will be returned. 95 | """ 96 | degradation_folders = sorted(glob.glob(os.path.join(data_dir, self.dataset_split, "*/"))) 97 | filtered_folders = [] 98 | 99 | for folder in degradation_folders: 100 | folder_name = os.path.basename(folder.strip('/')) 101 | if folder_name == "clear": 102 | continue 103 | 104 | # Count the number of degradations based on the number of underscores in the folder name 105 | degradation_count = folder_name.count('_') + 1 106 | 107 | # Check the degradation type mode and filter accordingly 108 | if self.subset == "single" and degradation_count == 1: 109 | filtered_folders.append(folder) 110 | elif self.subset == "double" and degradation_count == 2: 111 | filtered_folders.append(folder) 112 | elif self.subset == "triple" and degradation_count == 3: 113 | filtered_folders.append(folder) 114 | elif self.subset == "all": 115 | filtered_folders.append(folder) 116 | # If self.subset is a specific degradation folder name, match it exactly 117 | elif self.subset not in ["single", "double", "triple", "all"]: 118 | if folder_name == self.subset: 119 | filtered_folders.append(folder) 120 | 121 | print(f"Degradation type mode: {self.subset}") 122 | print(f"Loading degradation folders: {[os.path.basename(f.strip('/')) for f in filtered_folders]}") 123 | return filtered_folders 124 | 125 | def _crop_patch(self, img_1, img_2): 126 | # Crop a patch from both images (degraded and clean) at the same location 127 | H = img_1.shape[0] 128 | W = img_1.shape[1] 129 | ind_H = random.randint(0, H - self.args.patch_size) 130 | ind_W = random.randint(0, W - self.args.patch_size) 131 | 132 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] 133 | patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] 134 | 135 | return patch_1, patch_2 136 | 137 | 138 | class AnyIRTrainDataset(Dataset): 139 | def __init__(self, args): 140 | super(AnyIRTrainDataset, self).__init__() 141 | self.args = args 142 | self.rs_ids = [] 143 | self.hazy_ids = [] 144 | self.D = Degradation(args) 145 | self.de_temp = 0 146 | self.de_type = self.args.de_type 147 | print(self.de_type) 148 | 149 | self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4, 'deblur' : 5} 150 | 151 | self._init_ids() 152 | self._merge_ids() 153 | 154 | self.crop_transform = Compose([ 155 | ToPILImage(), 156 | RandomCrop(args.patch_size), 157 | ]) 158 | 159 | self.toTensor = ToTensor() 160 | 161 | def _init_ids(self): 162 | if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type: 163 | self._init_clean_ids() 164 | if 'derain' in self.de_type: 165 | self._init_rs_ids() 166 | if 'dehaze' in self.de_type: 167 | self._init_hazy_ids() 168 | if 'deblur' in self.de_type: 169 | self._init_deblur_ids() 170 | if 'enhance' in self.de_type: 171 | self._init_enhance_ids() 172 | 173 | random.shuffle(self.de_type) 174 | 175 | def _init_clean_ids(self): 176 | ref_file = self.args.data_file_dir + "noisy/denoise.txt" 177 | temp_ids = [] 178 | temp_ids+= [id_.strip() for id_ in open(ref_file)] 179 | clean_ids = [] 180 | name_list = os.listdir(self.args.denoise_dir) 181 | clean_ids += [self.args.denoise_dir + id_ for id_ in name_list if id_.strip() in temp_ids] 182 | 183 | if 'denoise_15' in self.de_type: 184 | self.s15_ids = [{"clean_id": x,"de_type":0} for x in clean_ids] 185 | self.s15_ids = self.s15_ids * 3 186 | random.shuffle(self.s15_ids) 187 | self.s15_counter = 0 188 | if 'denoise_25' in self.de_type: 189 | self.s25_ids = [{"clean_id": x,"de_type":1} for x in clean_ids] 190 | self.s25_ids = self.s25_ids * 3 191 | random.shuffle(self.s25_ids) 192 | self.s25_counter = 0 193 | if 'denoise_50' in self.de_type: 194 | self.s50_ids = [{"clean_id": x,"de_type":2} for x in clean_ids] 195 | self.s50_ids = self.s50_ids * 3 196 | random.shuffle(self.s50_ids) 197 | self.s50_counter = 0 198 | 199 | self.num_clean = len(clean_ids) 200 | print("Total Denoise Ids : {}".format(self.num_clean)) 201 | 202 | def _init_hazy_ids(self): 203 | temp_ids = [] 204 | hazy = self.args.data_file_dir + "hazy/hazy_outside.txt" 205 | temp_ids+= [self.args.dehaze_dir + id_.strip() for id_ in open(hazy)] 206 | self.hazy_ids = [{"clean_id" : x,"de_type":4} for x in temp_ids] 207 | 208 | self.hazy_counter = 0 209 | 210 | self.num_hazy = len(self.hazy_ids) 211 | print("Total Hazy Ids : {}".format(self.num_hazy)) 212 | 213 | def _init_deblur_ids(self): 214 | temp_ids = [] 215 | 216 | image_list = os.listdir(os.path.join(self.args.gopro_dir, 'blur/')) 217 | temp_ids = image_list 218 | self.deblur_ids = [{"clean_id" : x,"de_type":5} for x in temp_ids] 219 | self.deblur_ids = self.deblur_ids * 5 220 | self.deblur_counter = 0 221 | self.num_deblur = len(self.deblur_ids) 222 | print('Total Blur Ids : {}'.format(self.num_deblur)) 223 | 224 | def _init_enhance_ids(self): 225 | temp_ids = [] 226 | image_list = os.listdir(os.path.join(self.args.enhance_dir, 'low/')) 227 | temp_ids = image_list 228 | self.enhance_ids= [{"clean_id" : x,"de_type":6} for x in temp_ids] 229 | self.enhance_ids = self.enhance_ids * 20 230 | self.num_enhance = len(self.enhance_ids) 231 | print('Total enhance Ids : {}'.format(self.num_enhance)) 232 | 233 | 234 | def _init_rs_ids(self): 235 | temp_ids = [] 236 | rs = self.args.data_file_dir + "rainy/rainTrain.txt" 237 | temp_ids+= [self.args.derain_dir + id_.strip() for id_ in open(rs)] 238 | self.rs_ids = [{"clean_id":x,"de_type":3} for x in temp_ids] 239 | self.rs_ids = self.rs_ids * 120 240 | 241 | self.rl_counter = 0 242 | self.num_rl = len(self.rs_ids) 243 | print("Total Rainy Ids : {}".format(self.num_rl)) 244 | 245 | 246 | def _crop_patch(self, img_1, img_2): 247 | H = img_1.shape[0] 248 | W = img_1.shape[1] 249 | ind_H = random.randint(0, H - self.args.patch_size) 250 | ind_W = random.randint(0, W - self.args.patch_size) 251 | 252 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] 253 | patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] 254 | 255 | return patch_1, patch_2 256 | 257 | def _get_gt_name(self, rainy_name): 258 | gt_name = rainy_name.split("rainy")[0] + 'gt/norain-' + rainy_name.split('rain-')[-1] 259 | return gt_name 260 | 261 | def _get_deblur_name(self, deblur_name): 262 | gt_name = deblur_name.replace("blur", "sharp") 263 | return gt_name 264 | 265 | def _get_enhance_name(self, enhance_name): 266 | gt_name = enhance_name.replace("low", "gt") 267 | return gt_name 268 | 269 | def _get_nonhazy_name(self, hazy_name): 270 | dir_name = hazy_name.split("synthetic")[0] + 'original/' 271 | name = hazy_name.split('/')[-1].split('_')[0] 272 | suffix = '.' + hazy_name.split('.')[-1] 273 | nonhazy_name = dir_name + name + suffix 274 | return nonhazy_name 275 | 276 | def _merge_ids(self): 277 | self.sample_ids = [] 278 | if "denoise_15" in self.de_type: 279 | self.sample_ids += self.s15_ids 280 | self.sample_ids += self.s25_ids 281 | self.sample_ids += self.s50_ids 282 | if "derain" in self.de_type: 283 | self.sample_ids+= self.rs_ids 284 | if "dehaze" in self.de_type: 285 | self.sample_ids+= self.hazy_ids 286 | if "deblur" in self.de_type: 287 | self.sample_ids += self.deblur_ids 288 | if "enhance" in self.de_type: 289 | self.sample_ids += self.enhance_ids 290 | 291 | print(len(self.sample_ids)) 292 | 293 | def __getitem__(self, idx): 294 | sample = self.sample_ids[idx] 295 | de_id = sample["de_type"] 296 | 297 | if de_id < 3: 298 | if de_id == 0: 299 | clean_id = sample["clean_id"] 300 | elif de_id == 1: 301 | clean_id = sample["clean_id"] 302 | elif de_id == 2: 303 | clean_id = sample["clean_id"] 304 | 305 | clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16) 306 | clean_patch = self.crop_transform(clean_img) 307 | clean_patch= np.array(clean_patch) 308 | 309 | clean_name = clean_id.split("/")[-1].split('.')[0] 310 | 311 | clean_patch = random_augmentation(clean_patch)[0] 312 | 313 | degrad_patch = self.D.single_degrade(clean_patch, de_id) 314 | else: 315 | if de_id == 3: 316 | # Rain Streak Removal 317 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) 318 | clean_name = self._get_gt_name(sample["clean_id"]) 319 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) 320 | elif de_id == 4: 321 | # Dehazing with SOTS outdoor training set 322 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) 323 | clean_name = self._get_nonhazy_name(sample["clean_id"]) 324 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) 325 | elif de_id == 5: 326 | # Deblur with Gopro set 327 | degrad_img = crop_img(np.array(Image.open(os.path.join(self.args.gopro_dir, 'blur/', sample["clean_id"])).convert('RGB')), base=16) 328 | clean_img = crop_img(np.array(Image.open(os.path.join(self.args.gopro_dir, 'sharp/', sample["clean_id"])).convert('RGB')), base=16) 329 | clean_name = self._get_deblur_name(sample["clean_id"]) 330 | elif de_id == 6: 331 | # Enhancement with LOL training set 332 | degrad_img = crop_img(np.array(Image.open(os.path.join(self.args.enhance_dir, 'low/', sample["clean_id"])).convert('RGB')), base=16) 333 | clean_img = crop_img(np.array(Image.open(os.path.join(self.args.enhance_dir, 'gt/', sample["clean_id"])).convert('RGB')), base=16) 334 | clean_name = self._get_enhance_name(sample["clean_id"]) 335 | 336 | 337 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img)) 338 | 339 | clean_patch = self.toTensor(clean_patch) 340 | degrad_patch = self.toTensor(degrad_patch) 341 | 342 | 343 | return [clean_name, de_id], degrad_patch, clean_patch 344 | 345 | def __len__(self): 346 | return len(self.sample_ids) 347 | 348 | 349 | class AnyDnTestDataset(Dataset): 350 | def __init__(self, args): 351 | super(AnyDnTestDataset, self).__init__() 352 | self.args = args 353 | self.clean_ids = [] 354 | self.sigma = 15 355 | 356 | self._init_clean_ids() 357 | 358 | self.toTensor = ToTensor() 359 | 360 | def _init_clean_ids(self): 361 | name_list = os.listdir(self.args.denoise_path) 362 | self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list] 363 | 364 | self.num_clean = len(self.clean_ids) 365 | 366 | def _add_gaussian_noise(self, clean_patch): 367 | noise = np.random.randn(*clean_patch.shape) 368 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) 369 | return noisy_patch, clean_patch 370 | 371 | def set_sigma(self, sigma): 372 | self.sigma = sigma 373 | 374 | def __getitem__(self, clean_id): 375 | clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16) 376 | clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0] 377 | 378 | noisy_img, _ = self._add_gaussian_noise(clean_img) 379 | clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img) 380 | 381 | return [clean_name], noisy_img, clean_img 382 | def tile_degrad(input_,tile=128,tile_overlap =0): 383 | sigma_dict = {0:0,1:15,2:25,3:50} 384 | b, c, h, w = input_.shape 385 | tile = min(tile, h, w) 386 | assert tile % 8 == 0, "tile size should be multiple of 8" 387 | 388 | stride = tile - tile_overlap 389 | h_idx_list = list(range(0, h-tile, stride)) + [h-tile] 390 | w_idx_list = list(range(0, w-tile, stride)) + [w-tile] 391 | E = torch.zeros(b, c, h, w).type_as(input_) 392 | W = torch.zeros_like(E) 393 | s = 0 394 | for h_idx in h_idx_list: 395 | for w_idx in w_idx_list: 396 | in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile] 397 | out_patch = in_patch 398 | out_patch_mask = torch.ones_like(in_patch) 399 | 400 | E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch) 401 | W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask) 402 | 403 | restored = torch.clamp(restored, 0, 1) 404 | return restored 405 | def __len__(self): 406 | return self.num_clean 407 | 408 | 409 | class AnyIRTestDataset(Dataset): 410 | def __init__(self, args, task="derain",addnoise = False,sigma = None): 411 | super(AnyIRTestDataset, self).__init__() 412 | self.ids = [] 413 | self.task_idx = 0 414 | self.args = args 415 | 416 | self.task_dict = {'derain': 0, 'dehaze': 1, 'deblur': 2, 'enhance': 3} 417 | self.toTensor = ToTensor() 418 | self.addnoise = addnoise 419 | self.sigma = sigma 420 | 421 | self.set_dataset(task) 422 | def _add_gaussian_noise(self, clean_patch): 423 | noise = np.random.randn(*clean_patch.shape) 424 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) 425 | return noisy_patch, clean_patch 426 | 427 | def _init_input_ids(self): 428 | if self.task_idx == 0: 429 | self.ids = [] 430 | name_list = os.listdir(self.args.derain_path + 'input/') 431 | # print(name_list) 432 | print(self.args.derain_path) 433 | self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list] 434 | elif self.task_idx == 1: 435 | self.ids = [] 436 | name_list = os.listdir(self.args.dehaze_path + 'input/') 437 | self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list] 438 | elif self.task_idx == 2: 439 | self.ids = [] 440 | name_list = os.listdir(self.args.gopro_path +'input/') 441 | self.ids += [self.args.gopro_path + 'input/' + id_ for id_ in name_list] 442 | elif self.task_idx == 3: 443 | self.ids = [] 444 | name_list = os.listdir(self.args.enhance_path + 'input/') 445 | self.ids += [self.args.enhance_path + 'input/' + id_ for id_ in name_list] 446 | 447 | self.length = len(self.ids) 448 | 449 | def _get_gt_path(self, degraded_name): 450 | if self.task_idx == 0: 451 | dir_name = degraded_name.split("input")[0] + 'target/' 452 | name = 'no' + degraded_name.split('/')[-1].split('_')[0] 453 | gt_name = dir_name + name 454 | elif self.task_idx == 1: 455 | dir_name = degraded_name.split("input")[0] + 'target/' 456 | name = degraded_name.split('/')[-1].split('_')[0] + '.png' 457 | gt_name = dir_name + name 458 | elif self.task_idx == 2: 459 | gt_name = degraded_name.replace("input", "target") 460 | 461 | elif self.task_idx == 3: 462 | gt_name = degraded_name.replace("input", "target") 463 | 464 | return gt_name 465 | 466 | def set_dataset(self, task): 467 | self.task_idx = self.task_dict[task] 468 | self._init_input_ids() 469 | 470 | def __getitem__(self, idx): 471 | degraded_path = self.ids[idx] 472 | clean_path = self._get_gt_path(degraded_path) 473 | 474 | degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16) 475 | if self.addnoise: 476 | degraded_img,_ = self._add_gaussian_noise(degraded_img) 477 | clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16) 478 | 479 | clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) 480 | degraded_name = degraded_path.split('/')[-1][:-4] 481 | 482 | return [degraded_name], degraded_img, clean_img 483 | 484 | def __len__(self): 485 | return self.length 486 | 487 | 488 | class TestSpecificDataset(Dataset): 489 | def __init__(self, args): 490 | super(TestSpecificDataset, self).__init__() 491 | self.args = args 492 | self.degraded_ids = [] 493 | self._init_clean_ids(args.test_path) 494 | 495 | self.toTensor = ToTensor() 496 | 497 | def _init_clean_ids(self, root): 498 | extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP'] 499 | if os.path.isdir(root): 500 | name_list = [] 501 | for image_file in os.listdir(root): 502 | if any([image_file.endswith(ext) for ext in extensions]): 503 | name_list.append(image_file) 504 | if len(name_list) == 0: 505 | raise Exception('The input directory does not contain any image files') 506 | self.degraded_ids += [root + id_ for id_ in name_list] 507 | else: 508 | if any([root.endswith(ext) for ext in extensions]): 509 | name_list = [root] 510 | else: 511 | raise Exception('Please pass an Image file') 512 | self.degraded_ids = name_list 513 | print("Total Images : {}".format(name_list)) 514 | 515 | self.num_img = len(self.degraded_ids) 516 | 517 | def __getitem__(self, idx): 518 | degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16) 519 | name = self.degraded_ids[idx].split('/')[-1][:-4] 520 | 521 | degraded_img = self.toTensor(degraded_img) 522 | 523 | return [name], degraded_img 524 | 525 | def __len__(self): 526 | return self.num_img --------------------------------------------------------------------------------