├── utils
├── __init__.py
├── degradation_utils.py
├── val_utils.py
├── pytorch_ssim
│ └── __init__.py
├── loss_utils.py
├── image_utils.py
├── image_io.py
├── imresize.py
├── schedulers.py
└── dataset_utils.py
├── requiements.txt
├── .gitignore
└── README.md
/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/requiements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu118
2 |
3 | einops==0.8.0
4 | fvcore
5 | lightning==2.0.1
6 | lightning-cloud==0.5.32
7 | lightning-utilities==0.8.0
8 | matplotlib
9 | ninja==1.11.1
10 | numpy==1.26.3
11 | opencv-python==4.7.0.68
12 | pandas==1.5.3
13 | pydantic==1.10.7
14 | scikit-image
15 | scikit-video
16 | scikit-learn
17 | seaborn
18 | tensorboard
19 | timm
20 | torch==2.0.0
21 | torchaudio==2.0.1
22 | torchvision==0.15.1
23 | torchmetrics
24 | triton==2.0.0
25 | tqdm
26 | wandb==0.18.7
27 | warmup_scheduler==0.3
28 | lpips
29 | typeguard
30 | gdown
--------------------------------------------------------------------------------
/utils/degradation_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor, Grayscale
3 |
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 |
8 | from utils.image_utils import crop_img
9 |
10 |
11 | class Degradation(object):
12 | def __init__(self, args):
13 | super(Degradation, self).__init__()
14 | self.args = args
15 | self.toTensor = ToTensor()
16 | self.crop_transform = Compose([
17 | ToPILImage(),
18 | RandomCrop(args.patch_size),
19 | ])
20 |
21 | def _add_gaussian_noise(self, clean_patch, sigma):
22 | # noise = torch.randn(*(clean_patch.shape))
23 | # clean_patch = self.toTensor(clean_patch)
24 | noise = np.random.randn(*clean_patch.shape)
25 | noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)
26 | # noisy_patch = torch.clamp(clean_patch + noise * sigma, 0, 255).type(torch.int32)
27 | return noisy_patch, clean_patch
28 |
29 | def _degrade_by_type(self, clean_patch, degrade_type):
30 | if degrade_type == 0:
31 | # denoise sigma=15
32 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)
33 | elif degrade_type == 1:
34 | # denoise sigma=25
35 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)
36 | elif degrade_type == 2:
37 | # denoise sigma=50
38 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)
39 |
40 | return degraded_patch, clean_patch
41 |
42 | def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None):
43 | if degrade_type == None:
44 | degrade_type = random.randint(0, 3)
45 | else:
46 | degrade_type = degrade_type
47 |
48 | degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type)
49 | degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type)
50 | return degrad_patch_1, degrad_patch_2
51 |
52 | def single_degrade(self,clean_patch,degrade_type = None):
53 | if degrade_type == None:
54 | degrade_type = random.randint(0, 3)
55 | else:
56 | degrade_type = degrade_type
57 |
58 | degrad_patch_1, _ = self._degrade_by_type(clean_patch, degrade_type)
59 | return degrad_patch_1
60 |
--------------------------------------------------------------------------------
/utils/val_utils.py:
--------------------------------------------------------------------------------
1 |
2 | import time
3 | import numpy as np
4 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
5 | from skvideo.measure import niqe
6 |
7 |
8 | class AverageMeter():
9 | """ Computes and stores the average and current value """
10 |
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | """ Reset all statistics """
16 | self.val = 0
17 | self.avg = 0
18 | self.sum = 0
19 | self.count = 0
20 |
21 | def update(self, val, n=1):
22 | """ Update statistics """
23 | self.val = val
24 | self.sum += val * n
25 | self.count += n
26 | self.avg = self.sum / self.count
27 |
28 |
29 | def accuracy(output, target, topk=(1,)):
30 | """ Computes the precision@k for the specified values of k """
31 | maxk = max(topk)
32 | batch_size = target.size(0)
33 |
34 | _, pred = output.topk(maxk, 1, True, True)
35 | pred = pred.t()
36 | # one-hot case
37 | if target.ndimension() > 1:
38 | target = target.max(1)[1]
39 |
40 | correct = pred.eq(target.view(1, -1).expand_as(pred))
41 |
42 | res = []
43 | for k in topk:
44 | correct_k = correct[:k].view(-1).float().sum(0)
45 | res.append(correct_k.mul_(1.0 / batch_size))
46 |
47 | return res
48 |
49 |
50 | def compute_psnr_ssim(recoverd, clean):
51 | assert recoverd.shape == clean.shape
52 | recoverd = np.clip(recoverd.detach().cpu().numpy(), 0, 1)
53 | clean = np.clip(clean.detach().cpu().numpy(), 0, 1)
54 |
55 | recoverd = recoverd.transpose(0, 2, 3, 1)
56 | clean = clean.transpose(0, 2, 3, 1)
57 | psnr = 0
58 | ssim = 0
59 |
60 | for i in range(recoverd.shape[0]):
61 | # psnr_val += compare_psnr(clean[i], recoverd[i])
62 | # ssim += compare_ssim(clean[i], recoverd[i], multichannel=True)
63 | psnr += peak_signal_noise_ratio(clean[i], recoverd[i], data_range=1)
64 | # ssim += structural_similarity(clean[i], recoverd[i], data_range=1, multichannel=True)
65 | ssim += structural_similarity(clean[i], recoverd[i], data_range=1, channel_axis=2)
66 | return psnr / recoverd.shape[0], ssim / recoverd.shape[0], recoverd.shape[0]
67 |
68 |
69 | def compute_niqe(image):
70 | image = np.clip(image.detach().cpu().numpy(), 0, 1)
71 | image = image.transpose(0, 2, 3, 1)
72 | niqe_val = niqe(image)
73 |
74 | return niqe_val.mean()
75 |
76 | class timer():
77 | def __init__(self):
78 | self.acc = 0
79 | self.tic()
80 |
81 | def tic(self):
82 | self.t0 = time.time()
83 |
84 | def toc(self):
85 | return time.time() - self.t0
86 |
87 | def hold(self):
88 | self.acc += self.toc()
89 |
90 | def release(self):
91 | ret = self.acc
92 | self.acc = 0
93 |
94 | return ret
95 |
96 | def reset(self):
97 | self.acc = 0
--------------------------------------------------------------------------------
/utils/pytorch_ssim/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 | # Matlab style 1D gaussian filter.
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
10 | return gauss/gauss.sum()
11 |
12 | # Matlab style n_D gaussian filter.
13 | def create_window(window_size, channel):
14 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
15 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
16 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
17 | return window
18 |
19 | def _ssim(img1, img2, window, window_size, channel, size_average = True):
20 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel)
21 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel)
22 |
23 | mu1_sq = mu1.pow(2)
24 | mu2_sq = mu2.pow(2)
25 | mu1_mu2 = mu1*mu2
26 |
27 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq
28 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq
29 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2
30 |
31 | C1 = 0.01**2
32 | C2 = 0.03**2
33 |
34 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2))
35 |
36 |
37 | # I added this for sm
38 | # ssim_map = torch.exp(1 + ssim_map)
39 |
40 | if size_average:
41 | return ssim_map.mean()
42 | else:
43 | return ssim_map.mean(1).mean(1).mean(1)
44 |
45 | class SSIM(torch.nn.Module):
46 | def __init__(self, window_size = 11, size_average = True):
47 | super(SSIM, self).__init__()
48 | self.window_size = window_size
49 | self.size_average = size_average
50 | self.channel = 1
51 | self.window = create_window(window_size, self.channel)
52 |
53 | def forward(self, img1, img2):
54 | (_, channel, _, _) = img1.size()
55 |
56 | if channel == self.channel and self.window.data.type() == img1.data.type():
57 | window = self.window
58 | else:
59 | window = create_window(self.window_size, channel)
60 |
61 | if img1.is_cuda:
62 | window = window.cuda(img1.get_device())
63 | window = window.type_as(img1)
64 |
65 | self.window = window
66 | self.channel = channel
67 |
68 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
69 |
70 | def ssim(img1, img2, window_size = 11, size_average = True):
71 | (_, channel, _, _) = img1.size()
72 | window = create_window(window_size, channel)
73 |
74 | if img1.is_cuda:
75 | window = window.cuda(img1.get_device())
76 | window = window.type_as(img1)
77 |
78 | return _ssim(img1, img2, window, window_size, channel, size_average)
79 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | /wandb
86 | /logs
87 | /train_ckpt
88 | /joblogs
89 | /output
90 | # pyenv
91 | # For a library or package, you might want to ignore these files since the code is
92 | # intended to run in multiple environments; otherwise, check them in:
93 | # .python-version
94 |
95 | # pipenv
96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
99 | # install all needed dependencies.
100 | #Pipfile.lock
101 |
102 | # poetry
103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104 | # This is especially recommended for binary packages to ensure reproducibility, and is more
105 | # commonly ignored for libraries.
106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107 | #poetry.lock
108 |
109 | # pdm
110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111 | #pdm.lock
112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113 | # in version control.
114 | # https://pdm.fming.dev/#use-with-ide
115 | .pdm.toml
116 |
117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118 | __pypackages__/
119 |
120 | # Celery stuff
121 | celerybeat-schedule
122 | celerybeat.pid
123 |
124 | # SageMath parsed files
125 | *.sage.py
126 |
127 | # Environments
128 | .env
129 | .venv
130 | env/
131 | venv/
132 | ENV/
133 | env.bak/
134 | venv.bak/
135 |
136 | # Spyder project settings
137 | .spyderproject
138 | .spyproject
139 |
140 | # Rope project settings
141 | .ropeproject
142 |
143 | # mkdocs documentation
144 | /site
145 |
146 | # mypy
147 | .mypy_cache/
148 | .dmypy.json
149 | dmypy.json
150 |
151 | # Pyre type checker
152 | .pyre/
153 |
154 | # pytype static type analyzer
155 | .pytype/
156 |
157 | # Cython debug symbols
158 | cython_debug/
159 |
160 | # PyCharm
161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163 | # and can be added to the global gitignore or merged into this file. For a more nuclear
164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165 | #.idea/
166 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AnyIR
2 | ### Any Image Restoration via Efficient Spatial-Frequency Degradation Adaptation
3 |
4 | The official PyTorch Implementation of AnyIR for All-in-One Image Restoration
5 |
6 | #### [Bin Ren 1,2,3](https://amazingren.github.io/), [Eduard Zamfir4](https://eduardzamfir.github.io), [Zongwei Wu4](https://sites.google.com/view/zwwu/accueil), [Yawei Li 4](https://yaweili.bitbucket.io/), [Yidi Li3](https://liyidi.github.io/), [Danda Pani Paudel3](https://people.ee.ethz.ch/~paudeld/), [Radu Timofte 4](https://www.informatik.uni-wuerzburg.de/computervision/), [Ming-Hsuan Yang 7](https://scholar.google.com/citations?user=p9-ohHsAAAAJ&hl=en), [Luc Van Gool 3](https://scholar.google.com/citations?user=TwMib_QAAAAJ&hl=en), and [Nicu Sebe 2](https://scholar.google.com/citations?user=stFCYOAAAAAJ&hl=en)
7 |
8 | 1 University of Pisa, Italy,
9 | 2 University of Trento, Italy,
10 | 3 INSAIT Sofia University, "St. Kliment Ohridski", Bulgaria,
11 | 4 University of Würzburg, Germany,
12 | 5 ETH Zürich, Switzerland,
13 | 6 Taiyuan University of Technology, China,
14 | 7 University of California, Merced, USA
15 |
16 |
17 | [](https://arxiv.org/pdf/2407.13372)
18 |
19 |
20 | ## Latest
21 | - `08/01/2025`: Unfortunately, though this work was recommented `Accept` by the AC from the ACM MM2025, while the PC finally reject this work without reason, so this work is still under review.
22 | - `07/18/2024`: Repository is created. Our code will be made publicly available upon acceptance.
23 |
24 |
25 | ## Method
26 |
27 |
28 |
29 | Abstract
30 |
31 | Restoring any degraded image efficiently via just one model has become increasingly significant and impactful, especially with the proliferation of mobile devices. Traditional solutions typically involve training dedicated models per degradation, resulting in inefficiency and redundancy. More recent approaches either introduce additional modules to learn visual prompts - significantly increasing the size of the model - or incorporate cross-modal transfer from large language models trained on vast datasets, adding complexity to the system architecture. In contrast, our approach, termed AnyIR, takes a unified path that leverages inherent similarity across various degradations to enable both efficient and comprehensive restoration through a joint embedding mechanism, without scaling up the model or relying on large language models.
32 | Specifically, we examine the sub-latent space of each input, identifying key components and reweighting them first in a gated manner. To fuse intrinsic degradation awareness and contextualized attention, a spatial-frequency parallel fusion strategy is proposed to enhance spatially aware local-global interactions and enrich restoration details from the frequency perspective. Extensive benchmarking in the all-in-one restoration setting confirms AnyIR’s SOTA performance, reducing model complexity by around \textbf{82\%} in parameters and \textbf{85\%} in FLOPs compared to the baseline solution.
33 | Our code will be available upon acceptance.
34 |
35 |
36 |
37 | ## Installation
38 |
39 | ### Environments
40 | ```
41 | # Step1: Create the virtual environments via micromamba or conda:
42 | micromamba create -n anyir python=3.9 -y
43 | or
44 | conda create -n anyir python=3.9 -y
45 |
46 | # Step2: Prepare PyTorch and other libs
47 | pip install -r requirements.txt
48 |
49 | # Step3: Set cuda
50 | export LD_LIBRARY_PATH=/opt/modules/nvidia-cuda-11.8/lib64:$LD_LIBRARY_PATH
51 | export PATH=/opt/modules/nvidia-cuda-11.8/bin:$PATH
52 |
53 | ```
54 |
55 |
56 | ### Datasets
57 |
58 |
59 | ### Checkpoints Downloads:
60 |
61 |
62 | ### Visual Results Downloads:
63 |
64 |
65 | ### Training
66 |
67 |
68 | ### Evaluation:
69 | (I). 3-Degradation Setting:
70 |
71 | (II). 5-Degradation Setting:
72 |
73 | (III). Mix-Degradation Setting:
74 |
75 | (IV). Real-World Setting:
76 |
77 |
78 |
79 |
80 |
81 | ## Citation
82 |
83 | If you find our work helpful, please consider citing the following paper and/or ⭐ the repo.
84 | ```
85 | @misc{ren2025any,
86 | title={Any Image Restoration via Efficient Spatial-Frequency Degradation Adaptation},
87 | author={Ren, Bin and Zamfir, Eduard and Wu, Zongwei and Li, Yawei and Li, Yidi and Paudel, Danda Pani and Timofte, Radu and Yang, Ming-Hsuan and Van Gool, Luc and Sebe, Nicu},
88 | year={2025},
89 | eprint={2504.14249},
90 | archivePrefix={arXiv},
91 | primaryClass={cs.CV}
92 | }
93 | ```
94 |
95 |
96 | ## Acknowledgements
97 |
98 | This code is built on [PromptIR](https://github.com/va1shn9v/PromptIR) and [AirNet](https://github.com/XLearning-SCU/2022-CVPR-AirNet).
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | import math
2 |
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn.functional import mse_loss
6 |
7 |
8 | class GANLoss(nn.Module):
9 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
10 | tensor=torch.FloatTensor):
11 | super(GANLoss, self).__init__()
12 | self.real_label = target_real_label
13 | self.fake_label = target_fake_label
14 | self.real_label_var = None
15 | self.fake_label_var = None
16 | self.Tensor = tensor
17 | if use_lsgan:
18 | self.loss = nn.MSELoss()
19 | else:
20 | self.loss = nn.BCELoss()
21 |
22 | def get_target_tensor(self, input, target_is_real):
23 | target_tensor = None
24 | if target_is_real:
25 | create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel()))
26 | # pdb.set_trace()
27 | if create_label:
28 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
29 | # self.real_label_var = Variable(real_tensor, requires_grad=False)
30 | # self.real_label_var = torch.Tensor(real_tensor)
31 | self.real_label_var = real_tensor
32 | target_tensor = self.real_label_var
33 | else:
34 | # pdb.set_trace()
35 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel()))
36 | if create_label:
37 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
38 | # self.fake_label_var = Variable(fake_tensor, requires_grad=False)
39 | # self.fake_label_var = torch.Tensor(fake_tensor)
40 | self.fake_label_var = fake_tensor
41 | target_tensor = self.fake_label_var
42 | return target_tensor
43 |
44 | def __call__(self, input, target_is_real):
45 | target_tensor = self.get_target_tensor(input, target_is_real)
46 | # pdb.set_trace()
47 | return self.loss(input, target_tensor)
48 |
49 |
50 |
51 | class FocalL1Loss(nn.Module):
52 | def __init__(self, gamma=2.0, epsilon=1e-6, alpha=0.1):
53 | """
54 | Focal L1 Loss with adjusted weighting for output values in [0, 1].
55 |
56 | Args:
57 | gamma (float): Focusing parameter. Larger gamma focuses more on hard examples.
58 | epsilon (float): Small constant to prevent weights from being zero.
59 | alpha (float): Scaling factor to normalize error values.
60 | """
61 | super(FocalL1Loss, self).__init__()
62 | self.gamma = gamma
63 | self.epsilon = epsilon
64 | self.alpha = alpha # Scaling factor to prevent error values from being too small.
65 |
66 | def forward(self, pred, target):
67 | """
68 | Compute the Focal L1 Loss between the predicted and target images.
69 |
70 | Args:
71 | pred (torch.Tensor): Predicted image [b, c, h, w].
72 | target (torch.Tensor): Ground truth image [b, c, h, w].
73 |
74 | Returns:
75 | torch.Tensor: Scalar Focal L1 Loss.
76 | """
77 | # Compute the absolute error (L1 Loss) and scale it by alpha
78 | abs_err = torch.abs(pred - target) / self.alpha
79 |
80 | # Apply a logarithmic transformation to the error to prevent very small weights
81 | focal_weight = (torch.log(1 + abs_err + self.epsilon)) ** self.gamma
82 |
83 | # Compute the weighted loss
84 | focal_l1_loss = focal_weight * abs_err
85 |
86 | # Return the mean loss across all pixels
87 | return focal_l1_loss.mean()
88 |
89 |
90 | class FFTLoss(nn.Module):
91 | def __init__(self, loss_weight=1.0, reduction='mean'):
92 | super(FFTLoss, self).__init__()
93 | self.loss_weight = loss_weight
94 | self.criterion = torch.nn.L1Loss(reduction=reduction)
95 |
96 | def forward(self, pred, target):
97 | pred_fft = torch.fft.rfft2(pred)
98 | target_fft = torch.fft.rfft2(target)
99 |
100 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1)
101 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1)
102 |
103 | return self.loss_weight * self.criterion(pred_fft, target_fft)
104 |
105 |
106 | class TemperatureScheduler:
107 | def __init__(self, start_temp, end_temp, total_steps):
108 | """
109 | Scheduler for Gumbel-Softmax temperature that decreases using a cosine annealing schedule.
110 |
111 | Args:
112 | - start_temp (float): Initial temperature (e.g., 5.0).
113 | - end_temp (float): Final temperature (e.g., 0.01).
114 | - total_steps (int): Total number of steps/epochs to anneal over.
115 | """
116 | self.start_temp = start_temp
117 | self.end_temp = end_temp
118 | self.total_steps = total_steps
119 |
120 | def get_temperature(self, step):
121 | """
122 | Get the temperature value for the current step, following a cosine annealing schedule.
123 |
124 | Args:
125 | - step (int): Current step or epoch.
126 |
127 | Returns:
128 | - temperature (float): The temperature for the Gumbel-Softmax at this step.
129 | """
130 | if step >= self.total_steps:
131 | return self.end_temp
132 |
133 | # Cosine annealing formula to compute the temperature
134 | cos_inner = math.pi * step / self.total_steps
135 | #temp = self.end_temp + 0.5 * (self.start_temp - self.end_temp) * (1 + math.cos(cos_inner))
136 | temp = self.start_temp + 0.5 * (self.end_temp - self.start_temp) * (1 - math.cos(cos_inner))
137 |
138 | return temp
139 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on 2020/9/8
3 |
4 | @author: Boyun Li
5 | """
6 | import os
7 | import numpy as np
8 | import torch
9 | import random
10 | import torch.nn as nn
11 | from torch.nn import init
12 | from PIL import Image
13 |
14 | class EdgeComputation(nn.Module):
15 | def __init__(self, test=False):
16 | super(EdgeComputation, self).__init__()
17 | self.test = test
18 | def forward(self, x):
19 | if self.test:
20 | x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
21 | x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
22 |
23 | # y = torch.Tensor(x.size()).cuda()
24 | y = torch.Tensor(x.size())
25 | y.fill_(0)
26 | y[:, :, :, 1:] += x_diffx
27 | y[:, :, :, :-1] += x_diffx
28 | y[:, :, 1:, :] += x_diffy
29 | y[:, :, :-1, :] += x_diffy
30 | y = torch.sum(y, 1, keepdim=True) / 3
31 | y /= 4
32 | return y
33 | else:
34 | x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1])
35 | x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :])
36 |
37 | y = torch.Tensor(x.size())
38 | y.fill_(0)
39 | y[:, :, 1:] += x_diffx
40 | y[:, :, :-1] += x_diffx
41 | y[:, 1:, :] += x_diffy
42 | y[:, :-1, :] += x_diffy
43 | y = torch.sum(y, 0) / 3
44 | y /= 4
45 | return y.unsqueeze(0)
46 |
47 |
48 | # randomly crop a patch from image
49 | def crop_patch(im, pch_size):
50 | H = im.shape[0]
51 | W = im.shape[1]
52 | ind_H = random.randint(0, H - pch_size)
53 | ind_W = random.randint(0, W - pch_size)
54 | pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size]
55 | return pch
56 |
57 |
58 | # crop an image to the multiple of base
59 | def crop_img(image, base=64):
60 | h = image.shape[0]
61 | w = image.shape[1]
62 | crop_h = h % base
63 | crop_w = w % base
64 | return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
65 |
66 |
67 | # image (H, W, C) -> patches (B, H, W, C)
68 | def slice_image2patches(image, patch_size=64, overlap=0):
69 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0
70 | H = image.shape[0]
71 | W = image.shape[1]
72 | patches = []
73 | image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge')
74 | for h in range(H // patch_size):
75 | for w in range(W // patch_size):
76 | idx_h = [h * patch_size, (h + 1) * patch_size + overlap]
77 | idx_w = [w * patch_size, (w + 1) * patch_size + overlap]
78 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0))
79 | return np.concatenate(patches, axis=0)
80 |
81 |
82 | # patches (B, H, W, C) -> image (H, W, C)
83 | def splice_patches2image(patches, image_size, overlap=0):
84 | assert len(image_size) > 1
85 | assert patches.shape[-3] == patches.shape[-2]
86 | H = image_size[0]
87 | W = image_size[1]
88 | patch_size = patches.shape[-2] - overlap
89 | image = np.zeros(image_size)
90 | idx = 0
91 | for h in range(H // patch_size):
92 | for w in range(W // patch_size):
93 | image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx,
94 | overlap:patch_size + overlap,
95 | overlap:patch_size + overlap,
96 | :]
97 | idx += 1
98 | return image
99 |
100 |
101 | # def data_augmentation(image, mode):
102 | # if mode == 0:
103 | # # original
104 | # out = image.numpy()
105 | # elif mode == 1:
106 | # # flip up and down
107 | # out = np.flipud(image)
108 | # elif mode == 2:
109 | # # rotate counterwise 90 degree
110 | # out = np.rot90(image, axes=(1, 2))
111 | # elif mode == 3:
112 | # # rotate 90 degree and flip up and down
113 | # out = np.rot90(image, axes=(1, 2))
114 | # out = np.flipud(out)
115 | # elif mode == 4:
116 | # # rotate 180 degree
117 | # out = np.rot90(image, k=2, axes=(1, 2))
118 | # elif mode == 5:
119 | # # rotate 180 degree and flip
120 | # out = np.rot90(image, k=2, axes=(1, 2))
121 | # out = np.flipud(out)
122 | # elif mode == 6:
123 | # # rotate 270 degree
124 | # out = np.rot90(image, k=3, axes=(1, 2))
125 | # elif mode == 7:
126 | # # rotate 270 degree and flip
127 | # out = np.rot90(image, k=3, axes=(1, 2))
128 | # out = np.flipud(out)
129 | # else:
130 | # raise Exception('Invalid choice of image transformation')
131 | # return out
132 |
133 | def data_augmentation(image, mode):
134 | if mode == 0:
135 | # original
136 | out = image.numpy()
137 | elif mode == 1:
138 | # flip up and down
139 | out = np.flipud(image)
140 | elif mode == 2:
141 | # rotate counterwise 90 degree
142 | out = np.rot90(image)
143 | elif mode == 3:
144 | # rotate 90 degree and flip up and down
145 | out = np.rot90(image)
146 | out = np.flipud(out)
147 | elif mode == 4:
148 | # rotate 180 degree
149 | out = np.rot90(image, k=2)
150 | elif mode == 5:
151 | # rotate 180 degree and flip
152 | out = np.rot90(image, k=2)
153 | out = np.flipud(out)
154 | elif mode == 6:
155 | # rotate 270 degree
156 | out = np.rot90(image, k=3)
157 | elif mode == 7:
158 | # rotate 270 degree and flip
159 | out = np.rot90(image, k=3)
160 | out = np.flipud(out)
161 | else:
162 | raise Exception('Invalid choice of image transformation')
163 | return out
164 |
165 |
166 | # def random_augmentation(*args):
167 | # out = []
168 | # if random.randint(0, 1) == 1:
169 | # flag_aug = random.randint(1, 7)
170 | # for data in args:
171 | # out.append(data_augmentation(data, flag_aug).copy())
172 | # else:
173 | # for data in args:
174 | # out.append(data)
175 | # return out
176 |
177 | def random_augmentation(*args):
178 | out = []
179 | flag_aug = random.randint(1, 7)
180 | for data in args:
181 | out.append(data_augmentation(data, flag_aug).copy())
182 | return out
183 |
184 |
185 | def weights_init_normal_(m):
186 | classname = m.__class__.__name__
187 | if classname.find('Conv') != -1:
188 | init.uniform(m.weight.data, 0.0, 0.02)
189 | elif classname.find('Linear') != -1:
190 | init.uniform(m.weight.data, 0.0, 0.02)
191 | elif classname.find('BatchNorm2d') != -1:
192 | init.uniform(m.weight.data, 1.0, 0.02)
193 | init.constant(m.bias.data, 0.0)
194 |
195 |
196 | def weights_init_normal(m):
197 | classname = m.__class__.__name__
198 | if classname.find('Conv2d') != -1:
199 | m.apply(weights_init_normal_)
200 | elif classname.find('Linear') != -1:
201 | init.uniform(m.weight.data, 0.0, 0.02)
202 | elif classname.find('BatchNorm2d') != -1:
203 | init.uniform(m.weight.data, 1.0, 0.02)
204 | init.constant(m.bias.data, 0.0)
205 |
206 |
207 | def weights_init_xavier(m):
208 | classname = m.__class__.__name__
209 | if classname.find('Conv') != -1:
210 | init.xavier_normal(m.weight.data, gain=1)
211 | elif classname.find('Linear') != -1:
212 | init.xavier_normal(m.weight.data, gain=1)
213 | elif classname.find('BatchNorm2d') != -1:
214 | init.uniform(m.weight.data, 1.0, 0.02)
215 | init.constant(m.bias.data, 0.0)
216 |
217 |
218 | def weights_init_kaiming(m):
219 | classname = m.__class__.__name__
220 | if classname.find('Conv') != -1:
221 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
222 | elif classname.find('Linear') != -1:
223 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
224 | elif classname.find('BatchNorm2d') != -1:
225 | init.uniform(m.weight.data, 1.0, 0.02)
226 | init.constant(m.bias.data, 0.0)
227 |
228 |
229 | def weights_init_orthogonal(m):
230 | classname = m.__class__.__name__
231 | print(classname)
232 | if classname.find('Conv') != -1:
233 | init.orthogonal(m.weight.data, gain=1)
234 | elif classname.find('Linear') != -1:
235 | init.orthogonal(m.weight.data, gain=1)
236 | elif classname.find('BatchNorm2d') != -1:
237 | init.uniform(m.weight.data, 1.0, 0.02)
238 | init.constant(m.bias.data, 0.0)
239 |
240 |
241 | def init_weights(net, init_type='normal'):
242 | print('initialization method [%s]' % init_type)
243 | if init_type == 'normal':
244 | net.apply(weights_init_normal)
245 | elif init_type == 'xavier':
246 | net.apply(weights_init_xavier)
247 | elif init_type == 'kaiming':
248 | net.apply(weights_init_kaiming)
249 | elif init_type == 'orthogonal':
250 | net.apply(weights_init_orthogonal)
251 | else:
252 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
253 |
254 |
255 | def np_to_torch(img_np):
256 | """
257 | Converts image in numpy.array to torch.Tensor.
258 |
259 | From C x W x H [0..1] to C x W x H [0..1]
260 |
261 | :param img_np:
262 | :return:
263 | """
264 | return torch.from_numpy(img_np)[None, :]
265 |
266 |
267 | def torch_to_np(img_var):
268 | """
269 | Converts an image in torch.Tensor format to np.array.
270 |
271 | From 1 x C x W x H [0..1] to C x W x H [0..1]
272 | :param img_var:
273 | :return:
274 | """
275 | return img_var.detach().cpu().numpy()
276 | # return img_var.detach().cpu().numpy()[0]
277 |
278 |
279 | def save_image(name, image_np, output_path="output/normal/"):
280 | if not os.path.exists(output_path):
281 | os.mkdir(output_path)
282 |
283 | p = np_to_pil(image_np)
284 | p.save(output_path + "{}.png".format(name))
285 |
286 |
287 | def np_to_pil(img_np):
288 | """
289 | Converts image in np.array format to PIL image.
290 |
291 | From C x W x H [0..1] to W x H x C [0...255]
292 | :param img_np:
293 | :return:
294 | """
295 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
296 |
297 | if img_np.shape[0] == 1:
298 | ar = ar[0]
299 | else:
300 | assert img_np.shape[0] == 3, img_np.shape
301 | ar = ar.transpose(1, 2, 0)
302 |
303 | return Image.fromarray(ar)
--------------------------------------------------------------------------------
/utils/image_io.py:
--------------------------------------------------------------------------------
1 | import glob
2 |
3 | import torch
4 | import torchvision
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from PIL import Image
9 |
10 | # import skvideo.io
11 |
12 | matplotlib.use('agg')
13 |
14 |
15 | def prepare_hazy_image(file_name):
16 | img_pil = crop_image(get_image(file_name, -1)[0], d=32)
17 | return pil_to_np(img_pil)
18 |
19 |
20 | def prepare_gt_img(file_name, SOTS=True):
21 | if SOTS:
22 | img_pil = crop_image(crop_a_image(get_image(file_name, -1)[0], d=10), d=32)
23 | else:
24 | img_pil = crop_image(get_image(file_name, -1)[0], d=32)
25 |
26 | return pil_to_np(img_pil)
27 |
28 |
29 | def crop_a_image(img, d=10):
30 | bbox = [
31 | int((d)),
32 | int((d)),
33 | int((img.size[0] - d)),
34 | int((img.size[1] - d)),
35 | ]
36 | img_cropped = img.crop(bbox)
37 | return img_cropped
38 |
39 |
40 | def crop_image(img, d=32):
41 | """
42 | Make dimensions divisible by d
43 |
44 | :param pil img:
45 | :param d:
46 | :return:
47 | """
48 |
49 | new_size = (img.size[0] - img.size[0] % d,
50 | img.size[1] - img.size[1] % d)
51 |
52 | bbox = [
53 | int((img.size[0] - new_size[0]) / 2),
54 | int((img.size[1] - new_size[1]) / 2),
55 | int((img.size[0] + new_size[0]) / 2),
56 | int((img.size[1] + new_size[1]) / 2),
57 | ]
58 |
59 | img_cropped = img.crop(bbox)
60 | return img_cropped
61 |
62 |
63 | def crop_np_image(img_np, d=32):
64 | return torch_to_np(crop_torch_image(np_to_torch(img_np), d))
65 |
66 |
67 | def crop_torch_image(img, d=32):
68 | """
69 | Make dimensions divisible by d
70 | image is [1, 3, W, H] or [3, W, H]
71 | :param pil img:
72 | :param d:
73 | :return:
74 | """
75 | new_size = (img.shape[-2] - img.shape[-2] % d,
76 | img.shape[-1] - img.shape[-1] % d)
77 | pad = ((img.shape[-2] - new_size[-2]) // 2, (img.shape[-1] - new_size[-1]) // 2)
78 |
79 | if len(img.shape) == 4:
80 | return img[:, :, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]]
81 | assert len(img.shape) == 3
82 | return img[:, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]]
83 |
84 |
85 | def get_params(opt_over, net, net_input, downsampler=None):
86 | """
87 | Returns parameters that we want to optimize over.
88 | :param opt_over: comma separated list, e.g. "net,input" or "net"
89 | :param net: network
90 | :param net_input: torch.Tensor that stores input `z`
91 | :param downsampler:
92 | :return:
93 | """
94 |
95 | opt_over_list = opt_over.split(',')
96 | params = []
97 |
98 | for opt in opt_over_list:
99 |
100 | if opt == 'net':
101 | params += [x for x in net.parameters()]
102 | elif opt == 'down':
103 | assert downsampler is not None
104 | params = [x for x in downsampler.parameters()]
105 | elif opt == 'input':
106 | net_input.requires_grad = True
107 | params += [net_input]
108 | else:
109 | assert False, 'what is it?'
110 |
111 | return params
112 |
113 |
114 | def get_image_grid(images_np, nrow=8):
115 | """
116 | Creates a grid from a list of images by concatenating them.
117 | :param images_np:
118 | :param nrow:
119 | :return:
120 | """
121 | images_torch = [torch.from_numpy(x).type(torch.FloatTensor) for x in images_np]
122 | torch_grid = torchvision.utils.make_grid(images_torch, nrow)
123 |
124 | return torch_grid.numpy()
125 |
126 |
127 | def plot_image_grid(name, images_np, interpolation='lanczos', output_path="output/"):
128 | """
129 | Draws images in a grid
130 |
131 | Args:
132 | images_np: list of images, each image is np.array of size 3xHxW or 1xHxW
133 | nrow: how many images will be in one row
134 | interpolation: interpolation used in plt.imshow
135 | """
136 | assert len(images_np) == 2
137 | n_channels = max(x.shape[0] for x in images_np)
138 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"
139 |
140 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
141 |
142 | grid = get_image_grid(images_np, 2)
143 |
144 | if images_np[0].shape[0] == 1:
145 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
146 | else:
147 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)
148 |
149 | plt.savefig(output_path + "{}.png".format(name))
150 |
151 |
152 | def save_image_np(name, image_np, output_path="output/"):
153 | p = np_to_pil(image_np)
154 | p.save(output_path + "{}.png".format(name))
155 |
156 |
157 | def save_image_tensor(image_tensor, output_path="output/"):
158 | image_np = torch_to_np(image_tensor)
159 | # print(image_np.shape)
160 | p = np_to_pil(image_np)
161 | p.save(output_path)
162 |
163 |
164 | def video_to_images(file_name, name):
165 | video = prepare_video(file_name)
166 | for i, f in enumerate(video):
167 | save_image(name + "_{0:03d}".format(i), f)
168 |
169 |
170 | def images_to_video(images_dir, name, gray=True):
171 | num = len(glob.glob(images_dir + "/*.jpg"))
172 | c = []
173 | for i in range(num):
174 | if gray:
175 | img = prepare_gray_image(images_dir + "/" + name + "_{}.jpg".format(i))
176 | else:
177 | img = prepare_image(images_dir + "/" + name + "_{}.jpg".format(i))
178 | print(img.shape)
179 | c.append(img)
180 | save_video(name, np.array(c))
181 |
182 |
183 | def save_heatmap(name, image_np):
184 | cmap = plt.get_cmap('jet')
185 |
186 | rgba_img = cmap(image_np)
187 | rgb_img = np.delete(rgba_img, 3, 2)
188 | save_image(name, rgb_img.transpose(2, 0, 1))
189 |
190 |
191 | def save_graph(name, graph_list, output_path="output/"):
192 | plt.clf()
193 | plt.plot(graph_list)
194 | plt.savefig(output_path + name + ".png")
195 |
196 |
197 | def create_augmentations(np_image):
198 | """
199 | convention: original, left, upside-down, right, rot1, rot2, rot3
200 | :param np_image:
201 | :return:
202 | """
203 | aug = [np_image.copy(), np.rot90(np_image, 1, (1, 2)).copy(),
204 | np.rot90(np_image, 2, (1, 2)).copy(), np.rot90(np_image, 3, (1, 2)).copy()]
205 | flipped = np_image[:, ::-1, :].copy()
206 | aug += [flipped.copy(), np.rot90(flipped, 1, (1, 2)).copy(), np.rot90(flipped, 2, (1, 2)).copy(),
207 | np.rot90(flipped, 3, (1, 2)).copy()]
208 | return aug
209 |
210 |
211 | def create_video_augmentations(np_video):
212 | """
213 | convention: original, left, upside-down, right, rot1, rot2, rot3
214 | :param np_video:
215 | :return:
216 | """
217 | aug = [np_video.copy(), np.rot90(np_video, 1, (2, 3)).copy(),
218 | np.rot90(np_video, 2, (2, 3)).copy(), np.rot90(np_video, 3, (2, 3)).copy()]
219 | flipped = np_video[:, :, ::-1, :].copy()
220 | aug += [flipped.copy(), np.rot90(flipped, 1, (2, 3)).copy(), np.rot90(flipped, 2, (2, 3)).copy(),
221 | np.rot90(flipped, 3, (2, 3)).copy()]
222 | return aug
223 |
224 |
225 | def save_graphs(name, graph_dict, output_path="output/"):
226 | """
227 |
228 | :param name:
229 | :param dict graph_dict: a dict from the name of the list to the list itself.
230 | :return:
231 | """
232 | plt.clf()
233 | fig, ax = plt.subplots()
234 | for k, v in graph_dict.items():
235 | ax.plot(v, label=k)
236 | # ax.semilogy(v, label=k)
237 | ax.set_xlabel('iterations')
238 | # ax.set_ylabel(name)
239 | ax.set_ylabel('MSE-loss')
240 | # ax.set_ylabel('PSNR')
241 | plt.legend()
242 | plt.savefig(output_path + name + ".png")
243 |
244 |
245 | def load(path):
246 | """Load PIL image."""
247 | img = Image.open(path)
248 | return img
249 |
250 |
251 | def get_image(path, imsize=-1):
252 | """Load an image and resize to a cpecific size.
253 |
254 | Args:
255 | path: path to image
256 | imsize: tuple or scalar with dimensions; -1 for `no resize`
257 | """
258 | img = load(path)
259 | if isinstance(imsize, int):
260 | imsize = (imsize, imsize)
261 |
262 | if imsize[0] != -1 and img.size != imsize:
263 | if imsize[0] > img.size[0]:
264 | img = img.resize(imsize, Image.BICUBIC)
265 | else:
266 | img = img.resize(imsize, Image.ANTIALIAS)
267 |
268 | img_np = pil_to_np(img)
269 | # 3*460*620
270 | # print(np.shape(img_np))
271 |
272 | return img, img_np
273 |
274 |
275 | def prepare_gt(file_name):
276 | """
277 | loads makes it divisible
278 | :param file_name:
279 | :return: the numpy representation of the image
280 | """
281 | img = get_image(file_name, -1)
282 | # print(img[0].size)
283 |
284 | img_pil = img[0].crop([10, 10, img[0].size[0] - 10, img[0].size[1] - 10])
285 |
286 | img_pil = crop_image(img_pil, d=32)
287 |
288 | # img_pil = get_image(file_name, -1)[0]
289 | # print(img_pil.size)
290 | return pil_to_np(img_pil)
291 |
292 |
293 | def prepare_image(file_name):
294 | """
295 | loads makes it divisible
296 | :param file_name:
297 | :return: the numpy representation of the image
298 | """
299 | img = get_image(file_name, -1)
300 | # print(img[0].size)
301 | # img_pil = img[0]
302 | img_pil = crop_image(img[0], d=16)
303 | # img_pil = get_image(file_name, -1)[0]
304 | # print(img_pil.size)
305 | return pil_to_np(img_pil)
306 |
307 |
308 | # def prepare_video(file_name, folder="output/"):
309 | # data = skvideo.io.vread(folder + file_name)
310 | # return crop_torch_image(data.transpose(0, 3, 1, 2).astype(np.float32) / 255.)[:35]
311 | #
312 | #
313 | # def save_video(name, video_np, output_path="output/"):
314 | # outputdata = video_np * 255
315 | # outputdata = outputdata.astype(np.uint8)
316 | # skvideo.io.vwrite(output_path + "{}.mp4".format(name), outputdata.transpose(0, 2, 3, 1))
317 |
318 |
319 | def prepare_gray_image(file_name):
320 | img = prepare_image(file_name)
321 | return np.array([np.mean(img, axis=0)])
322 |
323 |
324 | def pil_to_np(img_PIL, with_transpose=True):
325 | """
326 | Converts image in PIL format to np.array.
327 |
328 | From W x H x C [0...255] to C x W x H [0..1]
329 | """
330 | ar = np.array(img_PIL)
331 | if len(ar.shape) == 3 and ar.shape[-1] == 4:
332 | ar = ar[:, :, :3]
333 | # this is alpha channel
334 | if with_transpose:
335 | if len(ar.shape) == 3:
336 | ar = ar.transpose(2, 0, 1)
337 | else:
338 | ar = ar[None, ...]
339 |
340 | return ar.astype(np.float32) / 255.
341 |
342 |
343 | def median(img_np_list):
344 | """
345 | assumes C x W x H [0..1]
346 | :param img_np_list:
347 | :return:
348 | """
349 | assert len(img_np_list) > 0
350 | l = len(img_np_list)
351 | shape = img_np_list[0].shape
352 | result = np.zeros(shape)
353 | for c in range(shape[0]):
354 | for w in range(shape[1]):
355 | for h in range(shape[2]):
356 | result[c, w, h] = sorted(i[c, w, h] for i in img_np_list)[l // 2]
357 | return result
358 |
359 |
360 | def average(img_np_list):
361 | """
362 | assumes C x W x H [0..1]
363 | :param img_np_list:
364 | :return:
365 | """
366 | assert len(img_np_list) > 0
367 | l = len(img_np_list)
368 | shape = img_np_list[0].shape
369 | result = np.zeros(shape)
370 | for i in img_np_list:
371 | result += i
372 | return result / l
373 |
374 |
375 | def np_to_pil(img_np):
376 | """
377 | Converts image in np.array format to PIL image.
378 |
379 | From C x W x H [0..1] to W x H x C [0...255]
380 | :param img_np:
381 | :return:
382 | """
383 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
384 |
385 | if img_np.shape[0] == 1:
386 | ar = ar[0]
387 | else:
388 | assert img_np.shape[0] == 3, img_np.shape
389 | ar = ar.transpose(1, 2, 0)
390 |
391 | return Image.fromarray(ar)
392 |
393 |
394 | def np_to_torch(img_np):
395 | """
396 | Converts image in numpy.array to torch.Tensor.
397 |
398 | From C x W x H [0..1] to C x W x H [0..1]
399 |
400 | :param img_np:
401 | :return:
402 | """
403 | return torch.from_numpy(img_np)[None, :]
404 |
405 |
406 | def torch_to_np(img_var):
407 | """
408 | Converts an image in torch.Tensor format to np.array.
409 |
410 | From 1 x C x W x H [0..1] to C x W x H [0..1]
411 | :param img_var:
412 | :return:
413 | """
414 | return img_var.detach().cpu().numpy()[0]
415 |
--------------------------------------------------------------------------------
/utils/imresize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage import filters, measurements, interpolation
3 | from math import pi
4 |
5 |
6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False):
7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor)
9 |
10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only)
11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1:
12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag)
13 |
14 | # Choose interpolation method, each method has the matching kernel size
15 | method, kernel_width = {
16 | "cubic": (cubic, 4.0),
17 | "lanczos2": (lanczos2, 4.0),
18 | "lanczos3": (lanczos3, 6.0),
19 | "box": (box, 1.0),
20 | "linear": (linear, 2.0),
21 | None: (cubic, 4.0) # set default interpolation method as cubic
22 | }.get(kernel)
23 |
24 | # Antialiasing is only used when downscaling
25 | antialiasing *= (scale_factor[0] < 1)
26 |
27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist()
29 |
30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
31 | out_im = np.copy(im)
32 | for dim in sorted_dims:
33 | # No point doing calculations for scale-factor 1. nothing will happen anyway
34 | if scale_factor[dim] == 1.0:
35 | continue
36 |
37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
38 | # weights that multiply the values there to get its result.
39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim],
40 | method, kernel_width, antialiasing)
41 |
42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view)
44 |
45 | return out_im
46 |
47 |
48 | def fix_scale_and_size(input_shape, output_shape, scale_factor):
49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
50 | # same size as the number of input dimensions)
51 | if scale_factor is not None:
52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
53 | if np.isscalar(scale_factor):
54 | scale_factor = [scale_factor, scale_factor]
55 |
56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
57 | scale_factor = list(scale_factor)
58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor)))
59 |
60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
61 | # to all the unspecified dimensions
62 | if output_shape is not None:
63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):])
64 |
65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
66 | # sub-optimal, because there can be different scales to the same output-shape.
67 | if scale_factor is None:
68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
69 |
70 | # Dealing with missing output-shape. calculating according to scale-factor
71 | if output_shape is None:
72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
73 |
74 | return scale_factor, output_shape
75 |
76 |
77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing):
78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
79 | # such that each position from the field_of_view will be multiplied with a matching filter from the
80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
81 | # around it. This is only done for one dimension of the image.
82 |
83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
84 | # 1/sf. this means filtering is more 'low-pass filter'.
85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
86 | kernel_width *= 1.0 / scale if antialiasing else 1.0
87 |
88 | # These are the coordinates of the output image
89 | out_coordinates = np.arange(1, out_length+1)
90 |
91 | # These are the matching positions of the output-coordinates on the input image coordinates.
92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale)
101 |
102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter
103 | left_boundary = np.floor(match_coordinates - kernel_width / 2)
104 |
105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
107 | expanded_kernel_width = np.ceil(kernel_width) + 2
108 |
109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image
110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
111 | # vertical dim is the pixels it 'sees' (kernel_size + 2)
112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1))
113 |
114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
116 | # 'field_of_view')
117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
118 |
119 | # Normalize weights to sum up to 1. be careful from dividing by 0
120 | sum_weights = np.sum(weights, axis=1)
121 | sum_weights[sum_weights == 0] = 1.0
122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
123 |
124 | # We use this mirror structure as a trick for reflection padding at the boundaries
125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
127 |
128 | # Get rid of weights and pixel positions that are of zero weight
129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
130 | weights = np.squeeze(weights[:, non_zero_out_pixels])
131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
132 |
133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
134 | return weights, field_of_view
135 |
136 |
137 | def resize_along_dim(im, dim, weights, field_of_view):
138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
139 | tmp_im = np.swapaxes(im, dim, 0)
140 |
141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
142 | # tmp_im[field_of_view.T], (bsxfun style)
143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1])
144 |
145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1.
146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with
148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
150 | # same number
151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0)
152 |
153 | # Finally we swap back the axes to the original order
154 | return np.swapaxes(tmp_out_im, dim, 0)
155 |
156 |
157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag):
158 | # See kernel_shift function to understand what this is
159 | if kernel_shift_flag:
160 | kernel = kernel_shift(kernel, scale_factor)
161 |
162 | # First run a correlation (convolution with flipped kernel)
163 | out_im = np.zeros_like(im)
164 | for channel in range(np.ndim(im)):
165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel)
166 |
167 | # Then subsample and return
168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None],
169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :]
170 |
171 |
172 | def kernel_shift(kernel, sf):
173 | # There are two reasons for shifting the kernel:
174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know
175 | # the degradation process included shifting so we always assume center of mass is center of the kernel.
176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first
177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the
178 | # top left corner of the first pixel. that is why different shift size needed between od and even size.
179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows:
180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth.
181 |
182 | # First calculate the current center of mass for the kernel
183 | current_center_of_mass = measurements.center_of_mass(kernel)
184 |
185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above
186 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2))
187 |
188 | # Define the shift vector for the kernel shifting (x,y)
189 | shift_vec = wanted_center_of_mass - current_center_of_mass
190 |
191 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift
192 | # (biggest shift among dims + 1 for safety)
193 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant')
194 |
195 | # Finally shift the kernel and return
196 | return interpolation.shift(kernel, shift_vec)
197 |
198 |
199 | # These next functions are all interpolation methods. x is the distance from the left pixel center
200 |
201 |
202 | def cubic(x):
203 | absx = np.abs(x)
204 | absx2 = absx ** 2
205 | absx3 = absx ** 3
206 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) +
207 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2)))
208 |
209 |
210 | def lanczos2(x):
211 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) /
212 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps))
213 | * (abs(x) < 2))
214 |
215 |
216 | def box(x):
217 | return ((-0.5 <= x) & (x < 0.5)) * 1.0
218 |
219 |
220 | def lanczos3(x):
221 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) /
222 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps))
223 | * (abs(x) < 3))
224 |
225 |
226 | def linear(x):
227 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
228 |
229 |
230 | def np_imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False):
231 | return np.clip(imresize(im.transpose(1, 2, 0), scale_factor, output_shape, kernel, antialiasing,
232 | kernel_shift_flag).transpose(2, 0, 1), 0, 1)
--------------------------------------------------------------------------------
/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | import torch
5 | import warnings
6 | from typing import List
7 |
8 | from torch import nn
9 | from torch.optim import Adam, Optimizer
10 |
11 | class MultiStepRestartLR(_LRScheduler):
12 | """ MultiStep with restarts learning rate scheme.
13 |
14 | Args:
15 | optimizer (torch.nn.optimizer): Torch optimizer.
16 | milestones (list): Iterations that will decrease learning rate.
17 | gamma (float): Decrease ratio. Default: 0.1.
18 | restarts (list): Restart iterations. Default: [0].
19 | restart_weights (list): Restart weights at each restart iteration.
20 | Default: [1].
21 | last_epoch (int): Used in _LRScheduler. Default: -1.
22 | """
23 |
24 | def __init__(self,
25 | optimizer,
26 | milestones,
27 | gamma=0.1,
28 | restarts=(0, ),
29 | restart_weights=(1, ),
30 | last_epoch=-1):
31 | self.milestones = Counter(milestones)
32 | self.gamma = gamma
33 | self.restarts = restarts
34 | self.restart_weights = restart_weights
35 | assert len(self.restarts) == len(
36 | self.restart_weights), 'restarts and their weights do not match.'
37 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
38 |
39 | def get_lr(self):
40 | if self.last_epoch in self.restarts:
41 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
42 | return [
43 | group['initial_lr'] * weight
44 | for group in self.optimizer.param_groups
45 | ]
46 | if self.last_epoch not in self.milestones:
47 | return [group['lr'] for group in self.optimizer.param_groups]
48 | return [
49 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
50 | for group in self.optimizer.param_groups
51 | ]
52 |
53 | class LinearLR(_LRScheduler):
54 | """
55 |
56 | Args:
57 | optimizer (torch.nn.optimizer): Torch optimizer.
58 | milestones (list): Iterations that will decrease learning rate.
59 | gamma (float): Decrease ratio. Default: 0.1.
60 | last_epoch (int): Used in _LRScheduler. Default: -1.
61 | """
62 |
63 | def __init__(self,
64 | optimizer,
65 | total_iter,
66 | last_epoch=-1):
67 | self.total_iter = total_iter
68 | super(LinearLR, self).__init__(optimizer, last_epoch)
69 |
70 | def get_lr(self):
71 | process = self.last_epoch / self.total_iter
72 | weight = (1 - process)
73 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
74 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
75 |
76 | class VibrateLR(_LRScheduler):
77 | """
78 |
79 | Args:
80 | optimizer (torch.nn.optimizer): Torch optimizer.
81 | milestones (list): Iterations that will decrease learning rate.
82 | gamma (float): Decrease ratio. Default: 0.1.
83 | last_epoch (int): Used in _LRScheduler. Default: -1.
84 | """
85 |
86 | def __init__(self,
87 | optimizer,
88 | total_iter,
89 | last_epoch=-1):
90 | self.total_iter = total_iter
91 | super(VibrateLR, self).__init__(optimizer, last_epoch)
92 |
93 | def get_lr(self):
94 | process = self.last_epoch / self.total_iter
95 |
96 | f = 0.1
97 | if process < 3 / 8:
98 | f = 1 - process * 8 / 3
99 | elif process < 5 / 8:
100 | f = 0.2
101 |
102 | T = self.total_iter // 80
103 | Th = T // 2
104 |
105 | t = self.last_epoch % T
106 |
107 | f2 = t / Th
108 | if t >= Th:
109 | f2 = 2 - f2
110 |
111 | weight = f * f2
112 |
113 | if self.last_epoch < Th:
114 | weight = max(0.1, weight)
115 |
116 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
117 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
118 |
119 | def get_position_from_periods(iteration, cumulative_period):
120 | """Get the position from a period list.
121 |
122 | It will return the index of the right-closest number in the period list.
123 | For example, the cumulative_period = [100, 200, 300, 400],
124 | if iteration == 50, return 0;
125 | if iteration == 210, return 2;
126 | if iteration == 300, return 2.
127 |
128 | Args:
129 | iteration (int): Current iteration.
130 | cumulative_period (list[int]): Cumulative period list.
131 |
132 | Returns:
133 | int: The position of the right-closest number in the period list.
134 | """
135 | for i, period in enumerate(cumulative_period):
136 | if iteration <= period:
137 | return i
138 |
139 |
140 | class CosineAnnealingRestartLR(_LRScheduler):
141 | """ Cosine annealing with restarts learning rate scheme.
142 |
143 | An example of config:
144 | periods = [10, 10, 10, 10]
145 | restart_weights = [1, 0.5, 0.5, 0.5]
146 | eta_min=1e-7
147 |
148 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
149 | scheduler will restart with the weights in restart_weights.
150 |
151 | Args:
152 | optimizer (torch.nn.optimizer): Torch optimizer.
153 | periods (list): Period for each cosine anneling cycle.
154 | restart_weights (list): Restart weights at each restart iteration.
155 | Default: [1].
156 | eta_min (float): The mimimum lr. Default: 0.
157 | last_epoch (int): Used in _LRScheduler. Default: -1.
158 | """
159 |
160 | def __init__(self,
161 | optimizer,
162 | periods,
163 | restart_weights=(1, ),
164 | eta_min=0,
165 | last_epoch=-1):
166 | self.periods = periods
167 | self.restart_weights = restart_weights
168 | self.eta_min = eta_min
169 | assert (len(self.periods) == len(self.restart_weights)
170 | ), 'periods and restart_weights should have the same length.'
171 | self.cumulative_period = [
172 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
173 | ]
174 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
175 |
176 | def get_lr(self):
177 | idx = get_position_from_periods(self.last_epoch,
178 | self.cumulative_period)
179 | current_weight = self.restart_weights[idx]
180 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
181 | current_period = self.periods[idx]
182 |
183 | return [
184 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
185 | (1 + math.cos(math.pi * (
186 | (self.last_epoch - nearest_restart) / current_period)))
187 | for base_lr in self.base_lrs
188 | ]
189 |
190 | class CosineAnnealingRestartCyclicLR(_LRScheduler):
191 | """ Cosine annealing with restarts learning rate scheme.
192 | An example of config:
193 | periods = [10, 10, 10, 10]
194 | restart_weights = [1, 0.5, 0.5, 0.5]
195 | eta_min=1e-7
196 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
197 | scheduler will restart with the weights in restart_weights.
198 | Args:
199 | optimizer (torch.nn.optimizer): Torch optimizer.
200 | periods (list): Period for each cosine anneling cycle.
201 | restart_weights (list): Restart weights at each restart iteration.
202 | Default: [1].
203 | eta_min (float): The mimimum lr. Default: 0.
204 | last_epoch (int): Used in _LRScheduler. Default: -1.
205 | """
206 |
207 | def __init__(self,
208 | optimizer,
209 | periods,
210 | restart_weights=(1, ),
211 | eta_mins=(0, ),
212 | last_epoch=-1):
213 | self.periods = periods
214 | self.restart_weights = restart_weights
215 | self.eta_mins = eta_mins
216 | assert (len(self.periods) == len(self.restart_weights)
217 | ), 'periods and restart_weights should have the same length.'
218 | self.cumulative_period = [
219 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
220 | ]
221 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
222 |
223 | def get_lr(self):
224 | idx = get_position_from_periods(self.last_epoch,
225 | self.cumulative_period)
226 | current_weight = self.restart_weights[idx]
227 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
228 | current_period = self.periods[idx]
229 | eta_min = self.eta_mins[idx]
230 |
231 | return [
232 | eta_min + current_weight * 0.5 * (base_lr - eta_min) *
233 | (1 + math.cos(math.pi * (
234 | (self.last_epoch - nearest_restart) / current_period)))
235 | for base_lr in self.base_lrs
236 | ]
237 |
238 |
239 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
240 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
241 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
242 | .. warning::
243 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
244 | after each iteration as calling it after each epoch will keep the starting lr at
245 | warmup_start_lr for the first epoch which is 0 in most cases.
246 | .. warning::
247 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
248 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
249 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
250 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
251 | train and validation methods.
252 | Example:
253 | >>> layer = nn.Linear(10, 1)
254 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
255 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
256 | >>> #
257 | >>> # the default case
258 | >>> for epoch in range(40):
259 | ... # train(...)
260 | ... # validate(...)
261 | ... scheduler.step()
262 | >>> #
263 | >>> # passing epoch param case
264 | >>> for epoch in range(40):
265 | ... scheduler.step(epoch)
266 | ... # train(...)
267 | ... # validate(...)
268 | """
269 |
270 | def __init__(
271 | self,
272 | optimizer: Optimizer,
273 | warmup_epochs: int,
274 | max_epochs: int,
275 | warmup_start_lr: float = 0.0,
276 | eta_min: float = 0.0,
277 | last_epoch: int = -1,
278 | ) -> None:
279 | """
280 | Args:
281 | optimizer (Optimizer): Wrapped optimizer.
282 | warmup_epochs (int): Maximum number of iterations for linear warmup
283 | max_epochs (int): Maximum number of iterations
284 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
285 | eta_min (float): Minimum learning rate. Default: 0.
286 | last_epoch (int): The index of last epoch. Default: -1.
287 | """
288 | self.warmup_epochs = warmup_epochs
289 | self.max_epochs = max_epochs
290 | self.warmup_start_lr = warmup_start_lr
291 | self.eta_min = eta_min
292 |
293 | super().__init__(optimizer, last_epoch)
294 |
295 | def get_lr(self) -> List[float]:
296 | """Compute learning rate using chainable form of the scheduler."""
297 | if not self._get_lr_called_within_step:
298 | warnings.warn(
299 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
300 | UserWarning,
301 | )
302 |
303 | if self.last_epoch == 0:
304 | return [self.warmup_start_lr] * len(self.base_lrs)
305 | if self.last_epoch < self.warmup_epochs:
306 | return [
307 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
308 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
309 | ]
310 | if self.last_epoch == self.warmup_epochs:
311 | return self.base_lrs
312 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
313 | return [
314 | group["lr"]
315 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
316 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
317 | ]
318 |
319 | return [
320 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
321 | / (
322 | 1
323 | + math.cos(
324 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
325 | )
326 | )
327 | * (group["lr"] - self.eta_min)
328 | + self.eta_min
329 | for group in self.optimizer.param_groups
330 | ]
331 |
332 | def _get_closed_form_lr(self) -> List[float]:
333 | """Called when epoch is passed as a param to the `step` function of the scheduler."""
334 | if self.last_epoch < self.warmup_epochs:
335 | return [
336 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
337 | for base_lr in self.base_lrs
338 | ]
339 |
340 | return [
341 | self.eta_min
342 | + 0.5
343 | * (base_lr - self.eta_min)
344 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
345 | for base_lr in self.base_lrs
346 | ]
347 |
348 |
349 | # warmup + decay as a function
350 | def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False):
351 | """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps."""
352 | assert not (linear and cosine)
353 |
354 | def fn(step):
355 | if step < warmup_steps:
356 | return float(step) / float(max(1, warmup_steps))
357 |
358 | if not (cosine or linear):
359 | # no decay
360 | return 1.0
361 |
362 | progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
363 | if cosine:
364 | # cosine decay
365 | return 0.5 * (1.0 + math.cos(math.pi * progress))
366 |
367 | # linear decay
368 | return 1.0 - progress
369 |
370 | return fn
371 |
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 | import glob
5 | import numpy as np
6 | from PIL import Image
7 |
8 | from torch.utils.data import Dataset
9 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
10 | import torch
11 |
12 | from utils.image_utils import random_augmentation, crop_img
13 | from utils.degradation_utils import Degradation
14 |
15 |
16 | class CDD11(Dataset):
17 | def __init__(self, args, split: str = "train", subset: str = "all"):
18 | super(CDD11, self).__init__()
19 |
20 | self.args = args
21 | self.toTensor = ToTensor()
22 | # self.de_type = self.args.de_type
23 | self.dataset_split = split
24 | self.subset = subset
25 | if split == "train":
26 | self.patch_size = args.patch_size
27 | else:
28 | self.patch_size = 64
29 | self.cdd11_dir = args.cdd11_path
30 |
31 | self._init()
32 |
33 | def __getitem__(self, index):
34 | # Randomly select a degradation type
35 | if self.dataset_split == "train":
36 | degradation_type = random.choice(list(self.degraded_dict.keys()))
37 | degraded_image_path = random.choice(self.degraded_dict[degradation_type])
38 | else:
39 | degradation_type = self.subset
40 | degraded_image_path = self.degraded_dict[degradation_type][index]
41 |
42 | # Select a degraded image within that type
43 |
44 | degraded_name = os.path.basename(degraded_image_path)
45 |
46 | # Get the corresponding clean image based on the file name
47 | image_name = os.path.basename(degraded_image_path)
48 | assert degraded_name == image_name
49 | clean_image_path = os.path.join(os.path.dirname(self.clean[0]), image_name)
50 |
51 | # Load the images
52 | #lr = crop_img(np.array(Image.open(degraded_image_path).convert('RGB')), base=16)
53 | lr = np.array(Image.open(degraded_image_path).convert('RGB'))
54 | #hr = crop_img(np.array(Image.open(clean_image_path).convert('RGB')), base=16)
55 | hr = np.array(Image.open(clean_image_path).convert('RGB'))
56 | # Apply random augmentation and crop
57 | if self.dataset_split == "train":
58 | lr, hr = random_augmentation(*self._crop_patch(lr, hr))
59 |
60 | # Convert to tensors
61 | lr = self.toTensor(lr)
62 | hr = self.toTensor(hr)
63 |
64 | return [clean_image_path, degradation_type], lr, hr
65 |
66 | def __len__(self):
67 | return sum(len(images) for images in self.degraded_dict.values())
68 |
69 | def _init(self):
70 | data_dir = os.path.join(self.cdd11_dir)
71 | self.clean = sorted(glob.glob(os.path.join(data_dir, f"{self.dataset_split}/clear", "*.png")))
72 |
73 | if len(self.clean) == 0:
74 | raise ValueError(f"No clean images found in {os.path.join(data_dir, f'{self.dataset_split}/clear')}")
75 |
76 | self.degraded_dict = {}
77 | allowed_degradation_folders = self._filter_degradation_folders(data_dir)
78 | for folder in allowed_degradation_folders:
79 | folder_name = os.path.basename(folder.strip('/'))
80 | degraded_images = sorted(glob.glob(os.path.join(folder, "*.png")))
81 |
82 | if len(degraded_images) == 0:
83 | raise ValueError(f"No images found in {folder_name}")
84 |
85 | # scale dataset length
86 | if self.dataset_split == "train":
87 | degraded_images *= 6
88 |
89 | self.degraded_dict[folder_name] = degraded_images
90 |
91 | def _filter_degradation_folders(self, data_dir):
92 | """
93 | This function returns folders based on the degradation_type_mode.
94 | 'single', 'double', 'triple', or 'all' degradation types will be returned.
95 | """
96 | degradation_folders = sorted(glob.glob(os.path.join(data_dir, self.dataset_split, "*/")))
97 | filtered_folders = []
98 |
99 | for folder in degradation_folders:
100 | folder_name = os.path.basename(folder.strip('/'))
101 | if folder_name == "clear":
102 | continue
103 |
104 | # Count the number of degradations based on the number of underscores in the folder name
105 | degradation_count = folder_name.count('_') + 1
106 |
107 | # Check the degradation type mode and filter accordingly
108 | if self.subset == "single" and degradation_count == 1:
109 | filtered_folders.append(folder)
110 | elif self.subset == "double" and degradation_count == 2:
111 | filtered_folders.append(folder)
112 | elif self.subset == "triple" and degradation_count == 3:
113 | filtered_folders.append(folder)
114 | elif self.subset == "all":
115 | filtered_folders.append(folder)
116 | # If self.subset is a specific degradation folder name, match it exactly
117 | elif self.subset not in ["single", "double", "triple", "all"]:
118 | if folder_name == self.subset:
119 | filtered_folders.append(folder)
120 |
121 | print(f"Degradation type mode: {self.subset}")
122 | print(f"Loading degradation folders: {[os.path.basename(f.strip('/')) for f in filtered_folders]}")
123 | return filtered_folders
124 |
125 | def _crop_patch(self, img_1, img_2):
126 | # Crop a patch from both images (degraded and clean) at the same location
127 | H = img_1.shape[0]
128 | W = img_1.shape[1]
129 | ind_H = random.randint(0, H - self.args.patch_size)
130 | ind_W = random.randint(0, W - self.args.patch_size)
131 |
132 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
133 | patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
134 |
135 | return patch_1, patch_2
136 |
137 |
138 | class AnyIRTrainDataset(Dataset):
139 | def __init__(self, args):
140 | super(AnyIRTrainDataset, self).__init__()
141 | self.args = args
142 | self.rs_ids = []
143 | self.hazy_ids = []
144 | self.D = Degradation(args)
145 | self.de_temp = 0
146 | self.de_type = self.args.de_type
147 | print(self.de_type)
148 |
149 | self.de_dict = {'denoise_15': 0, 'denoise_25': 1, 'denoise_50': 2, 'derain': 3, 'dehaze': 4, 'deblur' : 5}
150 |
151 | self._init_ids()
152 | self._merge_ids()
153 |
154 | self.crop_transform = Compose([
155 | ToPILImage(),
156 | RandomCrop(args.patch_size),
157 | ])
158 |
159 | self.toTensor = ToTensor()
160 |
161 | def _init_ids(self):
162 | if 'denoise_15' in self.de_type or 'denoise_25' in self.de_type or 'denoise_50' in self.de_type:
163 | self._init_clean_ids()
164 | if 'derain' in self.de_type:
165 | self._init_rs_ids()
166 | if 'dehaze' in self.de_type:
167 | self._init_hazy_ids()
168 | if 'deblur' in self.de_type:
169 | self._init_deblur_ids()
170 | if 'enhance' in self.de_type:
171 | self._init_enhance_ids()
172 |
173 | random.shuffle(self.de_type)
174 |
175 | def _init_clean_ids(self):
176 | ref_file = self.args.data_file_dir + "noisy/denoise.txt"
177 | temp_ids = []
178 | temp_ids+= [id_.strip() for id_ in open(ref_file)]
179 | clean_ids = []
180 | name_list = os.listdir(self.args.denoise_dir)
181 | clean_ids += [self.args.denoise_dir + id_ for id_ in name_list if id_.strip() in temp_ids]
182 |
183 | if 'denoise_15' in self.de_type:
184 | self.s15_ids = [{"clean_id": x,"de_type":0} for x in clean_ids]
185 | self.s15_ids = self.s15_ids * 3
186 | random.shuffle(self.s15_ids)
187 | self.s15_counter = 0
188 | if 'denoise_25' in self.de_type:
189 | self.s25_ids = [{"clean_id": x,"de_type":1} for x in clean_ids]
190 | self.s25_ids = self.s25_ids * 3
191 | random.shuffle(self.s25_ids)
192 | self.s25_counter = 0
193 | if 'denoise_50' in self.de_type:
194 | self.s50_ids = [{"clean_id": x,"de_type":2} for x in clean_ids]
195 | self.s50_ids = self.s50_ids * 3
196 | random.shuffle(self.s50_ids)
197 | self.s50_counter = 0
198 |
199 | self.num_clean = len(clean_ids)
200 | print("Total Denoise Ids : {}".format(self.num_clean))
201 |
202 | def _init_hazy_ids(self):
203 | temp_ids = []
204 | hazy = self.args.data_file_dir + "hazy/hazy_outside.txt"
205 | temp_ids+= [self.args.dehaze_dir + id_.strip() for id_ in open(hazy)]
206 | self.hazy_ids = [{"clean_id" : x,"de_type":4} for x in temp_ids]
207 |
208 | self.hazy_counter = 0
209 |
210 | self.num_hazy = len(self.hazy_ids)
211 | print("Total Hazy Ids : {}".format(self.num_hazy))
212 |
213 | def _init_deblur_ids(self):
214 | temp_ids = []
215 |
216 | image_list = os.listdir(os.path.join(self.args.gopro_dir, 'blur/'))
217 | temp_ids = image_list
218 | self.deblur_ids = [{"clean_id" : x,"de_type":5} for x in temp_ids]
219 | self.deblur_ids = self.deblur_ids * 5
220 | self.deblur_counter = 0
221 | self.num_deblur = len(self.deblur_ids)
222 | print('Total Blur Ids : {}'.format(self.num_deblur))
223 |
224 | def _init_enhance_ids(self):
225 | temp_ids = []
226 | image_list = os.listdir(os.path.join(self.args.enhance_dir, 'low/'))
227 | temp_ids = image_list
228 | self.enhance_ids= [{"clean_id" : x,"de_type":6} for x in temp_ids]
229 | self.enhance_ids = self.enhance_ids * 20
230 | self.num_enhance = len(self.enhance_ids)
231 | print('Total enhance Ids : {}'.format(self.num_enhance))
232 |
233 |
234 | def _init_rs_ids(self):
235 | temp_ids = []
236 | rs = self.args.data_file_dir + "rainy/rainTrain.txt"
237 | temp_ids+= [self.args.derain_dir + id_.strip() for id_ in open(rs)]
238 | self.rs_ids = [{"clean_id":x,"de_type":3} for x in temp_ids]
239 | self.rs_ids = self.rs_ids * 120
240 |
241 | self.rl_counter = 0
242 | self.num_rl = len(self.rs_ids)
243 | print("Total Rainy Ids : {}".format(self.num_rl))
244 |
245 |
246 | def _crop_patch(self, img_1, img_2):
247 | H = img_1.shape[0]
248 | W = img_1.shape[1]
249 | ind_H = random.randint(0, H - self.args.patch_size)
250 | ind_W = random.randint(0, W - self.args.patch_size)
251 |
252 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
253 | patch_2 = img_2[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
254 |
255 | return patch_1, patch_2
256 |
257 | def _get_gt_name(self, rainy_name):
258 | gt_name = rainy_name.split("rainy")[0] + 'gt/norain-' + rainy_name.split('rain-')[-1]
259 | return gt_name
260 |
261 | def _get_deblur_name(self, deblur_name):
262 | gt_name = deblur_name.replace("blur", "sharp")
263 | return gt_name
264 |
265 | def _get_enhance_name(self, enhance_name):
266 | gt_name = enhance_name.replace("low", "gt")
267 | return gt_name
268 |
269 | def _get_nonhazy_name(self, hazy_name):
270 | dir_name = hazy_name.split("synthetic")[0] + 'original/'
271 | name = hazy_name.split('/')[-1].split('_')[0]
272 | suffix = '.' + hazy_name.split('.')[-1]
273 | nonhazy_name = dir_name + name + suffix
274 | return nonhazy_name
275 |
276 | def _merge_ids(self):
277 | self.sample_ids = []
278 | if "denoise_15" in self.de_type:
279 | self.sample_ids += self.s15_ids
280 | self.sample_ids += self.s25_ids
281 | self.sample_ids += self.s50_ids
282 | if "derain" in self.de_type:
283 | self.sample_ids+= self.rs_ids
284 | if "dehaze" in self.de_type:
285 | self.sample_ids+= self.hazy_ids
286 | if "deblur" in self.de_type:
287 | self.sample_ids += self.deblur_ids
288 | if "enhance" in self.de_type:
289 | self.sample_ids += self.enhance_ids
290 |
291 | print(len(self.sample_ids))
292 |
293 | def __getitem__(self, idx):
294 | sample = self.sample_ids[idx]
295 | de_id = sample["de_type"]
296 |
297 | if de_id < 3:
298 | if de_id == 0:
299 | clean_id = sample["clean_id"]
300 | elif de_id == 1:
301 | clean_id = sample["clean_id"]
302 | elif de_id == 2:
303 | clean_id = sample["clean_id"]
304 |
305 | clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
306 | clean_patch = self.crop_transform(clean_img)
307 | clean_patch= np.array(clean_patch)
308 |
309 | clean_name = clean_id.split("/")[-1].split('.')[0]
310 |
311 | clean_patch = random_augmentation(clean_patch)[0]
312 |
313 | degrad_patch = self.D.single_degrade(clean_patch, de_id)
314 | else:
315 | if de_id == 3:
316 | # Rain Streak Removal
317 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16)
318 | clean_name = self._get_gt_name(sample["clean_id"])
319 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
320 | elif de_id == 4:
321 | # Dehazing with SOTS outdoor training set
322 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16)
323 | clean_name = self._get_nonhazy_name(sample["clean_id"])
324 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
325 | elif de_id == 5:
326 | # Deblur with Gopro set
327 | degrad_img = crop_img(np.array(Image.open(os.path.join(self.args.gopro_dir, 'blur/', sample["clean_id"])).convert('RGB')), base=16)
328 | clean_img = crop_img(np.array(Image.open(os.path.join(self.args.gopro_dir, 'sharp/', sample["clean_id"])).convert('RGB')), base=16)
329 | clean_name = self._get_deblur_name(sample["clean_id"])
330 | elif de_id == 6:
331 | # Enhancement with LOL training set
332 | degrad_img = crop_img(np.array(Image.open(os.path.join(self.args.enhance_dir, 'low/', sample["clean_id"])).convert('RGB')), base=16)
333 | clean_img = crop_img(np.array(Image.open(os.path.join(self.args.enhance_dir, 'gt/', sample["clean_id"])).convert('RGB')), base=16)
334 | clean_name = self._get_enhance_name(sample["clean_id"])
335 |
336 |
337 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img))
338 |
339 | clean_patch = self.toTensor(clean_patch)
340 | degrad_patch = self.toTensor(degrad_patch)
341 |
342 |
343 | return [clean_name, de_id], degrad_patch, clean_patch
344 |
345 | def __len__(self):
346 | return len(self.sample_ids)
347 |
348 |
349 | class AnyDnTestDataset(Dataset):
350 | def __init__(self, args):
351 | super(AnyDnTestDataset, self).__init__()
352 | self.args = args
353 | self.clean_ids = []
354 | self.sigma = 15
355 |
356 | self._init_clean_ids()
357 |
358 | self.toTensor = ToTensor()
359 |
360 | def _init_clean_ids(self):
361 | name_list = os.listdir(self.args.denoise_path)
362 | self.clean_ids += [self.args.denoise_path + id_ for id_ in name_list]
363 |
364 | self.num_clean = len(self.clean_ids)
365 |
366 | def _add_gaussian_noise(self, clean_patch):
367 | noise = np.random.randn(*clean_patch.shape)
368 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
369 | return noisy_patch, clean_patch
370 |
371 | def set_sigma(self, sigma):
372 | self.sigma = sigma
373 |
374 | def __getitem__(self, clean_id):
375 | clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16)
376 | clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
377 |
378 | noisy_img, _ = self._add_gaussian_noise(clean_img)
379 | clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img)
380 |
381 | return [clean_name], noisy_img, clean_img
382 | def tile_degrad(input_,tile=128,tile_overlap =0):
383 | sigma_dict = {0:0,1:15,2:25,3:50}
384 | b, c, h, w = input_.shape
385 | tile = min(tile, h, w)
386 | assert tile % 8 == 0, "tile size should be multiple of 8"
387 |
388 | stride = tile - tile_overlap
389 | h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
390 | w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
391 | E = torch.zeros(b, c, h, w).type_as(input_)
392 | W = torch.zeros_like(E)
393 | s = 0
394 | for h_idx in h_idx_list:
395 | for w_idx in w_idx_list:
396 | in_patch = input_[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
397 | out_patch = in_patch
398 | out_patch_mask = torch.ones_like(in_patch)
399 |
400 | E[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch)
401 | W[..., h_idx:(h_idx+tile), w_idx:(w_idx+tile)].add_(out_patch_mask)
402 |
403 | restored = torch.clamp(restored, 0, 1)
404 | return restored
405 | def __len__(self):
406 | return self.num_clean
407 |
408 |
409 | class AnyIRTestDataset(Dataset):
410 | def __init__(self, args, task="derain",addnoise = False,sigma = None):
411 | super(AnyIRTestDataset, self).__init__()
412 | self.ids = []
413 | self.task_idx = 0
414 | self.args = args
415 |
416 | self.task_dict = {'derain': 0, 'dehaze': 1, 'deblur': 2, 'enhance': 3}
417 | self.toTensor = ToTensor()
418 | self.addnoise = addnoise
419 | self.sigma = sigma
420 |
421 | self.set_dataset(task)
422 | def _add_gaussian_noise(self, clean_patch):
423 | noise = np.random.randn(*clean_patch.shape)
424 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
425 | return noisy_patch, clean_patch
426 |
427 | def _init_input_ids(self):
428 | if self.task_idx == 0:
429 | self.ids = []
430 | name_list = os.listdir(self.args.derain_path + 'input/')
431 | # print(name_list)
432 | print(self.args.derain_path)
433 | self.ids += [self.args.derain_path + 'input/' + id_ for id_ in name_list]
434 | elif self.task_idx == 1:
435 | self.ids = []
436 | name_list = os.listdir(self.args.dehaze_path + 'input/')
437 | self.ids += [self.args.dehaze_path + 'input/' + id_ for id_ in name_list]
438 | elif self.task_idx == 2:
439 | self.ids = []
440 | name_list = os.listdir(self.args.gopro_path +'input/')
441 | self.ids += [self.args.gopro_path + 'input/' + id_ for id_ in name_list]
442 | elif self.task_idx == 3:
443 | self.ids = []
444 | name_list = os.listdir(self.args.enhance_path + 'input/')
445 | self.ids += [self.args.enhance_path + 'input/' + id_ for id_ in name_list]
446 |
447 | self.length = len(self.ids)
448 |
449 | def _get_gt_path(self, degraded_name):
450 | if self.task_idx == 0:
451 | dir_name = degraded_name.split("input")[0] + 'target/'
452 | name = 'no' + degraded_name.split('/')[-1].split('_')[0]
453 | gt_name = dir_name + name
454 | elif self.task_idx == 1:
455 | dir_name = degraded_name.split("input")[0] + 'target/'
456 | name = degraded_name.split('/')[-1].split('_')[0] + '.png'
457 | gt_name = dir_name + name
458 | elif self.task_idx == 2:
459 | gt_name = degraded_name.replace("input", "target")
460 |
461 | elif self.task_idx == 3:
462 | gt_name = degraded_name.replace("input", "target")
463 |
464 | return gt_name
465 |
466 | def set_dataset(self, task):
467 | self.task_idx = self.task_dict[task]
468 | self._init_input_ids()
469 |
470 | def __getitem__(self, idx):
471 | degraded_path = self.ids[idx]
472 | clean_path = self._get_gt_path(degraded_path)
473 |
474 | degraded_img = crop_img(np.array(Image.open(degraded_path).convert('RGB')), base=16)
475 | if self.addnoise:
476 | degraded_img,_ = self._add_gaussian_noise(degraded_img)
477 | clean_img = crop_img(np.array(Image.open(clean_path).convert('RGB')), base=16)
478 |
479 | clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
480 | degraded_name = degraded_path.split('/')[-1][:-4]
481 |
482 | return [degraded_name], degraded_img, clean_img
483 |
484 | def __len__(self):
485 | return self.length
486 |
487 |
488 | class TestSpecificDataset(Dataset):
489 | def __init__(self, args):
490 | super(TestSpecificDataset, self).__init__()
491 | self.args = args
492 | self.degraded_ids = []
493 | self._init_clean_ids(args.test_path)
494 |
495 | self.toTensor = ToTensor()
496 |
497 | def _init_clean_ids(self, root):
498 | extensions = ['jpg', 'JPG', 'png', 'PNG', 'jpeg', 'JPEG', 'bmp', 'BMP']
499 | if os.path.isdir(root):
500 | name_list = []
501 | for image_file in os.listdir(root):
502 | if any([image_file.endswith(ext) for ext in extensions]):
503 | name_list.append(image_file)
504 | if len(name_list) == 0:
505 | raise Exception('The input directory does not contain any image files')
506 | self.degraded_ids += [root + id_ for id_ in name_list]
507 | else:
508 | if any([root.endswith(ext) for ext in extensions]):
509 | name_list = [root]
510 | else:
511 | raise Exception('Please pass an Image file')
512 | self.degraded_ids = name_list
513 | print("Total Images : {}".format(name_list))
514 |
515 | self.num_img = len(self.degraded_ids)
516 |
517 | def __getitem__(self, idx):
518 | degraded_img = crop_img(np.array(Image.open(self.degraded_ids[idx]).convert('RGB')), base=16)
519 | name = self.degraded_ids[idx].split('/')[-1][:-4]
520 |
521 | degraded_img = self.toTensor(degraded_img)
522 |
523 | return [name], degraded_img
524 |
525 | def __len__(self):
526 | return self.num_img
--------------------------------------------------------------------------------