├── src ├── __init__.py ├── model │ ├── __init__.py │ ├── classifier.py │ ├── common.py │ └── edsr.py ├── tool │ ├── __init__.py │ └── prepare_tta_data.py ├── utils │ ├── __init__.py │ ├── utils_tta.py │ ├── utils_blindsr_plus.py │ ├── diffjpeg.py │ ├── utils_blindsr.py │ ├── basicsr_degradations.py │ └── utils_image.py ├── scripts │ ├── train_classifier.sh │ ├── tta_div2kc_x4.sh │ ├── tta_div2kc_x2.sh │ ├── tta_div2kmc_x2.sh │ └── test_checkpoints.sh ├── data │ ├── benchmark.py │ ├── div2k.py │ ├── common.py │ ├── __init__.py │ ├── div2kc.py │ ├── div2kmc.py │ ├── patchkernel.py │ ├── srdata.py │ └── div2kmd.py ├── test_srtta.py ├── train_classifier.py ├── srtta.py └── main_tta.py ├── figures └── srtta_overview.png ├── requirements.txt ├── LICENSE ├── .gitignore └── README.md /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/tool/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /figures/srtta_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DengZeshuai/SRTTA/HEAD/figures/srtta_overview.png -------------------------------------------------------------------------------- /src/scripts/train_classifier.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train_classifier.py --bs 256 --total_epoch 400 \ 2 | --save_dir experiments/degradation_classifier \ 3 | --lr 0.001 --cache --lf 0.001 --train_dtypes single+multi --img_size 224 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | basicsr==1.3.4.9 2 | imageio==2.26.0 3 | matplotlib==3.7.1 4 | numpy==1.23.0 5 | opencv_contrib_python==4.7.0.72 6 | opencv_python==4.7.0.72 7 | pandas==1.5.3 8 | scipy==1.9.1 9 | skimage==0.0 10 | torch==1.10.1+cu111 11 | torchvision==0.11.2+cu111 12 | tqdm==4.65.0 13 | -------------------------------------------------------------------------------- /src/scripts/tta_div2kc_x4.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=4 python main_tta.py --exp_name tta_x4_parameter_reset --lr 5e-5 \ 2 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train checkpoints/EDSR_baseline_x4.pt \ 3 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 --params_reset \ 4 | --scale 4 --tta_data DIV2KC 5 | 6 | CUDA_VISIBLE_DEVICES=4 python main_tta.py --exp_name tta_x4_lifelong --lr 5e-5 \ 7 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train checkpoints/EDSR_baseline_x4.pt \ 8 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 \ 9 | --scale 4 --tta_data DIV2KC -------------------------------------------------------------------------------- /src/scripts/tta_div2kc_x2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python main_tta.py --exp_name SRTTA_reset_DIV2KC_x2 --lr 5e-5 \ 2 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train checkpoints/EDSR_baseline_x2.pt \ 3 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 --params_reset \ 4 | --scale 2 --tta_data DIV2KC 5 | 6 | CUDA_VISIBLE_DEVICES=1 python main_tta.py --exp_name SRTTA_lifelong_DIV2KC_x2 --lr 5e-5 \ 7 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train checkpoints/EDSR_baseline_x2.pt \ 8 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 \ 9 | --scale 2 --tta_data DIV2KC -------------------------------------------------------------------------------- /src/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, name='', train=True, benchmark=True): 13 | super(Benchmark, self).__init__( 14 | args, name=name, train=train, benchmark=True 15 | ) 16 | 17 | def _set_filesystem(self, dir_data): 18 | self.apath = os.path.join(dir_data, 'benchmark', self.name) 19 | self.dir_hr = os.path.join(self.apath, 'HR') 20 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 21 | self.ext = ('', '.png') 22 | -------------------------------------------------------------------------------- /src/scripts/tta_div2kmc_x2.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python main_tta.py --exp_name SRTTA_lifelong_DIV2KMC_x2 --lr 5e-5 \ 2 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train checkpoints/EDSR_baseline_x2.pt \ 3 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 --params_reset \ 4 | --scale 2 --tta_data DIV2KMC --data_test DIV2KMC+Set5 5 | 6 | # For lifelong setting, we first adapt to DIV2K-C, and then use the adapted model to adapt to DIV2K-MC 7 | CUDA_VISIBLE_DEVICES=1 python main_tta.py --exp_name SRTTA_reset_DIV2KMC_x2 --lr 5e-5 \ 8 | --iterations 10 --batch_size 32 --patch_size 64 --pre_train $PATH_TO_ADAPTED_CHECKPOINTS_IN_DIV2KC \ 9 | --teacher_weight 1 --fisher_restore --fisher_ratio 0.5 \ 10 | --scale 2 --tta_data DIV2KMC --data_test DIV2KMC+Set5 -------------------------------------------------------------------------------- /src/model/classifier.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | import torchvision 4 | 5 | 6 | class Classifier(nn.Module): 7 | def __init__(self, n_cls=3, feat_dim=128): 8 | super(Classifier,self).__init__() 9 | self.class_num = n_cls 10 | self.loss = nn.BCELoss(reduction="mean") 11 | self.backbone = torchvision.models.resnet50(pretrained=True) 12 | self.head = nn.Linear(1000, feat_dim) 13 | self.classifier = nn.Linear(feat_dim, self.class_num) 14 | 15 | def forward(self, x): 16 | x = F.normalize(self.head(self.backbone(x)), dim=1) 17 | res = self.classifier(x) 18 | return res 19 | 20 | def computeLoss(self, pred, gt): 21 | loss = self.loss(pred.sigmoid(), gt) 22 | return loss -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Zeshuai Deng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import srdata 3 | 4 | class DIV2K(srdata.SRData): 5 | def __init__(self, args, name='DIV2K', train=True, benchmark=False): 6 | data_range = [r.split('-') for r in args.data_range.split('/')] 7 | if train: 8 | data_range = data_range[0] 9 | else: 10 | if args.test_only and len(data_range) == 1: 11 | data_range = data_range[0] 12 | else: 13 | data_range = data_range[1] 14 | 15 | self.begin, self.end = list(map(lambda x: int(x), data_range)) 16 | super(DIV2K, self).__init__( 17 | args, name=name, train=train, benchmark=benchmark 18 | ) 19 | 20 | def _scan(self): 21 | names_hr, names_lr = super(DIV2K, self)._scan() 22 | names_hr = names_hr[self.begin - 1:self.end] 23 | names_lr = [n[self.begin - 1:self.end] for n in names_lr] 24 | 25 | return names_hr, names_lr 26 | 27 | def _set_filesystem(self, dir_data): 28 | super(DIV2K, self)._set_filesystem(dir_data) 29 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 30 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | # PyInstaller 27 | # Usually these files are written by a python script from a template 28 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 29 | *.manifest 30 | *.spec 31 | ./history/* 32 | # Installer logs 33 | pip-log.txt 34 | pip-delete-this-directory.txt 35 | /mnt/cephfs/home/chenzhuokun/git/SR_TTA/src/model/exp_transform/* 36 | /mnt/cephfs/home/chenzhuokun/git/SR_TTA/src/model/exp_cls/* 37 | # Unit test / coverage reports 38 | htmlcov/ 39 | .tox/ 40 | .coverage 41 | .coverage.* 42 | .cache 43 | nosetests.xml 44 | coverage.xml 45 | *,cover 46 | *.xls 47 | *.xlsx 48 | # Translations 49 | *.mo 50 | *.pot 51 | logs/* 52 | # Django stuff: 53 | *.log 54 | 55 | # Sphinx documentation 56 | docs/_build/ 57 | experiment/ 58 | # PyBuilder 59 | target/ 60 | .history/* 61 | # PyTorch 62 | *.pt 63 | *.pdf 64 | *.swp 65 | .vscode 66 | datasets 67 | src/checkpoints -------------------------------------------------------------------------------- /src/scripts/test_checkpoints.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=6 2 | # x2 lifelong 3 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x2.pt --dir_data ../datasets \ 4 | --pre_train checkpoints/reproduce/SRTTA_lifelong_DIV2KC_x2 \ 5 | --cls_model checkpoints/classifier.pt 6 | 7 | # x2 parameter-reset 8 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x2.pt --dir_data ../datasets \ 9 | --pre_train checkpoints/reproduce/SRTTA_reset_DIV2KC_x2 \ 10 | --cls_model checkpoints/classifier.pt 11 | 12 | # x2 lifelong multi-corruptions 13 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x2.pt --dir_data ../datasets \ 14 | --pre_train checkpoints/reproduce/SRTTA_lifelong_DIV2KMC_x2 --corruptions multi \ 15 | --cls_model checkpoints/classifier.pt --tta_data DIV2KMC --data_test DIV2KMC 16 | 17 | # x2 parameter-reset multi-corruptions 18 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x2.pt --dir_data ../datasets \ 19 | --pre_train checkpoints/reproduce/SRTTA_reset_DIV2KMC_x2 --corruptions multi \ 20 | --cls_model checkpoints/classifier.pt --tta_data DIV2KMC --data_test DIV2KMC 21 | 22 | # x4 lifelong 23 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x4.pt --dir_data ../datasets \ 24 | --pre_train checkpoints/reproduce/SRTTA_lifelong_DIV2KC_x4 \ 25 | --scale 4 --cls_model checkpoints/classifier.pt 26 | 27 | # x4 parameter-reset 28 | python test_srtta.py --test_only --base_model checkpoints/EDSR_baseline_x4.pt --dir_data ../datasets \ 29 | --pre_train checkpoints/reproduce/SRTTA_reset_DIV2KC_x4 \ 30 | --scale 4 --cls_model checkpoints/classifier.pt -------------------------------------------------------------------------------- /src/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.color as sc 5 | 6 | import torch 7 | 8 | def get_patch(*args, patch_size=96, scale=2, multi=False, input_large=False): 9 | ih, iw = args[0].shape[:2] 10 | 11 | if not input_large: 12 | p = scale if multi else 1 13 | tp = p * patch_size 14 | ip = tp // scale 15 | else: 16 | tp = patch_size 17 | ip = patch_size 18 | 19 | ix = random.randrange(0, iw - ip + 1) 20 | iy = random.randrange(0, ih - ip + 1) 21 | 22 | if not input_large: 23 | tx, ty = scale * ix, scale * iy 24 | else: 25 | tx, ty = ix, iy 26 | 27 | ret = [ 28 | args[0][iy:iy + ip, ix:ix + ip, :], 29 | *[a[ty:ty + tp, tx:tx + tp, :] for a in args[1:]] 30 | ] 31 | 32 | return ret 33 | 34 | def set_channel(*args, n_channels=3): 35 | def _set_channel(img): 36 | if img.ndim == 2: 37 | img = np.expand_dims(img, axis=2) 38 | 39 | c = img.shape[2] 40 | if n_channels == 1 and c == 3: 41 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 42 | elif n_channels == 3 and c == 1: 43 | img = np.concatenate([img] * n_channels, 2) 44 | 45 | return img 46 | 47 | return [_set_channel(a) for a in args] 48 | 49 | def np2Tensor(*args, rgb_range=255): 50 | def _np2Tensor(img): 51 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 52 | tensor = torch.from_numpy(np_transpose).float() 53 | tensor.mul_(rgb_range / 255) 54 | 55 | return tensor 56 | 57 | return [_np2Tensor(a) for a in args] 58 | 59 | def augment(*args, hflip=True, rot=True): 60 | hflip = hflip and random.random() < 0.5 61 | vflip = rot and random.random() < 0.5 62 | rot90 = rot and random.random() < 0.5 63 | 64 | def _augment(img): 65 | if hflip: img = img[:, ::-1, :] 66 | if vflip: img = img[::-1, :, :] 67 | if rot90: img = img.transpose(1, 0, 2) 68 | 69 | return img 70 | 71 | return [_augment(a) for a in args] 72 | 73 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | #from dataloader import MSDataLoader 3 | from torch.utils.data import dataloader 4 | from torch.utils.data import ConcatDataset 5 | 6 | # This is a simple wrapper function for ConcatDataset 7 | class MyConcatDataset(ConcatDataset): 8 | def __init__(self, datasets): 9 | super(MyConcatDataset, self).__init__(datasets) 10 | self.train = datasets[0].train 11 | 12 | def set_scale(self, idx_scale): 13 | for d in self.datasets: 14 | if hasattr(d, 'set_scale'): d.set_scale(idx_scale) 15 | 16 | class Data: 17 | def __init__(self, args): 18 | self.loader_train = None 19 | if not args.test_only: 20 | datasets = [] 21 | for d in args.data_train: 22 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 23 | m = import_module('data.' + module_name.lower()) 24 | datasets.append(getattr(m, module_name)(args, name=d)) 25 | 26 | self.loader_train = dataloader.DataLoader( 27 | MyConcatDataset(datasets), 28 | batch_size=args.batch_size, 29 | shuffle=True, 30 | pin_memory=not args.cpu, 31 | num_workers=args.n_threads, 32 | ) 33 | 34 | self.loader_test = [] 35 | for d in args.data_test: 36 | if d in ['Set5', 'Set14', 'B100', 'Urban100']: 37 | m = import_module('data.benchmark') 38 | testset = getattr(m, 'Benchmark')(args, train=False, name=d) 39 | else: 40 | module_name = d if d.find('DIV2K-Q') < 0 else 'DIV2KJPEG' 41 | m = import_module('data.' + module_name.lower()) 42 | testset = getattr(m, module_name)(args, train=False, name=d) 43 | 44 | self.loader_test.append( 45 | dataloader.DataLoader( 46 | testset, 47 | batch_size=1, 48 | shuffle=False, 49 | pin_memory=not args.cpu, 50 | num_workers=args.n_threads, 51 | ) 52 | ) 53 | -------------------------------------------------------------------------------- /src/data/div2kc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | try: 4 | import imageio.v2 as imageio 5 | except: 6 | import imageio 7 | import torch 8 | import numpy as np 9 | from data import srdata, common 10 | 11 | class DIV2KC(srdata.SRData): 12 | def __init__(self, args, name='DIV2KC', train=False, benchmark=False, corruption="GaussianBlur", lr_only=False): 13 | super(DIV2KC, self).__init__( 14 | args, name=name, train=train, benchmark=benchmark 15 | ) 16 | self.corruption = corruption 17 | self.lr_only = lr_only 18 | 19 | def set_corruption(self, corruption="GaussianBlur"): 20 | self.corruption = corruption 21 | self._set_filesystem(self.args.dir_data, corruption) 22 | self.images_hr, self.images_lr = self._scan() 23 | 24 | def set_lr_only(self, lr_only=True): 25 | self.lr_only = lr_only 26 | 27 | def get_img_paths(self): 28 | return self.images_hr 29 | 30 | def _set_filesystem(self, dir_data, corruption='GaussianBlur'): 31 | if not hasattr(self, 'corruption'): 32 | self.corruption = corruption 33 | self.apath = os.path.join(dir_data, self.name) 34 | self.dir_hr = os.path.join(self.apath, 'gt') 35 | s = max(self.scale) if isinstance(self.scale, list) else self.scale 36 | self.dir_lr = os.path.join(self.apath, 'corruptions', self.corruption, 'X{}'.format(s)) 37 | self.ext = ('.png', '.png') 38 | 39 | def _scan(self): 40 | names_hr = sorted( 41 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 42 | ) 43 | names_lr = [[] for _ in self.scale] 44 | for f in names_hr: 45 | filename, _ = os.path.splitext(os.path.basename(f)) 46 | for si, s in enumerate(self.scale): 47 | names_lr[si].append(os.path.join(self.dir_lr, filename + self.ext[1])) 48 | 49 | return names_hr, names_lr 50 | 51 | def __getitem__(self, idx): 52 | if self.lr_only: 53 | f_lr = self.images_lr[self.idx_scale][idx] 54 | filename = os.path.splitext(os.path.basename(f_lr))[0] 55 | lr = imageio.imread(f_lr) 56 | if lr.shape[2] > 3: lr = lr[:, :, :3] 57 | # convert torch to numpy 58 | lr = np.ascontiguousarray(lr.transpose((2, 0, 1))) 59 | lr = torch.from_numpy(lr).float() 60 | 61 | return lr, -1, filename 62 | else: 63 | lr, hr, filename = self._load_file(idx) 64 | pair = self.get_patch(lr, hr) 65 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 66 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 67 | 68 | return pair_t[0], pair_t[1], filename -------------------------------------------------------------------------------- /src/data/div2kmc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | try: 4 | import imageio.v2 as imageio 5 | except: 6 | import imageio 7 | import torch 8 | import numpy as np 9 | from data import srdata, common 10 | 11 | class DIV2KMC(srdata.SRData): 12 | def __init__(self, args, name='DIV2KMC', train=False, benchmark=False, corruption="GaussianBlur", lr_only=False): 13 | super(DIV2KMC, self).__init__( 14 | args, name=name, train=train, benchmark=benchmark 15 | ) 16 | self.corruption = corruption 17 | self.lr_only = lr_only 18 | 19 | def set_corruption(self, corruption="GaussianBlur"): 20 | self.corruption = corruption 21 | self._set_filesystem(self.args.dir_data, corruption) 22 | self.images_hr, self.images_lr = self._scan() 23 | 24 | def set_lr_only(self, lr_only=True): 25 | self.lr_only = lr_only 26 | 27 | def get_img_paths(self): 28 | return self.images_hr 29 | 30 | def _set_filesystem(self, dir_data, corruption='GaussianBlur'): 31 | if not hasattr(self, 'corruption'): 32 | self.corruption = corruption 33 | self.apath = os.path.join(dir_data, self.name) 34 | self.dir_hr = os.path.join(self.apath, 'gt') 35 | s = max(self.scale) if isinstance(self.scale, list) else self.scale 36 | self.dir_lr = os.path.join(self.apath, 'corruptions', self.corruption, 'X{}'.format(s)) 37 | self.ext = ('.png', '.png') 38 | 39 | def _scan(self): 40 | names_hr = sorted( 41 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 42 | ) 43 | names_lr = [[] for _ in self.scale] 44 | for f in names_hr: 45 | filename, _ = os.path.splitext(os.path.basename(f)) 46 | for si, s in enumerate(self.scale): 47 | names_lr[si].append(os.path.join(self.dir_lr, filename + self.ext[1])) 48 | 49 | return names_hr, names_lr 50 | 51 | def __getitem__(self, idx): 52 | if self.lr_only: 53 | f_lr = self.images_lr[self.idx_scale][idx] 54 | filename = os.path.splitext(os.path.basename(f_lr))[0] 55 | lr = imageio.imread(f_lr) 56 | if lr.shape[2] > 3: lr = lr[:, :, :3] 57 | # convert torch to numpy 58 | lr = np.ascontiguousarray(lr.transpose((2, 0, 1))) 59 | lr = torch.from_numpy(lr).float() 60 | 61 | return lr, -1, filename 62 | else: 63 | lr, hr, filename = self._load_file(idx) 64 | pair = self.get_patch(lr, hr) 65 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 66 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 67 | 68 | return pair_t[0], pair_t[1], filename -------------------------------------------------------------------------------- /src/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 8 | return nn.Conv2d( 9 | in_channels, out_channels, kernel_size, 10 | padding=(kernel_size//2), bias=bias) 11 | 12 | class MeanShift(nn.Conv2d): 13 | def __init__( 14 | self, rgb_range, 15 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 20 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 21 | for p in self.parameters(): 22 | p.requires_grad = False 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | class ResBlock(nn.Module): 38 | def __init__( 39 | self, conv, n_feats, kernel_size, 40 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 41 | 42 | super(ResBlock, self).__init__() 43 | m = [] 44 | for i in range(2): 45 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 46 | if bn: 47 | m.append(nn.BatchNorm2d(n_feats)) 48 | if i == 0: 49 | m.append(act) 50 | 51 | self.body = nn.Sequential(*m) 52 | self.res_scale = res_scale 53 | 54 | def forward(self, x): 55 | res = self.body(x).mul(self.res_scale) 56 | res += x 57 | 58 | return res 59 | 60 | class Upsampler(nn.Sequential): 61 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 62 | 63 | m = [] 64 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 65 | for _ in range(int(math.log(scale, 2))): 66 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 67 | m.append(nn.PixelShuffle(2)) 68 | if bn: 69 | m.append(nn.BatchNorm2d(n_feats)) 70 | if act == 'relu': 71 | m.append(nn.ReLU(True)) 72 | elif act == 'prelu': 73 | m.append(nn.PReLU(n_feats)) 74 | 75 | elif scale == 3: 76 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 77 | m.append(nn.PixelShuffle(3)) 78 | if bn: 79 | m.append(nn.BatchNorm2d(n_feats)) 80 | if act == 'relu': 81 | m.append(nn.ReLU(True)) 82 | elif act == 'prelu': 83 | m.append(nn.PReLU(n_feats)) 84 | else: 85 | raise NotImplementedError 86 | 87 | super(Upsampler, self).__init__(*m) 88 | 89 | -------------------------------------------------------------------------------- /src/data/patchkernel.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import math 5 | import time 6 | import torch 7 | try: 8 | import imageio.v2 as imageio 9 | except: 10 | import imageio 11 | import numpy as np 12 | import torch.utils.data as data 13 | import utils.utils_image as util 14 | import utils.basicsr_degradations as degradation 15 | 16 | 17 | class PatchKernel(data.Dataset): 18 | def __init__(self, args, name=None, train=True, task_idx=0): 19 | self.args = args 20 | self.scale = max(args.scale) 21 | self.task_idx = task_idx 22 | self.train = train 23 | self.kernel_range = args.kernel_range # kernel size ranges from 7 to 21 24 | self.kernel_list = args.kernel_list 25 | self.kernel_prob = args.kernel_prob 26 | self.blur_sigma = args.blur_sigma 27 | self.betag_range = args.betag_range 28 | self.betap_range = args.betap_range 29 | 30 | def set_image_path(self, image_path, image=None): 31 | self.image_path = image_path 32 | if image is not None: 33 | self.image = image 34 | else: 35 | self.image = imageio.imread(image_path)[:, :, :3] 36 | 37 | def set_image(self, image): 38 | if isinstance(image, torch.Tensor): 39 | self.image = util.tensor2single3(image) 40 | else: 41 | self.image = image 42 | 43 | def set_task(self, task_idx): 44 | self.task_idx = task_idx 45 | 46 | def __len__(self): 47 | return self.args.batch_size 48 | 49 | def __getitem__(self, idx): 50 | patch = self.get_patch(self.image, self.args.patch_size, self.scale) 51 | patch = self.augment(patch) 52 | if 1 in self.task_idx: 53 | kernel_size = random.choice(self.kernel_range) 54 | kernel = degradation.random_mixed_kernels( 55 | self.kernel_list, 56 | self.kernel_prob, 57 | kernel_size, 58 | self.blur_sigma, 59 | self.blur_sigma, [-math.pi, math.pi], 60 | self.betag_range, 61 | self.betap_range, 62 | noise_range=None) 63 | 64 | # pad kernel 65 | pad_size = (21 - kernel_size) // 2 66 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 67 | kernel = torch.FloatTensor(kernel) 68 | else: 69 | kernel = torch.zeros(1) 70 | 71 | patch = util.single2tensor3(patch) 72 | 73 | return patch, kernel 74 | 75 | def get_patch(self, img, patch_size=64, scale=2): 76 | h, w = img.shape[:2] 77 | rw = random.randrange(0, w - patch_size + 1) 78 | rh = random.randrange(0, h - patch_size + 1) 79 | 80 | patch = img[rh:rh + patch_size, rw:rw + patch_size, :].copy() 81 | return patch 82 | 83 | def augment(self, img_in, hflip=True, rot=True): 84 | hflip = hflip and random.random() < 0.5 85 | vflip = rot and random.random() < 0.5 86 | rot90 = rot and random.random() < 0.5 87 | 88 | def _augment(img): 89 | if hflip: img = img[:, ::-1, :] 90 | if vflip: img = img[::-1, :, :] 91 | if rot90: img = img.transpose(1, 0, 2) 92 | 93 | return img 94 | 95 | return _augment(img_in) -------------------------------------------------------------------------------- /src/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import random 4 | import pickle 5 | import time 6 | from data import common 7 | 8 | import numpy as np 9 | try: 10 | import imageio.v2 as imageio 11 | except: 12 | import imageio 13 | import torch 14 | import torch.utils.data as data 15 | 16 | class SRData(data.Dataset): 17 | def __init__(self, args, name='', train=True, benchmark=False): 18 | self.args = args 19 | self.name = name 20 | self.train = train 21 | self.split = 'train' if train else 'test' 22 | self.do_eval = True 23 | self.benchmark = benchmark 24 | self.scale = args.scale 25 | self.idx_scale = 0 26 | 27 | self._set_filesystem(args.dir_data) 28 | if args.ext.find('img') < 0: 29 | path_bin = os.path.join(self.apath, 'bin') 30 | os.makedirs(path_bin, exist_ok=True) 31 | 32 | list_hr, list_lr = self._scan() 33 | self.images_hr, self.images_lr = list_hr, list_lr 34 | 35 | if train: 36 | n_patches = args.batch_size * args.test_every 37 | n_images = len(args.data_train) * len(self.images_hr) 38 | if n_images == 0: 39 | self.repeat = 0 40 | else: 41 | self.repeat = max(n_patches // n_images, 1) 42 | 43 | # Below functions as used to prepare images 44 | def _scan(self): 45 | names_hr = sorted( 46 | glob.glob(os.path.join(self.dir_hr, '*' + self.ext[0])) 47 | ) 48 | names_lr = [[] for _ in self.scale] 49 | for f in names_hr: 50 | filename, _ = os.path.splitext(os.path.basename(f)) 51 | for si, s in enumerate(self.scale): 52 | names_lr[si].append(os.path.join( 53 | self.dir_lr, 'X{}/{}x{}{}'.format( 54 | s, filename, s, self.ext[1] 55 | ) 56 | )) 57 | 58 | return names_hr, names_lr 59 | 60 | def _set_filesystem(self, dir_data): 61 | self.apath = os.path.join(dir_data, self.name) 62 | self.dir_hr = os.path.join(self.apath, 'HR') 63 | self.dir_lr = os.path.join(self.apath, 'LR_bicubic') 64 | self.ext = ('.png', '.png') 65 | 66 | def __getitem__(self, idx): 67 | lr, hr, filename = self._load_file(idx) 68 | pair = self.get_patch(lr, hr) 69 | if len(pair)>2: 70 | pair=pair[0:2] 71 | pair = common.set_channel(*pair, n_channels=self.args.n_colors) 72 | pair_t = common.np2Tensor(*pair, rgb_range=self.args.rgb_range) 73 | 74 | return pair_t[0], pair_t[1], filename 75 | 76 | def __len__(self): 77 | if self.train: 78 | return len(self.images_hr) * self.repeat 79 | else: 80 | return len(self.images_hr) 81 | 82 | def _get_index(self, idx): 83 | if self.train: 84 | return idx % len(self.images_hr) 85 | else: 86 | return idx 87 | 88 | def _load_file(self, idx): 89 | idx = self._get_index(idx) 90 | f_hr = self.images_hr[idx] 91 | f_lr = self.images_lr[self.idx_scale][idx] 92 | 93 | filename, _ = os.path.splitext(os.path.basename(f_hr)) 94 | hr = imageio.imread(f_hr) 95 | lr = imageio.imread(f_lr) 96 | 97 | # some images have 4 channels 98 | if hr.shape[2] > 3: hr = hr[:, :, :3] 99 | if lr.shape[2] > 3: lr = lr[:, :, :3] 100 | 101 | return lr, hr, filename 102 | 103 | def get_patch(self, lr, hr): 104 | scale = self.scale[self.idx_scale] 105 | if self.train: 106 | lr, hr = common.get_patch( 107 | lr, hr, 108 | patch_size=self.args.patch_size, 109 | scale=scale, 110 | multi=(len(self.scale) > 1) 111 | ) 112 | if not self.args.no_augment: 113 | lr, hr = common.augment(lr, hr) 114 | else: 115 | ih, iw = lr.shape[:2] 116 | hr = hr[0:ih * scale, 0:iw * scale] 117 | return lr, hr 118 | 119 | def set_scale(self, idx_scale): 120 | self.idx_scale = idx_scale 121 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Efficient Test-Time Adaptation for Super-Resolution with Second-Order Degradation and Reconstruction (NeurIPS, 2023) 2 | 3 | This repository is the official PyTorch implementation of SRTTA with application to image super-resolution in test-time environment ([arXiv](https://arxiv.org/abs/2310.19011)). 4 | 5 | --- 6 | >Image super-resolution (SR) aims to learn a mapping from low-resolution (LR) to high-resolution (HR) using paired HR-LR training images. 7 | Conventional SR methods typically gather the paired training data by synthesizing LR images from HR images using a predetermined degradation model, e.g., Bicubic down-sampling. 8 | However, the realistic degradation type of test images may mismatch with the training-time degradation type 9 | due to the dynamic changes of the real-world scenarios, 10 | resulting in inferior-quality SR images. 11 | To address this, existing methods attempt to estimate the degradation model and train an image-specific model, which, however, is quite time-consuming and impracticable to handle rapidly changing domain shifts. 12 | Moreover, these methods largely concentrate on the estimation of one degradation type (e.g., blur degradation), overlooking other degradation types like noise and JPEG in real-world test-time scenarios, thus limiting their practicality. 13 | To tackle these problems, we present an efficient test-time adaptation framework for SR, named SRTTA, which is able to quickly adapt SR models to test domains with different/unknown degradation types. 14 | Specifically, we design a second-order degradation scheme to construct paired data based on the degradation type of the test image, which is predicted by a pre-trained degradation classifier. 15 | Then, we adapt the SR model by implementing feature-level reconstruction learning from the initial test image to its second-order degraded counterparts, which helps the SR model generate plausible HR images. 16 | Extensive experiments are conducted on newly synthesized corrupted DIV2K datasets with 8 different degradations and several real-world datasets, demonstrating that our SRTTA framework achieves an impressive improvement over existing methods with satisfying speed. 17 | > 18 | 19 | ## Requirements 20 | * Python 3.8, Pytorch 1.10 21 | * More details (See [requirements.txt](requirements.txt)) 22 | 23 | ## Datasets 24 | Download the dataset from the following links and put them in ./datasets. 25 | 26 | * BaiduYun : https://pan.baidu.com/s/1dJL941VJDo5nwXv5CEQzbQ (code:08v7) 27 | 28 | * GoogleDrive : https://drive.google.com/drive/folders/1lC9h4DdP3wrKIDRAMYUHPW-GFOMBbdxx?usp=drive_link 29 | 30 | ## Adapted checkpoints 31 | We provide adapted checkpoints to reproduce the results of the paper. You can download from the following links. 32 | 33 | * BaiduYun : https://pan.baidu.com/s/1IWKguxWE2KX7Wa1tKITZrA (code:6w14) 34 | 35 | * GoogleDrive : https://drive.google.com/drive/folders/1ks1rlb0HvBQRGeLLyDeLkIOdE8tKcCWS?usp=drive_link 36 | 37 | ## Evaluation on Adapted models 38 | Download the checkpoints from the above links and put them in src/checkpoints, then run the following commands. 39 | 40 | ```shell 41 | cd src 42 | bash scripts/test_checkpoints.sh 43 | ``` 44 | 45 | ## Run test-time adaptation on DIV2K-C and DIV2K-MC 46 | Download the pretrained models of EDSR (we also provide them in the above links), then put them in src/checkpoints and run the commands below. 47 | 48 | ```shell 49 | # test-time adaptation on DIV2K-C 50 | cd src 51 | # for x2 scale 52 | bash scripts/tta_div2kc_x2.sh 53 | # for x4 scale 54 | bash scripts/tta_div2kc_x4.sh 55 | 56 | # test-time adaptation on DIV2K-MC 57 | bash scripts/tta_div2kmc_x2.sh 58 | ``` 59 | 60 | ## Citation 61 | ```bibtex 62 | @inproceedings{ 63 | deng2023efficient, 64 | title={Efficient Test-Time Adaptation for Super-Resolution with Second-Order Degradation and Reconstruction}, 65 | author={Zeshuai Deng and Zhuokun Chen and Shuaicheng Niu and Thomas H. Li and Bohan Zhuang and Mingkui Tan}, 66 | booktitle={Thirty-seventh Conference on Neural Information Processing Systems}, 67 | year={2023}, 68 | url={https://openreview.net/forum?id=IZRlMABK4l} 69 | } 70 | ``` 71 | 72 | ## Acknowledgement 73 | The codes are based on [EDSR-PyTorch](https://github.com/sanghyun-son/EDSR-PyTorch). Thanks for their great efforts. -------------------------------------------------------------------------------- /src/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | import torch 3 | import os 4 | import torch.nn as nn 5 | 6 | url = { 7 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 8 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 9 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 10 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 11 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 12 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 13 | } 14 | 15 | def make_model(args, parent=False): 16 | return EDSR(args) 17 | 18 | class EDSR(nn.Module): 19 | def __init__(self, args, conv=common.default_conv): 20 | super(EDSR, self).__init__() 21 | self.args = args 22 | n_resblocks = args.n_resblocks 23 | n_feats = args.n_feats 24 | kernel_size = 3 25 | scale = max(args.scale) 26 | self.scale = scale 27 | act = nn.ReLU(True) 28 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 29 | if url_name in url: 30 | self.url = url[url_name] 31 | else: 32 | self.url = None 33 | self.sub_mean = common.MeanShift(args.rgb_range) 34 | self.add_mean = common.MeanShift(args.rgb_range, sign=1) 35 | 36 | # define head module 37 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 38 | 39 | # define body module 40 | m_body = [ 41 | common.ResBlock( 42 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 43 | ) for _ in range(n_resblocks) 44 | ] 45 | m_body.append(conv(n_feats, n_feats, kernel_size)) 46 | 47 | # define tail module 48 | m_tail = [ 49 | common.Upsampler(conv, scale, n_feats, act=False), 50 | conv(n_feats, args.n_colors, kernel_size) 51 | ] 52 | 53 | self.head = nn.Sequential(*m_head) 54 | self.body = nn.Sequential(*m_body) 55 | self.tail = nn.Sequential(*m_tail) 56 | 57 | self.upsampler = m_tail[0] # use for aux forward 58 | 59 | 60 | def forward(self, x, aux_forward=False): 61 | if aux_forward: 62 | return self.aux_forward(x) 63 | else: 64 | return self.sr_forward(x) 65 | 66 | def aux_forward(self, x): 67 | x = self.sub_mean(x) 68 | x = self.head(x) 69 | 70 | res = self.body(x) 71 | res += x 72 | res = self.upsampler(res) 73 | 74 | return res 75 | 76 | def sr_forward(self, x): 77 | x = self.sub_mean(x) 78 | x = self.head(x) 79 | 80 | res = self.body(x) 81 | res += x 82 | 83 | x = self.tail(res) 84 | x = self.add_mean(x) 85 | 86 | return x 87 | 88 | def save(self, apath, epoch, is_best=False): 89 | save_dirs = [os.path.join(apath, 'model_latest.pt')] 90 | 91 | if is_best: 92 | save_dirs.append(os.path.join(apath, 'model_best.pt')) 93 | if self.save_models: 94 | save_dirs.append( 95 | os.path.join(apath, 'model_{}.pt'.format(epoch)) 96 | ) 97 | 98 | for s in save_dirs: 99 | torch.save(self.model.state_dict(), s) 100 | 101 | def load_state_dict(self, state_dict, strict=True): 102 | own_state = self.state_dict() 103 | for name, param in state_dict.items(): 104 | if name in own_state: 105 | if isinstance(param, nn.Parameter): 106 | param = param.data 107 | try: 108 | own_state[name].copy_(param) 109 | except Exception: 110 | if name.find('tail') == -1: 111 | raise RuntimeError('While copying the parameter named {}, ' 112 | 'whose dimensions in the model are {} and ' 113 | 'whose dimensions in the checkpoint are {}.' 114 | .format(name, own_state[name].size(), param.size())) 115 | elif strict: 116 | if name.find('tail') == -1: 117 | raise KeyError('unexpected key "{}" in state_dict' 118 | .format(name)) 119 | 120 | -------------------------------------------------------------------------------- /src/data/div2kmd.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import random 5 | import torch 6 | import numpy as np 7 | from tqdm import tqdm 8 | from torchvision import transforms 9 | from torch.utils.data import Dataset 10 | 11 | import utils.utils_tta as util 12 | 13 | 14 | def trainTranform(img_size=224): 15 | transform = transforms.Compose([ 16 | transforms.RandomCrop(size=img_size), 17 | transforms.RandomApply([ 18 | transforms.RandomRotation( 19 | 90, resample=False, expand=False, center=None) 20 | ], p=0.8) 21 | ]) 22 | 23 | return transform 24 | 25 | 26 | class DIV2KMD(Dataset): 27 | def __init__(self, base_path, img_size=224, train=True, cache=False): 28 | super(DIV2KMD, self).__init__() 29 | self.img_size = img_size 30 | self.single_type = util.get_corruptions() 31 | self.mixed_types = util.get_mixed_corruptions() 32 | self.corruptions = self.single_type + self.mixed_types 33 | self.classes = ["blur", "noise", "jpeg"] 34 | if train: 35 | self.img_paths = list(sorted([os.path.join(base_path, p) for p in os.listdir(base_path)])) 36 | self.img_paths = self.img_paths[0:800] 37 | self.transform = trainTranform(img_size) 38 | else: 39 | self.img_paths = [] 40 | self.labels = [] 41 | for cor in os.listdir(base_path): 42 | img_names = sorted(glob.glob(os.path.join(base_path, cor, 'X2', '*.png'))) 43 | for img_name in img_names: 44 | img_label = [] 45 | self.img_paths.append(f"{img_name}") 46 | if "blur" in cor.lower(): 47 | cor_cls = "blur" 48 | img_label.append(self.classes.index(cor_cls)) # class 0 49 | if "noise" in cor.lower(): 50 | cor_cls = "noise" 51 | img_label.append(self.classes.index(cor_cls)) # class 1 52 | if "jpeg" in cor.lower(): 53 | cor_cls = "jpeg" 54 | img_label.append(self.classes.index(cor_cls)) # class 2 55 | self.labels.append(img_label) 56 | 57 | self.train = train 58 | self.cache = cache 59 | if self.cache: 60 | print("caching images...") 61 | self.imgs = [cv2.imread(path)[:, :, [2, 1, 0]] for path in tqdm(self.img_paths)] 62 | 63 | def degrade_img(self, img, cor_type): 64 | """"preprocess high-resolution images with random degradation""" 65 | img = img.copy() 66 | if not self.train: return img 67 | 68 | if cor_type in self.single_type: 69 | img = util.preprocess_img(img, scale=2, corruption=cor_type) 70 | else: 71 | operators = ['blur', 'down', 'noise', 'jpeg'] 72 | for op in operators: 73 | if op == 'blur' and op in cor_type.lower(): 74 | degradation_type = random.choice( 75 | ['GaussianBlur', 'DefocusBlur', 'GlassBlur']) 76 | img = util.preprocess_img(img, scale=1, corruption=degradation_type) 77 | if op == 'down': 78 | # downsampling image using bicubic interpolation 79 | img = util.preprocess_img(img, scale=2, corruption='Original') 80 | if op == 'noise' and op in cor_type.lower(): 81 | degradation_type = random.choice( 82 | ['GaussianNoise', 'PoissonNoise', 'ImpulseNoise', 'SpeckleNoise']) 83 | img = util.preprocess_img(img, scale=1, corruption=degradation_type) 84 | if op == 'jpeg' and op in cor_type.lower(): 85 | degradation_type = random.choice(['JPEG']) 86 | img = util.preprocess_img(img, scale=1, corruption=degradation_type) 87 | return img 88 | 89 | def __getitem__(self, idx): 90 | if self.cache: 91 | img = self.imgs[idx].copy() 92 | else: 93 | try: 94 | img = cv2.imread(self.img_paths[idx])[:, :, [2, 1, 0]] 95 | except: 96 | print(1) 97 | 98 | img = np.ascontiguousarray(img).astype(np.float32) 99 | selected_types = random.choices(self.corruptions, k=1)[0] 100 | if self.train: 101 | # convert numpy array to tensor, HWC -> CHW 102 | input_img = torch.from_numpy(img).permute(2, 0, 1) 103 | # randomly crop images 104 | input_img = self.transform(input_img) 105 | # degrada images 106 | input_img = self.degrade_img(input_img.permute(1, 2, 0).numpy(), selected_types) 107 | # convert numpy array to tensor, normalize to [0-1] 108 | input_img = torch.from_numpy(input_img).permute(2, 0, 1).float() / 255. 109 | 110 | if "origin" in selected_types.lower(): 111 | # label for clean image 112 | label = torch.zeros(3).float() 113 | else: 114 | label = torch.zeros(3).float() 115 | for idx, c in enumerate(self.classes): 116 | if c in selected_types.lower(): 117 | label[idx] = 1. 118 | else: 119 | input_img = torch.from_numpy(img).permute(2,0,1) / 255. 120 | if len(self.labels[idx]) == 0: 121 | label = torch.zeros(3).float() 122 | else: 123 | label = torch.nn.functional.one_hot(torch.tensor( 124 | [p for p in self.labels[idx]]), len(self.classes)).sum(0).bool().float() 125 | return input_img, label, self.img_paths[idx] 126 | 127 | def __len__(self): 128 | return len(self.img_paths) -------------------------------------------------------------------------------- /src/test_srtta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import data 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from tqdm import tqdm 8 | from model.edsr import EDSR 9 | from model.classifier import Classifier 10 | import utils.utils_tta as util 11 | from utils.utils_tta import test_all, get_corruptions 12 | 13 | def run(args): 14 | if isinstance(args.scale,int): 15 | args.scale=[args.scale] 16 | args.data_test = args.data_test.split('+') 17 | val_dataloader = data.Data(args) 18 | test_loaders = val_dataloader.loader_test 19 | model = EDSR(args).cuda() 20 | corruptions = [] 21 | if "single" in args.corruptions: 22 | corruptions = corruptions + get_corruptions()[1:] 23 | if "multi" in args.corruptions: 24 | corruptions = corruptions + ["BlurJPEG","BlurNoise","NoiseJPEG","BlurNoiseJPEG"] 25 | origin_model = EDSR(args).cuda() 26 | origin_model = util.load_model(args.base_model, origin_model) 27 | if args.cls_model: 28 | cls_model=Classifier().eval().cuda() 29 | state_dict = torch.load(args.cls_model) 30 | cls_model.load_state_dict(state_dict,strict=True) 31 | else: 32 | cls_model = None 33 | corruption_psnrs = {} 34 | corruption_ssims = {} 35 | corruption_srs = {} 36 | log_txt = {} 37 | for corruption in tqdm(corruptions): 38 | model = util.load_model(f"{args.pre_train}/state_{corruption}_last.pt", model) 39 | for _, t_data in enumerate(test_loaders): 40 | data_name = t_data.dataset.name 41 | if data_name not in corruption_psnrs: 42 | corruption_psnrs[data_name] = [] 43 | if data_name not in corruption_ssims: 44 | corruption_ssims[data_name] = [] 45 | if data_name not in corruption_srs: 46 | corruption_srs[data_name] = {} 47 | if data_name not in log_txt: 48 | log_txt[data_name] = "" 49 | if data_name == "Set5": 50 | psnrs, ssims, srs = test_all(args, model, t_data, return_sr=True) 51 | # record metric 52 | res_txt=f"{data_name} : PSNR = {np.mean(psnrs)} SSIM = {np.mean(ssims)}" 53 | log_txt[data_name] += res_txt + "\n" 54 | print(res_txt) 55 | else: 56 | t_data.dataset.set_corruption(corruption) 57 | psnrs, ssims,srs = test_all(args, model, t_data, 58 | origin_model=origin_model, cls_model=cls_model, return_sr=True) 59 | # record metric 60 | res_txt = f"{data_name} - {corruption} : PSNR = {np.mean(psnrs)} SSIM = {np.mean(ssims)}" 61 | log_txt[data_name] += res_txt + "\n" 62 | print(res_txt) 63 | corruption_psnrs[data_name].append(np.mean(psnrs)) 64 | corruption_ssims[data_name].append(np.mean(ssims)) 65 | corruption_srs[data_name][corruption]=srs 66 | 67 | print("finish testing.") 68 | for _, t_data in enumerate(test_loaders): 69 | data_name = t_data.dataset.name 70 | print(f"------------------Results of {data_name}---------------------") 71 | log_txt[data_name] += f"{data_name} - AVERAGE : PSNR = {np.mean(corruption_psnrs[data_name])} SSIM = {np.mean(corruption_ssims[data_name])}" 72 | print(log_txt[data_name]) 73 | if args.save_dir: 74 | print("start saving results...") 75 | for _, t_data in enumerate(tqdm(test_loaders)): 76 | data_name = t_data.dataset.name 77 | for corruption in corruptions: 78 | os.makedirs(f"{args.save_dir}/{corruption}",exist_ok=True) 79 | img_names=[img_name.split("/")[-1] for img_name in t_data.dataset.images_lr[0]] 80 | for idx,sr in enumerate(corruption_srs[data_name][corruption]): 81 | cv2.imwrite(f"{args.save_dir}/{corruption}/{img_names[idx]}",sr.numpy().astype(np.uint8)[:,:,[2,1,0]]) 82 | print("finish saving results.") 83 | 84 | if __name__ == "__main__": 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--pre_train', type=str, default='checkpoints/SRTTA/srtta_lifelong_x2.pt', 87 | help='pre-trained model directory') 88 | parser.add_argument('--dir_data', type=str, default='/chenzhuokun/datasets/', help='dataset directory') 89 | parser.add_argument('--base_model', type=str, default='checkpoints/EDSR_baseline_x2.pt', help='path to base model') 90 | parser.add_argument('--cls_model', type=str, default=None, help='path to cls model') 91 | parser.add_argument('--corruptions', type=str,default="single", help='multi-corruption') 92 | parser.add_argument('--multi-corruption', action='store_true', help='multi-corruption',default=True) 93 | parser.add_argument('--tta_data', type=str, default='DIV2KC', help='test dataset name') 94 | parser.add_argument('--data_train', type=str, default='PatchKernel', help='train dataset name') 95 | parser.add_argument('--data_test', type=str, default='DIV2KC', help='test dataset name') 96 | parser.add_argument('--cpu', action='store_true', help='use cpu only') 97 | parser.add_argument('--save_dir', type=str, default="", help='train dataset name') 98 | parser.add_argument('--debug', action='store_true', help='set this option to debugs the code') 99 | parser.add_argument('--model', default='EDSR', help='model name') 100 | parser.add_argument('--ext', type=str, default='img', help='dataset file extension') 101 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') 102 | parser.add_argument('--test_only', action='store_true', help='set this option to test the model',default=True) 103 | parser.add_argument('--n_threads', type=int, default=12, help='number of threads for data loading') 104 | parser.add_argument('--n_resblocks', type=int, default=16, help='number of residual blocks') 105 | parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') 106 | parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') 107 | parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB') 108 | parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') 109 | parser.add_argument('--scale', type=int, default=[2], help='super resolution scale') 110 | args = parser.parse_args() 111 | run(args) -------------------------------------------------------------------------------- /src/tool/prepare_tta_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import cv2 4 | import math 5 | import numpy as np 6 | import torch 7 | import random 8 | 9 | from tqdm import tqdm 10 | import argparse 11 | import imageio 12 | 13 | from multiprocessing import Pool 14 | 15 | import sys 16 | sys.path.insert(1, '..') 17 | import utils.utils_image as util 18 | import utils.utils_blindsr as blindsr 19 | import utils.utils_blindsr_plus as blindsr_plus 20 | 21 | def get_corruptions(): 22 | """ Original: bicubic downsample """ 23 | corruptions = ['Original', 'GaussianBlur', 'DefocusBlur', 'GlassBlur', 24 | 'GaussianNoise', 'PoissonNoise', 'ImpulseNoise', 'SpeckleNoise', 'JPEG'] 25 | return corruptions 26 | 27 | 28 | def preprocess_img(img, scale=2, corruption='Original'): 29 | corruptions = get_corruptions() 30 | corr_idx = corruptions.index(corruption) 31 | 32 | img = util.uint2single(img) 33 | img_lr = None 34 | if corr_idx == 0: 35 | img_lr = blindsr.bicubic_degradation(img, scale) 36 | 37 | if corruption.lower().find('blur') >= 0: 38 | img_blur = img 39 | if corr_idx == 1: 40 | img_blur = blindsr.add_blur(img_blur, scale) 41 | if corr_idx == 2: 42 | img_blur = blindsr_plus.add_defocus_blur(img_blur) 43 | if corr_idx == 3: 44 | img_blur = blindsr_plus.add_glass_blur(img_blur) 45 | img_lr = blindsr.bicubic_degradation(img_blur, scale) 46 | 47 | if corruption.lower().find('noise') >= 0: 48 | img_lr = blindsr.bicubic_degradation(img, scale) 49 | if corr_idx == 4: 50 | img_lr = blindsr.add_Gaussian_noise(img_lr, noise_level1=2, noise_level2=25) 51 | 52 | if corr_idx == 5: 53 | img_lr = blindsr_plus.add_scale_Poisson_noise(img_lr) 54 | 55 | if corr_idx == 6: 56 | img_lr = blindsr_plus.add_impluse_noise(img_lr) 57 | 58 | if corr_idx == 7: 59 | img_lr = blindsr.add_speckle_noise(img_lr) 60 | 61 | if corr_idx == 8: 62 | img_lr = blindsr.bicubic_degradation(img, scale) 63 | img_lr = blindsr.add_JPEG_noise(img_lr) 64 | 65 | if img_lr is None: # defalut bicubic 66 | print("The corruption are not support, default bicubic") 67 | img_lr = blindsr.bicubic_degradation(img, scale) 68 | 69 | # signle2uint 70 | img_lr = util.single2uint(img_lr) # 0-1 to 0-255 71 | 72 | return img_lr 73 | 74 | def process(img_path, save_dir, scale=2, corruption='Original'): 75 | img_hr = imageio.imread(img_path) 76 | if img_hr.shape[2] > 3: img_hr = img_hr[:,:,:3] 77 | img_lr = preprocess_img(img_hr, scale, corruption=corruption) 78 | basename = os.path.splitext(os.path.basename(img_path))[0] 79 | save_path = os.path.join(save_dir, basename + '.png') 80 | imageio.imwrite(save_path, img_lr) 81 | 82 | return basename, corruption 83 | 84 | 85 | def get_image_paths(input_dir): 86 | img_paths = sorted(glob.glob(os.path.join(input_dir, "*.png"))) 87 | return img_paths 88 | 89 | def main(args): 90 | img_paths = get_image_paths(args.input_dir) 91 | corruptions = get_corruptions() 92 | 93 | if len(img_paths) < len(corruptions) * args.n_per_corruption: 94 | repeat = len(corruptions) * args.n_per_corruption // len(img_paths) + 1 95 | img_paths = img_paths * repeat 96 | 97 | if args.n_workers > 1: 98 | print(f'Read images with multiprocessing, #thread: {args.n_workers} ...') 99 | pool = Pool(args.n_workers) 100 | 101 | pbar = tqdm(total=len(img_paths)*len(args.scale), unit='image', ncols=100) 102 | 103 | def callback(args): 104 | """get the image data and update pbar.""" 105 | basename, corruption = args 106 | pbar.update(1) 107 | pbar.set_description(f'Processing {basename} with {corruption} ...') 108 | 109 | for s in args.scale: 110 | start_idx = 0 111 | for corruption in corruptions: 112 | if args.debug and corruption != args.corruption: 113 | continue 114 | 115 | end_idx = start_idx + args.n_per_corruption 116 | assert end_idx < len(img_paths) 117 | save_dir = os.path.join(args.output_dir, corruption, 'X{}'.format(s)) 118 | if not os.path.exists(save_dir): 119 | os.makedirs(save_dir) 120 | 121 | for img_path in img_paths[start_idx:end_idx]: 122 | if args.n_workers > 1: 123 | pool.apply_async( 124 | process, 125 | args=(img_path, save_dir, s, corruption), 126 | callback=callback 127 | ) 128 | else: 129 | print("Processing {} with {} ...".format(os.path.basename(img_path), corruption)) 130 | process(img_path, save_dir, s, corruption) 131 | 132 | if args.n_workers > 1: 133 | pool.close() 134 | pool.join() 135 | pbar.close() 136 | 137 | print(f'\nFinish processing.') 138 | 139 | 140 | def set_seed(seed): 141 | random.seed(seed) 142 | np.random.seed(seed) 143 | torch.manual_seed(seed) 144 | if torch.cuda.device_count() == 1: 145 | torch.cuda.manual_seed(seed) 146 | else: 147 | torch.cuda.manual_seed_all(seed) 148 | 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | parser.add_argument('--input_dir', type=str, default='/mnt/cephfs/home/dengzeshuai/data/sr/DIV2KRK/gt_rename/', help='Input folder') 154 | parser.add_argument('--output_dir', type=str, default='/mnt/cephfs/home/dengzeshuai/data/sr/DIV2KRK/corruptions/', help='Output folder') 155 | parser.add_argument('--scale', type=str, default='2', help='super resolution scale') 156 | parser.add_argument('--debug', action='store_true', help='set this option to debugs the code') 157 | parser.add_argument('--corruption', type=str, default='ImpluseNoise', help='the type of ImpluseNoise for debugging') 158 | parser.add_argument('--n_per_corruption', type=int, default=100, help='number of aux head to train') 159 | parser.add_argument('--n_workers', type=int, default=12, help='number of workers to process image') 160 | parser.add_argument('--seed', type=int, default=0, help='random seed for reproduce') 161 | args = parser.parse_args() 162 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 163 | 164 | set_seed(args.seed) 165 | 166 | main(args) 167 | -------------------------------------------------------------------------------- /src/train_classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import argparse 4 | import torch 5 | from torch.utils.data import DataLoader 6 | from tqdm import tqdm 7 | from torch.utils.tensorboard import SummaryWriter 8 | from utils.utils_tta import set_seed 9 | from data.div2kmd import DIV2KMD 10 | from model.classifier import Classifier 11 | import utils.utils_tta as util 12 | import warnings 13 | warnings.filterwarnings("ignore") 14 | 15 | def train_classifier(args, classifier, logger, writer,device="cuda"): 16 | classifier.train() 17 | 18 | train_dataset = DIV2KMD(base_path=args.training_dir,cache=args.cache) 19 | test_dataset = DIV2KMD(base_path=args.test_dir,train=False, cache=args.cache) 20 | 21 | optimizer = torch.optim.Adam(classifier.parameters(), lr=args.lr) 22 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 23 | optimizer=optimizer, T_max=args.total_epoch,eta_min=args.lr * args.lf) 24 | train_dataloader = DataLoader(train_dataset, 25 | batch_size=args.bs, num_workers=args.worker, shuffle=True) 26 | test_dataloader = DataLoader(test_dataset, 27 | batch_size=1, num_workers=args.worker, shuffle=False) 28 | best_acc = 0 29 | eval_record_threds = [0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] 30 | for epoch in range(args.total_epoch): 31 | pbar = tqdm(train_dataloader, total=len(train_dataset)//args.bs + 1, ncols=100) 32 | loss_sum = 0 33 | idx_iter = 0 34 | for imgs,labels,paths in pbar: 35 | imgs = imgs.cuda() # batch[0] = batch[0][0][0] 36 | labels = labels.cuda() # # batch[1] 37 | 38 | optimizer.zero_grad() 39 | pred = classifier(imgs) 40 | loss = classifier.computeLoss(pred, labels) 41 | loss.backward() 42 | optimizer.step() 43 | scheduler.step(epoch) 44 | loss_sum += loss 45 | pbar.set_postfix({"epoch": f"{epoch}", "loss": f"{loss:.2f}"}) 46 | idx_iter += 1 47 | writer.add_scalar('train/train_loss', loss_sum/len(pbar), epoch) 48 | 49 | if epoch % 20 == 0 or epoch == args.total_epoch - 1: 50 | right_dict = {} 51 | for eval_record_thred in eval_record_threds: 52 | right_dict[eval_record_thred] = 0 53 | classifier.eval() 54 | pbar = tqdm(enumerate(test_dataloader), total=len(test_dataset),ncols=100) 55 | sum = 0 56 | loss_ce_sum = 0 57 | for _, (imgs, labels,paths) in pbar: 58 | imgs = imgs.cuda() 59 | labels = labels.cuda() 60 | with torch.no_grad(): 61 | pred = classifier(imgs) 62 | loss_ce_sum += classifier.computeLoss(pred, labels) 63 | pred = pred.sigmoid() 64 | pred = pred[0] 65 | for eval_record_thred in eval_record_threds: 66 | right_dict[eval_record_thred]+=((labels[0].float().to(device)!=(((pred)>eval_record_thred))).sum().item()==0) 67 | sum += imgs.shape[0] 68 | pbar.set_postfix({"epoch": f"{epoch}","mode": "eval"}) 69 | acc05 = float(right_dict[eval_record_thred]) / float(sum) 70 | if acc05 > best_acc: 71 | best_acc=acc05 72 | torch.save(classifier.state_dict(),os.path.join(args.save_dir, "best.pt")) 73 | torch.save(classifier.state_dict(),os.path.join(args.save_dir, "last.pt")) 74 | 75 | log_txt = f"Epoch {epoch}" 76 | for eval_record_thred in eval_record_threds: 77 | acc = float(right_dict[eval_record_thred])/float(sum) 78 | log_txt += f", acc-{eval_record_thred:.1f}={acc:.4f}" 79 | logger.info(log_txt) 80 | writer.add_scalar('val/acc', acc, epoch) 81 | writer.add_scalar('val/loss_ce', loss_ce_sum/len(pbar), epoch) 82 | classifier.train() 83 | return classifier 84 | 85 | def evaluate_classifier(args, classifier, device="cuda"): 86 | classifier.eval() 87 | test_dataset = DIV2KMD(base_path=args.test_dir,train=False, cache=args.cache) 88 | test_dataloader = DataLoader(test_dataset, 89 | batch_size=1, num_workers=args.worker, shuffle=False) 90 | pbar=tqdm(enumerate(test_dataloader), total=len(test_dataset),ncols=100) 91 | right=0 92 | sum=0 93 | time_all=0 94 | for _,(imgs, labels,paths) in pbar: 95 | imgs = imgs.to(device) 96 | labels = labels.to(device) 97 | with torch.no_grad(): 98 | start_inference=time.time() 99 | pred=classifier(imgs) 100 | time_all+=time.time()-start_inference 101 | pred_right=((labels[0].to(device)!=(pred.sigmoid().mean(0)>0.5)).sum().item()==0) 102 | right+=pred_right 103 | sum+=imgs.shape[0] 104 | pbar.set_postfix({"epoch": f"{0}","mode": "eval"}) 105 | acc=float(right)/float(sum) 106 | print(f"accuracy = {acc}") 107 | print(f"avg time : {time_all/sum}") 108 | 109 | def main(args): 110 | set_seed(args.seed) 111 | 112 | if not os.path.exists(args.save_dir): 113 | os.makedirs(args.save_dir) 114 | 115 | logger = util.get_logger(f"{args.save_dir}/log.txt") 116 | writer = SummaryWriter(args.save_dir) 117 | 118 | classifier = Classifier().cuda() 119 | 120 | classifier = train_classifier(args, classifier, logger, writer) 121 | 122 | evaluate_classifier(args, classifier) 123 | 124 | if __name__ == "__main__": 125 | parser = argparse.ArgumentParser() 126 | parser.add_argument('--training_dir', type=str, default='../datasets/DIV2K/DIV2K_train_HR', help='dataset directory') 127 | parser.add_argument('--test_dir', type=str, default='../datasets/DIV2KC/corruptions', help='dataset directory') 128 | parser.add_argument('--bs', type=int, default=16, help='batch size') 129 | parser.add_argument('--lr', type=float, default=0.01, help='learning rate') 130 | parser.add_argument('--lf', type=float, default=0.01, help='lr final=lr*lf') 131 | parser.add_argument('--worker', type=int, default=8, help='number of workers') 132 | parser.add_argument('--seed', type=int, default=0, help='seed for reproduction') 133 | parser.add_argument('--total_epoch', type=int, default=300, help='total training epoch') 134 | parser.add_argument('--save_dir', type=str, default='../expriment/degradation_classifier/', help='exp name') 135 | parser.add_argument('--train_dtypes', type=str, default='single+multi', help='degradation types for preparing training data') 136 | parser.add_argument('--cache', action="store_true", help='cache images') 137 | parser.add_argument('--img_size', type=int, default=224, help='the size of training images') 138 | parser.add_argument('--shuffle', action="store_true", help='shuffle the order of precoss images') 139 | args = parser.parse_args() 140 | 141 | main(args) 142 | -------------------------------------------------------------------------------- /src/utils/utils_tta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import time 4 | try: 5 | import imageio.v2 as imageio 6 | except: 7 | import imageio 8 | from tqdm import tqdm 9 | import numpy as np 10 | import logging 11 | import random 12 | import torch 13 | 14 | import utils.utils_image as util 15 | import utils.utils_blindsr as blindsr 16 | import utils.utils_blindsr_plus as blindsr_plus 17 | 18 | def get_corruptions(): 19 | """ Original: bicubic downsample """ 20 | corruptions = ['Original', 'GaussianBlur', 'DefocusBlur', 'GlassBlur', 21 | 'GaussianNoise', 'PoissonNoise', 'ImpulseNoise', 'SpeckleNoise', 'JPEG'] 22 | return corruptions 23 | 24 | def get_mixed_corruptions(): 25 | """ Original: bicubic downsample """ 26 | corruptions = ["BlurJPEG", "BlurNoise", "NoiseJPEG", "BlurNoiseJPEG"] 27 | return corruptions 28 | 29 | def preprocess_img(img, scale=2, corruption='Original'): 30 | corruptions = get_corruptions() 31 | corr_idx = corruptions.index(corruption) 32 | 33 | img = util.uint2single(img) 34 | img_lr = None 35 | if corr_idx == 0: 36 | if scale != 1: 37 | img_lr = blindsr.bicubic_degradation(img, scale) 38 | else: 39 | img_lr = img 40 | 41 | if corruption.lower().find('blur') >= 0: 42 | if corr_idx == 1: 43 | img_blur = blindsr.add_blur(img, scale) 44 | if corr_idx == 2: 45 | img_blur = blindsr_plus.add_defocus_blur(img) 46 | if corr_idx == 3: 47 | img_blur = blindsr_plus.add_glass_blur(img) 48 | 49 | if scale != 1: 50 | img_lr = blindsr.bicubic_degradation(img_blur, scale) 51 | else: 52 | img_lr = img_blur 53 | 54 | if corruption.lower().find('noise') >= 0: 55 | if scale != 1: 56 | img_lr = blindsr.bicubic_degradation(img, scale) 57 | else: 58 | img_lr = img 59 | 60 | if corr_idx == 4: 61 | img_lr = blindsr.add_Gaussian_noise(img_lr, noise_level1=2, noise_level2=25) 62 | 63 | if corr_idx == 5: 64 | img_lr = blindsr_plus.add_scale_Poisson_noise(img_lr) 65 | 66 | if corr_idx == 6: 67 | img_lr = blindsr_plus.add_impluse_noise(img_lr) 68 | 69 | if corr_idx == 7: 70 | img_lr = blindsr.add_speckle_noise(img_lr) 71 | 72 | if corruption.lower().find('jpeg') >= 0: 73 | if scale != 1: 74 | img_lr = blindsr.bicubic_degradation(img, scale) 75 | else: 76 | img_lr = img 77 | img_lr = blindsr.add_JPEG_noise(img_lr) 78 | 79 | if img_lr is None: # defalut bicubic 80 | print("The corruption are not support, default bicubic") 81 | if scale != 1: 82 | img_lr = blindsr.bicubic_degradation(img, scale) 83 | else: 84 | img_lr = img 85 | 86 | # signle2uint 87 | img_lr = util.single2uint(img_lr) # 0-1 to 0-255 88 | 89 | return img_lr 90 | 91 | 92 | 93 | 94 | 95 | def read_img(img_path): 96 | img = imageio.imread(img_path) 97 | # some images have 4 channels 98 | if img.shape[2] > 3: img = img[:, :, :3] 99 | return img 100 | 101 | def write_img(save_path, image): 102 | imageio.imwrite(save_path, image) 103 | 104 | 105 | def get_paths(input_dir, target_dir): 106 | if not input_dir == target_dir: 107 | lr_paths = sorted(glob.glob(os.path.join(input_dir, "*"))) 108 | gt_paths = sorted(glob.glob(os.path.join(target_dir, "*"))) 109 | # check whether the lr path is corresponding to the gt path 110 | for lr_path, gt_path in zip(lr_paths, gt_paths): 111 | lr_name = os.path.basename(lr_path) 112 | gt_name = os.path.basename(gt_path) 113 | assert lr_name == gt_name 114 | else: 115 | scale = max(args.scale) if isinstance(args.scale, list) else scale 116 | lr_paths = sorted(glob.glob(os.path.join(input_dir, "*_LR{}.png".format(scale)))) 117 | gt_paths = sorted(glob.glob(os.path.join(target_dir, "*_HR.png"))) 118 | for lr_path, gt_path in zip(lr_paths, gt_paths): 119 | lr_name = os.path.basename(lr_path) 120 | gt_name = os.path.basename(gt_path).replace('_HR', '_LR{}'.format(scale)) 121 | assert lr_name == gt_name 122 | 123 | return lr_paths, gt_paths 124 | 125 | 126 | def test_one(args, model, lr_img, hr_img, sr_img=None,return_sr=False): 127 | model.eval() 128 | if not isinstance(lr_img, torch.Tensor): 129 | lr_img = util.single2tensor4(lr_img) 130 | if not args.cpu: lr_img = lr_img.cuda() 131 | 132 | if sr_img is None: 133 | with torch.no_grad(): 134 | sr_img = model.sr_forward(lr_img) 135 | 136 | sr_img = util.quantize(sr_img.cpu().squeeze(0).permute(1, 2, 0), args.rgb_range) 137 | sr_img_y = util.rgb2ycbcr(sr_img.numpy() / 255.) * 255. # normalize to 0-1, aviod round() operation 138 | 139 | if isinstance(hr_img, torch.Tensor): 140 | hr_img = hr_img.squeeze().permute(1, 2, 0).numpy() 141 | hr_img_y = util.rgb2ycbcr(hr_img / 255.) * 255. # normalize to 0-1, aviod round() operation 142 | 143 | psnr_, ssim_ = util.calc_psnr_ssim(sr_img_y, hr_img_y) 144 | 145 | if return_sr: 146 | return psnr_, ssim_, sr_img 147 | else: 148 | return psnr_, ssim_ 149 | 150 | 151 | def test_all(args, model, test_loader, origin_model=None, cls_model=None, return_sr=False): 152 | model_update = model 153 | psnrs, ssims, srs = [], [], [] 154 | for idx, (lr, gt, _) in tqdm(enumerate(test_loader), ncols=80, total=len(test_loader)): 155 | if cls_model: 156 | cls_pred = cls_model(lr[:,[2,1,0]].cuda()/255.) 157 | if (cls_pred.sigmoid() > 0.5).sum().item() == 0: 158 | # directly using pretrained model to upscale clean images 159 | model = origin_model 160 | else: 161 | model = model_update 162 | 163 | if return_sr: 164 | psnr_, ssim_, sr = test_one(args, model, lr, gt, return_sr=True) 165 | srs.append(sr) 166 | else: 167 | psnr_, ssim_ = test_one(args, model, lr, gt) 168 | psnrs.append(psnr_) 169 | ssims.append(ssim_) 170 | 171 | if return_sr: 172 | return psnrs, ssims, srs 173 | else: 174 | return psnrs, ssims 175 | 176 | def merge_list(psnrs, ssims): 177 | """return 1D array with results of [mean, 0, 1, ..., N]""" 178 | str_list = [] 179 | for psnr, ssim in zip(psnrs, ssims): 180 | psnr_ssim = "{:.3f}/{:.4f}".format(psnr, ssim) 181 | str_list.append(psnr_ssim) 182 | str_list.insert(0, "{:.3f}/{:.4f}".format(np.mean(psnrs), np.mean(ssims))) 183 | return str_list 184 | 185 | def transform_tensor(img, op, undo=False): 186 | """ 187 | params: 188 | img: BxHxWxC, [0-255], Tensor 189 | op: transform operator 190 | """ 191 | if undo: 192 | if op.find('t') >= 0: # first rotate back 193 | img = img.permute((0, 1, 3, 2)).contiguous() 194 | 195 | if op.find('v') >= 0: 196 | img = torch.flip(img, dims=[3]).contiguous() 197 | if op.find('h') >= 0: 198 | img = torch.flip(img, dims=[2]).contiguous() 199 | 200 | if not undo: 201 | if op.find('t') >= 0: # rotate in the last 202 | img = img.permute((0, 1, 3, 2)).contiguous() 203 | 204 | return img 205 | 206 | def augment_transform(img_in, undo=False): 207 | """ transform the input tensor 8 times 208 | params: 209 | img: BxHxWxC, [0-255], Tensor 210 | """ 211 | img_outs = [] 212 | tran_ops = ['', 'v', 'h', 't', 'vh', 'vt', 'ht', 'vht'] # augment 213 | for op in tran_ops: 214 | img_outs.append(transform_tensor(img_in, op, undo)) 215 | return img_outs, tran_ops 216 | 217 | 218 | def load_model(pre_train, model): 219 | print(pre_train) 220 | assert os.path.exists(pre_train) 221 | if os.path.exists(pre_train): 222 | if "baseline" in pre_train.lower(): 223 | state_dict = torch.load(pre_train) 224 | else: 225 | try: 226 | state_dict = torch.load(pre_train)['model'] 227 | except: 228 | state_dict = torch.load(pre_train) 229 | state_dict={k.replace("upsample.","upsampler."):v for k,v in state_dict.items()} 230 | model.load_state_dict(state_dict, strict=True) 231 | return model 232 | 233 | 234 | def get_logger(save_path): 235 | logger = logging.getLogger() 236 | logger.setLevel(level = logging.INFO) 237 | handler = logging.FileHandler(save_path) 238 | handler.setLevel(logging.INFO) 239 | logger.addHandler(handler) 240 | console = logging.StreamHandler() 241 | console.setLevel(logging.INFO) 242 | logger.addHandler(console) 243 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 244 | handler.setFormatter(formatter) 245 | console.setFormatter(formatter) 246 | logger.disabled = False 247 | return logger 248 | 249 | def set_seed(seed): 250 | random.seed(seed) 251 | np.random.seed(seed) 252 | torch.manual_seed(seed) 253 | if torch.cuda.device_count() == 1: 254 | torch.cuda.manual_seed(seed) 255 | else: 256 | torch.cuda.manual_seed_all(seed) 257 | torch.backends.cudnn.benchmark = False 258 | torch.backends.cudnn.deterministic = True -------------------------------------------------------------------------------- /src/utils/utils_blindsr_plus.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import special 3 | from scipy import ndimage 4 | from skimage.filters import gaussian 5 | import random 6 | import cv2 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | import utils.utils_image as util 11 | import utils.utils_blindsr as blindsr 12 | import utils.basicsr_degradations as degradations 13 | 14 | 15 | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): 16 | """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter 17 | 18 | Args: 19 | cutoff (float): cutoff frequency in radians (pi is max) 20 | kernel_size (int): horizontal and vertical size, must be odd. 21 | pad_to (int): pad kernel size to desired size, must be odd or zero. 22 | """ 23 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 24 | kernel = np.fromfunction( 25 | lambda x, y: cutoff * special.j1(cutoff * np.sqrt( 26 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( 27 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) 28 | kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) 29 | kernel = kernel / np.sum(kernel) 30 | if pad_to > kernel_size: 31 | pad_size = (pad_to - kernel_size) // 2 32 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 33 | return kernel 34 | 35 | 36 | def filter2D(img, kernel): 37 | """PyTorch version of cv2.filter2D 38 | 39 | Args: 40 | img (Tensor): (b, c, h, w) 41 | kernel (Tensor): (b, k, k) 42 | """ 43 | k = kernel.size(-1) 44 | b, c, h, w = img.size() 45 | if k % 2 == 1: 46 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 47 | else: 48 | raise ValueError('Wrong kernel size') 49 | 50 | ph, pw = img.size()[-2:] 51 | 52 | if kernel.size(0) == 1: 53 | # apply the same kernel to all batch images 54 | img = img.view(b * c, 1, ph, pw) 55 | kernel = kernel.view(1, 1, k, k) 56 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 57 | else: 58 | img = img.view(1, b * c, ph, pw) 59 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 60 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 61 | 62 | 63 | def add_sinc_blur(img): 64 | """ 65 | Add ringing and overshoot artifacts to the image, using 2D sinc kernel 66 | params: 67 | img: HxWxC, [0, 1] 68 | """ 69 | kernel_range = [2 * v + 1 for v in range(3, 11)] 70 | kernel_size = random.choice(kernel_range) 71 | 72 | if kernel_size < 13: 73 | omega_c = np.random.uniform(np.pi / 3, np.pi) 74 | else: 75 | omega_c = np.random.uniform(np.pi / 5, np.pi) 76 | kernel = circular_lowpass_kernel(omega_c, kernel_size, pad_to=21) 77 | 78 | # convert to tensor 79 | kernel = torch.FloatTensor(kernel) 80 | img = util.single2tensor4(img) 81 | 82 | img_out = filter2D(img, kernel) 83 | img_out = torch.clamp(img_out, 0, 1) 84 | 85 | # convert to numpy 86 | img_out = util.tensor2single3(img_out) 87 | 88 | return img_out 89 | 90 | 91 | def add_resize_blur(img, scale_range=2, prob=0.5): 92 | """Randown Upsampling and Downsampling 93 | params: 94 | img: HxWxC, [0, 1] 95 | """ 96 | h, w = img.shape[:2] 97 | 98 | if np.random.random() > 0.5: 99 | # random upsampling 100 | up_sf = random.uniform(1.1, scale_range) 101 | # upsampling downsampling with random interpolation 102 | img = cv2.resize(img, (int(up_sf*w), int(up_sf*h)), interpolation=random.choice([1, 2, 3])) 103 | 104 | # random downsampling 105 | down_sf = random.uniform(1/scale_range, 0.9) 106 | 107 | # downsampling with random interpolation 108 | img = cv2.resize(img, (int(down_sf*w), int(down_sf*h)), interpolation=random.choice([1, 2, 3])) 109 | 110 | # resize back to original size 111 | img = cv2.resize(img, (w, h), interpolation=random.choice([1, 2, 3])) 112 | 113 | img = np.clip(img, 0.0, 1.0) 114 | 115 | return img 116 | 117 | 118 | def gaussian_blur(img, scale=2): 119 | wd = 2.0 + 0.2*scale 120 | # iso gaussian blur with random size and sigma 121 | k = blindsr.fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random()) 122 | img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 123 | return img 124 | 125 | 126 | def add_glass_blur(img, scale=2): 127 | # wd = 2.0 + 0.2*scale 128 | # rand_sigma = random.random() * wd 129 | rand_sigma = random.uniform(1, 2) 130 | rand_range = random.randint(2, 6) 131 | rand_iter = random.randint(2, 3) 132 | # round and clip image for counting vals correctly 133 | # img = gaussian(img, sigma=rand_sigma, channel_axis=True) 134 | img = gaussian(img, sigma=rand_sigma) 135 | # img = np.clip((img * 255.0).round(), 0, 255) / 255. 136 | 137 | h, w = img.shape[:2] 138 | # locally shuffle pixels 139 | for _ in range(rand_iter): 140 | for h in range(h - rand_range, rand_range, -1): 141 | for w in range(w - rand_range, rand_range, -1): 142 | # random raletive position 143 | dx, dy = np.random.randint(-rand_range, rand_range, size=(2,)) 144 | h_prime, w_prime = h + dy, w + dx 145 | # swap pixel values 146 | img[h, w], img[h_prime, w_prime] = img[h_prime, w_prime], img[h, w] 147 | 148 | # round and clip image for counting vals correctly 149 | # img = gaussian_blur(img, scale) 150 | # img = gaussian(img, sigma=rand_sigma, channel_axis=True) 151 | img = gaussian(img, sigma=rand_sigma) 152 | img = np.clip(img, 0, 1) 153 | 154 | return img 155 | 156 | 157 | def disk(radius, alias_blur=0.1, dtype=np.float32): 158 | if radius <= 8: 159 | L = np.arange(-8, 8 + 1) 160 | ksize = (3, 3) 161 | else: 162 | L = np.arange(-radius, radius + 1) 163 | ksize = (5, 5) 164 | X, Y = np.meshgrid(L, L) 165 | aliased_disk = np.array((X ** 2 + Y ** 2) <= radius ** 2, dtype=dtype) 166 | aliased_disk /= np.sum(aliased_disk) 167 | 168 | # supersample disk to antialias 169 | return cv2.GaussianBlur(aliased_disk, ksize=ksize, sigmaX=alias_blur) # TODO: test this operation do not need random 170 | 171 | 172 | def add_defocus_blur(img): 173 | rand_radius = random.randint(3, 10) 174 | rand_alias_blur = np.random.uniform(0.1, 0.5) 175 | 176 | kernel = disk(radius=rand_radius, alias_blur=rand_alias_blur) 177 | img = cv2.filter2D(img, -1, kernel) 178 | img = np.clip(img, 0, 1) 179 | 180 | return img 181 | 182 | 183 | # avoid the scale to be too small to gennerate test data 184 | def add_scale_Poisson_noise(img, scale_range=(0.5, 3), gray_prob=0.5): 185 | """ 186 | params: 187 | img: HxWxC, [0, 1] 188 | """ 189 | img = degradations.random_add_poisson_noise( 190 | img, scale_range, gray_prob, clip=True, rounds=False) 191 | 192 | return img 193 | 194 | 195 | def _bernoulli(p, shape): 196 | """ 197 | https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L191 198 | 199 | Bernoulli trials at a given probability of a given size. 200 | This function is meant as a lower-memory alternative to calls such as 201 | `np.random.choice([True, False], size=image.shape, p=[p, 1-p])`. 202 | While `np.random.choice` can handle many classes, for the 2-class case 203 | (Bernoulli trials), this function is much more efficient. 204 | Parameters 205 | ---------- 206 | p : float 207 | The probability that any given trial returns `True`. 208 | shape : int or tuple of ints 209 | The shape of the ndarray to return. 210 | 211 | Returns 212 | ------- 213 | out : ndarray[bool] 214 | The results of Bernoulli trials in the given `size` where success 215 | occurs with probability `p`. 216 | """ 217 | if p == 0: 218 | return np.zeros(shape, dtype=bool) 219 | if p == 1: 220 | return np.ones(shape, dtype=bool) 221 | return np.random.random(shape) <= p 222 | 223 | 224 | def add_impluse_noise(img, noise_prob=(0.03, 0.25), salt_vs_pepper_prob=(0, 1), gray_prob=0.5): 225 | """ 226 | params: 227 | img: HxWxC, [0, 1] 228 | """ 229 | img_out = img.copy() 230 | 231 | rand_prob = np.random.uniform(noise_prob[0], noise_prob[1]) 232 | salt_vs_pepper = np.random.uniform(salt_vs_pepper_prob[0], salt_vs_pepper_prob[1]) 233 | 234 | if np.random.random() < gray_prob: 235 | flipped = _bernoulli(rand_prob, (*img.shape[:2], 1)) 236 | salted = _bernoulli(salt_vs_pepper, (*img.shape[:2], 1)) 237 | flipped = np.repeat(flipped, 3, axis=2) 238 | salted = np.repeat(salted, 3, axis=2) 239 | peppered = ~salted 240 | img_out[flipped & salted] = 1. 241 | img_out[flipped & peppered] = 0 242 | else: 243 | # flipped = np.random.binomial(n=1, p=rand_prob, size=img.shape) 244 | # salted = np.random.binomial(n=1, p=salt_vs_pepper, size=img.shape) 245 | flipped = _bernoulli(rand_prob, img.shape) 246 | salted = _bernoulli(salt_vs_pepper, img.shape) 247 | peppered = ~salted 248 | img_out[flipped & salted] = 1. 249 | img_out[flipped & peppered] = 0 250 | 251 | return img_out 252 | -------------------------------------------------------------------------------- /src/srtta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.optim as optim 5 | import torch.nn.functional as F 6 | from torchvision import transforms 7 | from copy import deepcopy 8 | import utils.utils_tta as utils_tta 9 | import utils.utils_image as util 10 | 11 | import utils.utils_blindsr_plus as blindsr_plus 12 | import utils.basicsr_degradations as degradations 13 | from utils.diffjpeg import DiffJPEG 14 | 15 | 16 | import logging 17 | logger = logging.getLogger(__name__) 18 | 19 | def configure_model(args, model): 20 | """Configure model for use with tta.""" 21 | # train mode, because tent optimizes the model to minimize entropy 22 | model.train() 23 | 24 | train_params = ['body'] # only update the body 25 | 26 | for k, v in model.named_parameters(): 27 | prefix, block_index = k.split('.')[:2] 28 | if prefix in train_params: 29 | 30 | logger.info('train params: {}'.format(k)) 31 | v.requires_grad = True 32 | else: 33 | logger.info('freezing params: {}'.format(k)) 34 | v.requires_grad = False # fix the other layers 35 | 36 | return model 37 | 38 | def collect_params(model): 39 | """Collect all trainable parameters. 40 | 41 | Walk the model's modules and collect all parameters. 42 | Return the parameters and their names. 43 | 44 | Note: other choices of parameterization are possible! 45 | """ 46 | params = [] 47 | names = [] 48 | for nm, m in model.named_modules(): 49 | if True: #isinstance(m, nn.BatchNorm2d): collect all 50 | for np, p in m.named_parameters(): 51 | if np in ['weight', 'bias'] and p.requires_grad: 52 | params.append(p) 53 | names.append(f"{nm}.{np}") 54 | # print(nm, np) 55 | return params, names 56 | 57 | def compute_loss(pred, target, eps=1e-3): 58 | """ L1 Charbonnier loss """ 59 | return torch.sqrt(((pred - target)**2) + eps).mean() 60 | 61 | def load_model_and_optimizer(model, optimizer, model_state, optimizer_state): 62 | """Restore the model and optimizer states from copies.""" 63 | model.load_state_dict(model_state, strict=True) 64 | optimizer.load_state_dict(optimizer_state) 65 | 66 | def create_ema_model(model): 67 | """Copy the model and optimizer states for resetting after adaptation.""" 68 | ema_model = deepcopy(model) 69 | for param in ema_model.parameters(): 70 | param.detach_() 71 | return ema_model 72 | 73 | class SRTTA(): 74 | def __init__(self, args, model, optimizer, fisher=None, cls_model=None): 75 | self.args = args 76 | self.model = model 77 | self.origin_model = model 78 | self.fisher = fisher 79 | self.cls_model = cls_model 80 | self.optimizer = optimizer 81 | 82 | self.compute_loss = compute_loss 83 | self.model_state = deepcopy(model.state_dict()) 84 | self.optimizer_state = deepcopy(optimizer.state_dict()) 85 | 86 | self.noise_range = args.noise_range 87 | self.jpeg_range = args.jpeg_range 88 | self.jpeger = DiffJPEG(differentiable=False) 89 | self.fishers = {} 90 | if not self.args.cpu: 91 | self.jpeger = self.jpeger.cuda() 92 | 93 | # teacher model, this do not need the gradient 94 | self.model_teacher = create_ema_model(self.model) 95 | 96 | def __call__(self, train_loader, img_lr, img_gt, img_name, corruption=None): 97 | # test-time adaptation 98 | 99 | if self.cls_model: 100 | classes = ["origin", "blur", "noise", "jpeg"] 101 | with torch.no_grad(): 102 | input_img = img_lr.clone() 103 | input_img = input_img[:, [2,1,0]].cuda() / 255. 104 | cls_pred = self.cls_model(input_img) 105 | degradation_types = [classes[i] for i in range(1, len(classes)) if (cls_pred.sigmoid()>0.5)[0][i-1]] 106 | if len(degradation_types) == 0: 107 | degradation_types = ['origin'] 108 | 109 | corruptions=[] 110 | for dtype_ in degradation_types: 111 | if dtype_.lower().find("noise") >= 0: 112 | corruption = 'GaussianNoise' 113 | elif dtype_.lower().find("jpeg") >= 0: 114 | corruption = "JPEG" 115 | elif dtype_.lower().find("blur") >= 0: 116 | corruption = 'GaussianBlur' 117 | elif dtype_.lower().find("origin") >= 0: 118 | corruption = "Original" 119 | corruptions.append(corruption) 120 | else: 121 | corruptions = [corruption] 122 | 123 | logger.info(f"use {','.join(corruptions)} for {img_name}") 124 | for iteration in range(self.args.iterations): 125 | self.test_time_adaptation(train_loader, corruptions) 126 | 127 | # test results 128 | if not isinstance(img_lr, torch.Tensor): 129 | img_lr = util.single2tensor4(img_lr) 130 | if not self.args.cpu: img_lr = img_lr.cuda() 131 | 132 | self.model.eval() 133 | with torch.no_grad(): 134 | sr_img = self.model.sr_forward(img_lr) 135 | 136 | # record metric 137 | with torch.no_grad(): 138 | psnr, ssim = utils_tta.test_one(self.args, self.model, img_lr, img_gt, sr_img=sr_img) 139 | 140 | if iteration == self.args.iterations - 1: 141 | logger.info("Adapted PSNR/SSIM: {:.3f}/{:.4f} on {} with {} iters".format( 142 | psnr, ssim, img_name, iteration, corruption)) 143 | 144 | return sr_img 145 | 146 | def test_time_adaptation(self, train_loader, corruption=None): 147 | total_loss = 0 148 | self.model.train() 149 | self.optimizer.zero_grad() 150 | 151 | if corruption is not None: 152 | task_idxs = [] 153 | for corruption_ in corruption: 154 | if corruption_.lower().find('blur') >= 0: 155 | task_idx = 1 156 | elif corruption_.lower().find('noise') >= 0: 157 | task_idx = 2 158 | elif corruption_.lower().find('jpeg') >= 0: 159 | task_idx = 3 160 | else: 161 | return 0 162 | task_idxs.append(task_idx) 163 | train_loader.dataset.datasets[0].set_task(task_idxs) 164 | loader_iter = iter(train_loader) 165 | img_label, kernel = next(loader_iter) 166 | if not self.args.cpu: 167 | img_label = img_label.cuda() 168 | if kernel.dim() > 2: kernel = kernel.cuda() 169 | with torch.no_grad(): 170 | img_in = self.preprocess(img_label, kernel, task_idxs).contiguous() 171 | out_tea_gt = self.model_teacher(img_label, aux_forward=True) 172 | 173 | with torch.no_grad(): 174 | out_student_gt = self.model(img_label, aux_forward=True) 175 | out_student = self.model(img_in, aux_forward=True) 176 | # compute student loss 177 | loss = self.compute_loss(out_student, out_student_gt.detach()) 178 | # compute teacher loss 179 | if self.args.teacher_weight > 0: 180 | loss += self.args.teacher_weight * self.compute_loss(out_student, out_tea_gt) 181 | loss.backward() 182 | total_loss += loss.item() 183 | 184 | self.optimizer.step() 185 | 186 | if self.args.fisher_restore: 187 | self.fisher_restoration() 188 | 189 | return total_loss 190 | 191 | def preprocess(self, img_gt, kernel, task_idx=1): 192 | # convert to [0-1] 193 | img_in = img_gt / 255. 194 | # add blur 195 | if task_idx==1 or (isinstance(task_idx,list) and 1 in task_idx): 196 | img_in = blindsr_plus.filter2D(img_in, kernel) 197 | # add noise 198 | if task_idx==2 or (isinstance(task_idx,list) and 2 in task_idx): 199 | img_in = degradations.random_add_gaussian_noise_pt( 200 | img_in, sigma_range=self.noise_range, clip=False, rounds=False, gray_prob=0.4) 201 | # add jpge 202 | if task_idx==3 or (isinstance(task_idx,list) and 3 in task_idx): 203 | jpeg_p = img_in.new_zeros(img_in.size(0)).uniform_(*self.jpeg_range) 204 | img_in = self.jpeger(img_in, quality=jpeg_p).contiguous() 205 | 206 | # convert to [0-255] 207 | img_in = (torch.clamp(img_in, 0, 1) * 255.0).round() 208 | 209 | return img_in.detach() 210 | 211 | def fisher_restoration(self): 212 | """Restore the important params back to original model""" 213 | for nm, m in self.model.named_modules(): 214 | for npp, p in m.named_parameters(): 215 | if npp in ['weight', 'bias'] and p.requires_grad: 216 | # fishers[name]: [fisher, mask] 217 | mask = self.fishers[f"{nm}.{npp}"][-1] 218 | with torch.no_grad(): 219 | p.data = self.model_state[f"{nm}.{npp}"] * mask + p * (1.-mask) 220 | 221 | def reset_parameters(self): 222 | """Restore the model and optimizer states from copies.""" 223 | self.model.load_state_dict(self.model_state, strict=True) 224 | self.optimizer.load_state_dict(self.optimizer_state) 225 | 226 | def compute_fisher(self, test_loaders): 227 | if len(self.fishers) > 0: return self.fishers 228 | 229 | fishers = {} 230 | fisher_optimizer = optim.Adam(self.model.parameters()) 231 | for idx_data, t_data in enumerate(test_loaders): 232 | if t_data.dataset.name != "Set5": continue 233 | for idx, (img_lr, _, filename) in enumerate(t_data, start=1): 234 | if not self.args.cpu: img_lr = img_lr.cuda() 235 | 236 | fisher_optimizer.zero_grad() 237 | tran_imgs, tran_ops = utils_tta.augment_transform(img_lr) 238 | tran_imgs.reverse() # the last item is img_lr 239 | tran_ops.reverse() 240 | 241 | # compute consistent loss 242 | sr_imgs = [] 243 | for idx, (tran_img, op) in enumerate(zip(tran_imgs, tran_ops), start=1): 244 | if idx < len(tran_imgs): 245 | with torch.no_grad(): 246 | sr_img = self.model(tran_img) 247 | sr_img = utils_tta.transform_tensor(sr_img, op, undo=True) 248 | sr_imgs.append(sr_img) 249 | else: 250 | sr_img = self.model(tran_img) 251 | sr_imgs.append(sr_img) 252 | # with torch.no_grad(): 253 | sr_pseudo = torch.cat(sr_imgs, dim=0).mean(dim=0, keepdim=True).detach() 254 | loss = self.compute_loss(sr_imgs[-1], sr_pseudo) 255 | loss.backward() 256 | 257 | # computer fisher 258 | for name, param in self.model.named_parameters(): 259 | if param.grad is not None: 260 | if idx_data > 1: 261 | fisher = param.grad.data.clone().detach() ** 2 + fishers[name] 262 | else: 263 | fisher = param.grad.data.clone().detach() ** 2 264 | if idx_data == len(t_data): 265 | fisher = fisher / idx_data 266 | fishers.update({name: fisher}) 267 | 268 | # computer mask based on the fisher 269 | for name, param in self.model.named_parameters(): 270 | if param.grad is not None: 271 | fisher = fishers[name].flatten() # TODO: check whether flatten is reverse 272 | _, mask_idx = torch.topk(fisher, k=int(len(fisher) * self.args.fisher_ratio)) 273 | mask = param.new_zeros(param.shape).flatten() # ensure the mask and p are in the save devide 274 | mask[mask_idx] = 1 275 | mask = mask.view(param.shape) 276 | self.fishers.update({name: [fisher, mask]}) 277 | 278 | # self.fishers = fishers 279 | fisher_optimizer.zero_grad() 280 | 281 | return fisher 282 | 283 | def resume(self, resume_path): 284 | if resume_path is not None: 285 | resume_state = torch.load(resume_path) 286 | load_model_and_optimizer(self.model, self.optimizer, 287 | resume_state['model'], resume_state['optimizer']) 288 | self.model_state = resume_state['ori_model'] 289 | self.optimizer_state = resume_state['ori_optimizer'] 290 | corruption = resume_state['corruption'] 291 | iter_idx = resume_state['iter_idx'] 292 | 293 | return corruption, iter_idx 294 | 295 | def save(self, corruption='GaussianBlur', iter_idx=0): 296 | state = {} 297 | state['model'] = self.model.state_dict() 298 | state['optimizer'] = self.optimizer.state_dict() 299 | state['ori_model'] = self.model_state 300 | state['ori_optimizer'] = self.optimizer_state 301 | state['corruption'] = corruption 302 | state['iter_idx'] = iter_idx 303 | 304 | save_path = os.path.join(self.args.save_dir, "state_{}_last.pt".format(corruption)) 305 | torch.save(state, save_path) 306 | -------------------------------------------------------------------------------- /src/main_tta.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | import argparse 4 | from datetime import datetime 5 | import torch 6 | import random 7 | import torch.optim as optim 8 | from torch.utils.data import dataloader 9 | import numpy as np 10 | import pandas as pd 11 | from tqdm import tqdm 12 | import time 13 | from copy import deepcopy 14 | from model.classifier import Classifier 15 | import srtta 16 | import data 17 | from data.div2kc import DIV2KC 18 | from data.div2kmc import DIV2KMC 19 | import utils.utils_tta as util 20 | from model.edsr import EDSR 21 | 22 | import logging 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def init_logger(args): 27 | if not os.path.exists(args.save_dir): 28 | os.makedirs(args.save_dir) 29 | 30 | current_time = datetime.now().strftime("%y%m%d_%H%M%S") 31 | log_dest = "{}_{}.txt".format( 32 | os.path.splitext(args.log_file)[0], current_time) 33 | logging.basicConfig( 34 | level=logging.INFO, 35 | format="[%(asctime)s] [%(filename)s: %(lineno)4d]: %(message)s", 36 | datefmt="%y/%m/%d %H:%M:%S", 37 | handlers=[ 38 | logging.FileHandler(os.path.join(args.save_dir, log_dest)), 39 | logging.StreamHandler() 40 | ]) 41 | 42 | logger.info(args) 43 | 44 | 45 | def setup_srtta(args, model, cls_model=None): 46 | """Set up tent adaptation. 47 | 48 | Configure the model for training + feature modulation by batch statistics, 49 | collect the parameters for feature modulation by gradient optimization, 50 | set up the optimizer, and then tent the model. 51 | """ 52 | model = srtta.configure_model(args, model) 53 | params, param_names = srtta.collect_params(model) 54 | optimizer = optim.Adam(params, args.lr, args.betas) 55 | srtta_model = srtta.SRTTA(args, model, optimizer, cls_model=cls_model) 56 | # logger.info(f"params for adaptation: %s", param_names) 57 | 58 | return srtta_model 59 | 60 | def create_metric_file(args, img_paths, corruption='GaussianBlur'): 61 | # log file 62 | if not os.path.exists(args.save_dir): 63 | os.makedirs(args.save_dir) 64 | 65 | log_file = os.path.join(args.save_dir, "{}_{}".format(corruption, args.metric_file)) 66 | header = ['adapted_img', 'iteration', 'Original', 'mean'] 67 | header += [os.path.basename(p) for p in img_paths] # use image_name as column index 68 | metric_frame = pd.DataFrame(columns=header) 69 | 70 | if os.path.exists(log_file) and not args.reset: 71 | print("The log file exists, don't overwrite it.") 72 | return None # assure not to save the different results into the same dir 73 | metric_frame.to_csv(log_file, mode='a', index=False, header=header) 74 | 75 | return log_file, header 76 | 77 | def record_metric(log_file, header, img_name, iteration, ori_psnr, ori_ssim, psnrs, ssims): 78 | # logger for test model on all image 79 | ori_metric = "{:.3f}/{:.4f}".format(ori_psnr, ori_ssim) 80 | metric_array = util.merge_list(psnrs, ssims) 81 | metric_array = [img_name, iteration, ori_metric] + metric_array # insert iteration and loss 82 | metric_array = [metric_array] # two-dimentional list for new row dataFrame 83 | if len(header) != len(metric_array[0]): 84 | header = header[:len(metric_array[0])] 85 | metric_frame = pd.DataFrame(metric_array, index=None, columns=header) 86 | metric_frame.to_csv(log_file, mode='a', index=False, header=False) 87 | 88 | def main(args): 89 | init_logger(args) 90 | 91 | if args.corruption is None: 92 | assert args.tta_data in ["DIV2KC","DIV2KMC"] 93 | if args.tta_data == "DIV2KC": 94 | corruptions = util.get_corruptions()[1:] # maybe affect the reproduction 95 | else: 96 | corruptions = ["BlurJPEG", "BlurNoise", "NoiseJPEG", "BlurNoiseJPEG"] 97 | else: 98 | corruptions = args.corruption 99 | 100 | logger.info("Corruptions: {}".format(corruptions)) 101 | 102 | #### init test data ### 103 | val_dataloader = data.Data(args) 104 | trainset = val_dataloader.loader_train.dataset.datasets[0] 105 | train_loader = val_dataloader.loader_train 106 | test_loaders = val_dataloader.loader_test # contain two test set 107 | 108 | #### init tta data ### 109 | if args.tta_data=="DIV2KC": 110 | tta_data = DIV2KC(args, name="DIV2KC", train=False) 111 | else: 112 | tta_data = DIV2KMC(args, name="DIV2KMC", train=False) 113 | tta_dataloader = dataloader.DataLoader( 114 | tta_data, batch_size=1, shuffle=True, # shuffle=True, 115 | pin_memory=not args.cpu, num_workers=args.n_threads) 116 | 117 | # config model 118 | model = EDSR(args) 119 | model = util.load_model(args.pre_train, model) 120 | origin_model = deepcopy(model) 121 | cls_model=Classifier().eval() 122 | state_dict = torch.load(args.classifier) 123 | cls_model.load_state_dict(state_dict,strict=True) 124 | 125 | if not args.cpu: 126 | model = model.cuda() 127 | origin_model = origin_model.cuda() 128 | cls_model = cls_model.cuda() 129 | 130 | srtta_model = setup_srtta(args, model, cls_model=cls_model) 131 | 132 | if args.resume is not None: 133 | finished_corruption, finished_img_iter = srtta_model.resume(args.resume) 134 | start_idx = corruptions.index(finished_corruption) 135 | corruptions = corruptions[start_idx:] 136 | 137 | for idx, corruption in enumerate(corruptions): 138 | if args.params_reset and idx != 0: 139 | srtta_model.reset_parameters() 140 | 141 | # evaluate the model before test-time adaptation 142 | ori_psnr, ori_ssim = 0, 0 143 | for _, t_data in enumerate(test_loaders): 144 | if args.resume is not None and finished_img_iter >= 0: break 145 | data_name = t_data.dataset.name 146 | if data_name in ["DIV2KC","DIV2KMC"]: 147 | # set corruption type for val dataset 148 | t_data.dataset.set_corruption(corruption) 149 | 150 | psnrs, ssims = util.test_all(args, model, t_data) 151 | 152 | logger.info("Original PSNR/SSIM: {:.3f}/{:.4f} on {} for {} data".format( 153 | np.mean(psnrs), np.mean(ssims), data_name, corruption)) 154 | if data_name == "Set5": 155 | ori_psnr, ori_ssim = np.mean(psnrs), np.mean(ssims) 156 | else: 157 | tta_psnrs, tta_ssims = psnrs, ssims 158 | 159 | metric_file, header = create_metric_file(args, tta_data.get_img_paths(), corruption) 160 | record_metric(metric_file, header, -1, -1, ori_psnr, ori_ssim, tta_psnrs, tta_ssims) 161 | 162 | # compute fisher to select the important params 163 | if args.fisher_restore: 164 | srtta_model.compute_fisher(test_loaders) 165 | 166 | if corruption.lower() == 'original': continue # do not adapt for clean data 167 | 168 | # set corruption for tta and val dataset 169 | tta_data.set_corruption(corruption) 170 | 171 | if args.resume is not None and finished_img_iter == len(tta_dataloader): 172 | logger.info(f"Corruption: {corruption} have been adapted, skip to next ") 173 | finished_img_iter = -1 174 | continue 175 | 176 | for iter_idx, (img_lr, img_gt, filename) in enumerate(tta_dataloader): 177 | if args.resume is not None and finished_img_iter >= 0: 178 | if iter_idx + 1 <= finished_img_iter: 179 | continue 180 | else: 181 | finished_img_iter = -1 182 | 183 | trainset.set_image(img_lr) 184 | 185 | # adaptation 186 | img_out = srtta_model(train_loader, img_lr, img_gt, filename[0], corruption) 187 | 188 | if args.save_results: 189 | img_out = util.quantize(img_out) 190 | img_out = img_out.byte().permute(1, 2, 0).cpu() 191 | util.write_img(os.path.join(args.save_dir, filename[0] + '.png'), img_out.numpy()) 192 | 193 | # test results 194 | if iter_idx == len(tta_dataloader) - 1 or iter_idx % args.test_interval == 0: 195 | ori_psnr, ori_ssim = 0, 0 196 | for _, t_data in enumerate(test_loaders): 197 | data_name = t_data.dataset.name 198 | if data_name == "Set5": 199 | psnrs, ssims = util.test_all(args, model, t_data) 200 | else: 201 | psnrs, ssims = util.test_all(args, model, t_data, origin_model=origin_model, cls_model=cls_model) 202 | 203 | logger.info("Adapted PSNR/SSIM: {:.3f}/{:.4f} on {}-{} for {} data".format( 204 | np.mean(psnrs), np.mean(ssims), data_name, filename[0], corruption)) 205 | if data_name == "Set5": 206 | ori_psnr, ori_ssim = np.mean(psnrs), np.mean(ssims) 207 | else: 208 | tta_psnrs, tta_ssims = psnrs, ssims 209 | 210 | # record metirc 211 | record_metric(metric_file, header, filename[0], args.iterations, ori_psnr, ori_ssim, tta_psnrs, tta_ssims) 212 | 213 | # save model params 214 | srtta_model.save(corruption, iter_idx) 215 | 216 | logger.info("Finish Test-time Adaptation....") 217 | 218 | if __name__ == '__main__': 219 | parser = argparse.ArgumentParser() 220 | parser.add_argument('--save_dir', type=str, default='../experiment/save/', help='path to gt image for metric computing') 221 | parser.add_argument('--metric_file', type=str, default='tta_metrics.csv', help='path to the file for recording metric') 222 | parser.add_argument('--log_file', type=str, default='logger.txt', help='path to gt image for metric computing') 223 | parser.add_argument('--reset', action='store_true', help='reset the adapting') 224 | parser.add_argument('--exp_name', type=str, default='debug', help='exp name') 225 | 226 | # model options 227 | parser.add_argument('--model', default='EDSR', help='model name') 228 | parser.add_argument('--pre_train', type=str, default='checkpoints/EDSR_baseline_x2.pt', help='pre-trained model directory') 229 | parser.add_argument('--classifier', type=str, default='checkpoints/classifier.pt', help='pre-trained model directory') 230 | parser.add_argument('--n_resblocks', type=int, default=16, help='number of residual blocks') 231 | parser.add_argument('--n_feats', type=int, default=64, help='number of feature maps') 232 | parser.add_argument('--res_scale', type=float, default=1, help='residual scaling') 233 | parser.add_argument('--scale', type=str, default='2', help='super resolution scale') 234 | parser.add_argument('--n_colors', type=int, default=3, help='number of color channels to use') 235 | parser.add_argument('--rgb_range', type=int, default=255, help='maximum value of RGB') 236 | 237 | # data options 238 | parser.add_argument('--dir_data', type=str, default='../datasets/', help='dataset directory') 239 | parser.add_argument('--data_train', type=str, default='PatchKernel', help='train dataset name') 240 | parser.add_argument('--data_test', type=str, default='DIV2KC+Set5', help='test dataset name') 241 | parser.add_argument('--tta_data', type=str, default='DIV2KC', help='test dataset name') 242 | parser.add_argument('--corruption', type=str, default=None, help='the type of corruption data') 243 | parser.add_argument('--ext', type=str, default='img', help='dataset file extension') 244 | parser.add_argument('--batch_size', type=int, default=32, help='input batch size for training') 245 | parser.add_argument('--patch_size', type=int, default=96, help='output patch size') 246 | parser.add_argument('--cpu', action='store_true', help='use cpu only') 247 | 248 | # training options 249 | parser.add_argument('--n_fixed_blocks', type=int, default=0, help='the number of last resblocks that do not update during tta') 250 | parser.add_argument('--iterations', type=int, default=10, help='the number of iterations to adapt on each test image') 251 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 252 | parser.add_argument('--betas', type=tuple, default=(0.9, 0.999), help='ADAM beta') 253 | parser.add_argument('--seed', type=int, default=1004, help='random seed') 254 | parser.add_argument('--save_results', action='store_true', help='save output results') 255 | parser.add_argument('--params_reset', action='store_true', help='flag to reset the parameters for each domain data') 256 | parser.add_argument('--fisher_restore', action='store_true', help='flag to stochastically restore from the original model') 257 | parser.add_argument('--fisher_ratio', type=float, default=0.3, help='threshold of stochastic restoration') 258 | parser.add_argument('--test_only', action='store_true', help='set this option to test the model') 259 | parser.add_argument('--n_threads', type=int, default=12, help='number of threads for data loading') 260 | parser.add_argument('--resume', type=str, default=None, help='path for resume from specific checkpoint') 261 | parser.add_argument('--test_interval', type=int, default=100, help='number of interval for computing metric') 262 | parser.add_argument('--multi-corruption', action='store_true', help='multi-corruption') 263 | parser.add_argument('--teacher_weight', type=float, default=1, help='the weight of the teacher degradation loss') 264 | 265 | args = parser.parse_args() 266 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 267 | args.data_train = args.data_train.split('+') 268 | args.data_test = args.data_test.split('+') 269 | if args.corruption is not None: 270 | args.corruption = args.corruption.split('+') 271 | 272 | ### hypyer-parameters for random degradation, TODO 273 | args.kernel_range = [2 * v + 1 for v in range(3, 11)] # kernel size ranges from 7 to 21 274 | args.kernel_list = ['iso', 'aniso'] 275 | args.kernel_prob = [0.5, 0.5] 276 | args.blur_sigma = [0.2, 3] 277 | args.betag_range = [0.5, 4] 278 | args.betap_range = [1, 2] 279 | args.noise_range = [1, 30] 280 | args.jpeg_range = [30, 95] 281 | print(args) 282 | 283 | date_str = time.strftime("%Y%m%d", time.localtime()) 284 | args.save_dir = f"../experiment/{date_str}_{args.exp_name}/" 285 | 286 | util.set_seed(args.seed) 287 | main(args) -------------------------------------------------------------------------------- /src/utils/diffjpeg.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/mlomnitz/DiffJPEG 3 | 4 | For images not divisible by 8 5 | https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343 6 | """ 7 | import itertools 8 | import numpy as np 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn import functional as F 12 | 13 | # ------------------------ utils ------------------------# 14 | y_table = np.array( 15 | [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56], 16 | [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92], 17 | [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]], 18 | dtype=np.float32).T 19 | y_table = nn.Parameter(torch.from_numpy(y_table)) 20 | c_table = np.empty((8, 8), dtype=np.float32) 21 | c_table.fill(99) 22 | c_table[:4, :4] = np.array([[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]).T 23 | c_table = nn.Parameter(torch.from_numpy(c_table)) 24 | 25 | 26 | def diff_round(x): 27 | """ Differentiable rounding function 28 | """ 29 | return torch.round(x) + (x - torch.round(x))**3 30 | 31 | 32 | def quality_to_factor(quality): 33 | """ Calculate factor corresponding to quality 34 | 35 | Args: 36 | quality(float): Quality for jpeg compression. 37 | 38 | Returns: 39 | float: Compression factor. 40 | """ 41 | if quality < 50: 42 | quality = 5000. / quality 43 | else: 44 | quality = 200. - quality * 2 45 | return quality / 100. 46 | 47 | 48 | # ------------------------ compression ------------------------# 49 | class RGB2YCbCrJpeg(nn.Module): 50 | """ Converts RGB image to YCbCr 51 | """ 52 | 53 | def __init__(self): 54 | super(RGB2YCbCrJpeg, self).__init__() 55 | matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], 56 | dtype=np.float32).T 57 | self.shift = nn.Parameter(torch.tensor([0., 128., 128.])) 58 | self.matrix = nn.Parameter(torch.from_numpy(matrix)) 59 | 60 | def forward(self, image): 61 | """ 62 | Args: 63 | image(Tensor): batch x 3 x height x width 64 | 65 | Returns: 66 | Tensor: batch x height x width x 3 67 | """ 68 | image = image.permute(0, 2, 3, 1) 69 | result = torch.tensordot(image, self.matrix, dims=1) + self.shift 70 | return result.view(image.shape) 71 | 72 | 73 | class ChromaSubsampling(nn.Module): 74 | """ Chroma subsampling on CbCr channels 75 | """ 76 | 77 | def __init__(self): 78 | super(ChromaSubsampling, self).__init__() 79 | 80 | def forward(self, image): 81 | """ 82 | Args: 83 | image(tensor): batch x height x width x 3 84 | 85 | Returns: 86 | y(tensor): batch x height x width 87 | cb(tensor): batch x height/2 x width/2 88 | cr(tensor): batch x height/2 x width/2 89 | """ 90 | image_2 = image.permute(0, 3, 1, 2).clone() 91 | cb = F.avg_pool2d(image_2[:, 1, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) 92 | cr = F.avg_pool2d(image_2[:, 2, :, :].unsqueeze(1), kernel_size=2, stride=(2, 2), count_include_pad=False) 93 | cb = cb.permute(0, 2, 3, 1) 94 | cr = cr.permute(0, 2, 3, 1) 95 | return image[:, :, :, 0], cb.squeeze(3), cr.squeeze(3) 96 | 97 | 98 | class BlockSplitting(nn.Module): 99 | """ Splitting image into patches 100 | """ 101 | 102 | def __init__(self): 103 | super(BlockSplitting, self).__init__() 104 | self.k = 8 105 | 106 | def forward(self, image): 107 | """ 108 | Args: 109 | image(tensor): batch x height x width 110 | 111 | Returns: 112 | Tensor: batch x h*w/64 x h x w 113 | """ 114 | height, _ = image.shape[1:3] 115 | batch_size = image.shape[0] 116 | image_reshaped = image.view(batch_size, height // self.k, self.k, -1, self.k) 117 | image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) 118 | return image_transposed.contiguous().view(batch_size, -1, self.k, self.k) 119 | 120 | 121 | class DCT8x8(nn.Module): 122 | """ Discrete Cosine Transformation 123 | """ 124 | 125 | def __init__(self): 126 | super(DCT8x8, self).__init__() 127 | tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) 128 | for x, y, u, v in itertools.product(range(8), repeat=4): 129 | tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16) 130 | alpha = np.array([1. / np.sqrt(2)] + [1] * 7) 131 | self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) 132 | self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float()) 133 | 134 | def forward(self, image): 135 | """ 136 | Args: 137 | image(tensor): batch x height x width 138 | 139 | Returns: 140 | Tensor: batch x height x width 141 | """ 142 | image = image - 128 143 | result = self.scale * torch.tensordot(image, self.tensor, dims=2) 144 | result.view(image.shape) 145 | return result 146 | 147 | 148 | class YQuantize(nn.Module): 149 | """ JPEG Quantization for Y channel 150 | 151 | Args: 152 | rounding(function): rounding function to use 153 | """ 154 | 155 | def __init__(self, rounding): 156 | super(YQuantize, self).__init__() 157 | self.rounding = rounding 158 | self.y_table = y_table 159 | 160 | def forward(self, image, factor=1): 161 | """ 162 | Args: 163 | image(tensor): batch x height x width 164 | 165 | Returns: 166 | Tensor: batch x height x width 167 | """ 168 | if isinstance(factor, (int, float)): 169 | image = image.float() / (self.y_table * factor) 170 | else: 171 | b = factor.size(0) 172 | table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) 173 | image = image.float() / table 174 | image = self.rounding(image) 175 | return image 176 | 177 | 178 | class CQuantize(nn.Module): 179 | """ JPEG Quantization for CbCr channels 180 | 181 | Args: 182 | rounding(function): rounding function to use 183 | """ 184 | 185 | def __init__(self, rounding): 186 | super(CQuantize, self).__init__() 187 | self.rounding = rounding 188 | self.c_table = c_table 189 | 190 | def forward(self, image, factor=1): 191 | """ 192 | Args: 193 | image(tensor): batch x height x width 194 | 195 | Returns: 196 | Tensor: batch x height x width 197 | """ 198 | if isinstance(factor, (int, float)): 199 | image = image.float() / (self.c_table * factor) 200 | else: 201 | b = factor.size(0) 202 | table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) 203 | image = image.float() / table 204 | image = self.rounding(image) 205 | return image 206 | 207 | 208 | class CompressJpeg(nn.Module): 209 | """Full JPEG compression algorithm 210 | 211 | Args: 212 | rounding(function): rounding function to use 213 | """ 214 | 215 | def __init__(self, rounding=torch.round): 216 | super(CompressJpeg, self).__init__() 217 | self.l1 = nn.Sequential(RGB2YCbCrJpeg(), ChromaSubsampling()) 218 | self.l2 = nn.Sequential(BlockSplitting(), DCT8x8()) 219 | self.c_quantize = CQuantize(rounding=rounding) 220 | self.y_quantize = YQuantize(rounding=rounding) 221 | 222 | def forward(self, image, factor=1): 223 | """ 224 | Args: 225 | image(tensor): batch x 3 x height x width 226 | 227 | Returns: 228 | dict(tensor): Compressed tensor with batch x h*w/64 x 8 x 8. 229 | """ 230 | y, cb, cr = self.l1(image * 255) 231 | components = {'y': y, 'cb': cb, 'cr': cr} 232 | for k in components.keys(): 233 | comp = self.l2(components[k]) 234 | if k in ('cb', 'cr'): 235 | comp = self.c_quantize(comp, factor=factor) 236 | else: 237 | comp = self.y_quantize(comp, factor=factor) 238 | 239 | components[k] = comp 240 | 241 | return components['y'], components['cb'], components['cr'] 242 | 243 | 244 | # ------------------------ decompression ------------------------# 245 | 246 | 247 | class YDequantize(nn.Module): 248 | """Dequantize Y channel 249 | """ 250 | 251 | def __init__(self): 252 | super(YDequantize, self).__init__() 253 | self.y_table = y_table 254 | 255 | def forward(self, image, factor=1): 256 | """ 257 | Args: 258 | image(tensor): batch x height x width 259 | 260 | Returns: 261 | Tensor: batch x height x width 262 | """ 263 | if isinstance(factor, (int, float)): 264 | out = image * (self.y_table * factor) 265 | else: 266 | b = factor.size(0) 267 | table = self.y_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) 268 | out = image * table 269 | return out 270 | 271 | 272 | class CDequantize(nn.Module): 273 | """Dequantize CbCr channel 274 | """ 275 | 276 | def __init__(self): 277 | super(CDequantize, self).__init__() 278 | self.c_table = c_table 279 | 280 | def forward(self, image, factor=1): 281 | """ 282 | Args: 283 | image(tensor): batch x height x width 284 | 285 | Returns: 286 | Tensor: batch x height x width 287 | """ 288 | if isinstance(factor, (int, float)): 289 | out = image * (self.c_table * factor) 290 | else: 291 | b = factor.size(0) 292 | table = self.c_table.expand(b, 1, 8, 8) * factor.view(b, 1, 1, 1) 293 | out = image * table 294 | return out 295 | 296 | 297 | class iDCT8x8(nn.Module): 298 | """Inverse discrete Cosine Transformation 299 | """ 300 | 301 | def __init__(self): 302 | super(iDCT8x8, self).__init__() 303 | alpha = np.array([1. / np.sqrt(2)] + [1] * 7) 304 | self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float()) 305 | tensor = np.zeros((8, 8, 8, 8), dtype=np.float32) 306 | for x, y, u, v in itertools.product(range(8), repeat=4): 307 | tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos((2 * v + 1) * y * np.pi / 16) 308 | self.tensor = nn.Parameter(torch.from_numpy(tensor).float()) 309 | 310 | def forward(self, image): 311 | """ 312 | Args: 313 | image(tensor): batch x height x width 314 | 315 | Returns: 316 | Tensor: batch x height x width 317 | """ 318 | image = image * self.alpha 319 | result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128 320 | result.view(image.shape) 321 | return result 322 | 323 | 324 | class BlockMerging(nn.Module): 325 | """Merge patches into image 326 | """ 327 | 328 | def __init__(self): 329 | super(BlockMerging, self).__init__() 330 | 331 | def forward(self, patches, height, width): 332 | """ 333 | Args: 334 | patches(tensor) batch x height*width/64, height x width 335 | height(int) 336 | width(int) 337 | 338 | Returns: 339 | Tensor: batch x height x width 340 | """ 341 | k = 8 342 | batch_size = patches.shape[0] 343 | image_reshaped = patches.view(batch_size, height // k, width // k, k, k) 344 | image_transposed = image_reshaped.permute(0, 1, 3, 2, 4) 345 | return image_transposed.contiguous().view(batch_size, height, width) 346 | 347 | 348 | class ChromaUpsampling(nn.Module): 349 | """Upsample chroma layers 350 | """ 351 | 352 | def __init__(self): 353 | super(ChromaUpsampling, self).__init__() 354 | 355 | def forward(self, y, cb, cr): 356 | """ 357 | Args: 358 | y(tensor): y channel image 359 | cb(tensor): cb channel 360 | cr(tensor): cr channel 361 | 362 | Returns: 363 | Tensor: batch x height x width x 3 364 | """ 365 | 366 | def repeat(x, k=2): 367 | height, width = x.shape[1:3] 368 | x = x.unsqueeze(-1) 369 | x = x.repeat(1, 1, k, k) 370 | x = x.view(-1, height * k, width * k) 371 | return x 372 | 373 | cb = repeat(cb) 374 | cr = repeat(cr) 375 | return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3) 376 | 377 | 378 | class YCbCr2RGBJpeg(nn.Module): 379 | """Converts YCbCr image to RGB JPEG 380 | """ 381 | 382 | def __init__(self): 383 | super(YCbCr2RGBJpeg, self).__init__() 384 | 385 | matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T 386 | self.shift = nn.Parameter(torch.tensor([0, -128., -128.])) 387 | self.matrix = nn.Parameter(torch.from_numpy(matrix)) 388 | 389 | def forward(self, image): 390 | """ 391 | Args: 392 | image(tensor): batch x height x width x 3 393 | 394 | Returns: 395 | Tensor: batch x 3 x height x width 396 | """ 397 | result = torch.tensordot(image + self.shift, self.matrix, dims=1) 398 | return result.view(image.shape).permute(0, 3, 1, 2) 399 | 400 | 401 | class DeCompressJpeg(nn.Module): 402 | """Full JPEG decompression algorithm 403 | 404 | Args: 405 | rounding(function): rounding function to use 406 | """ 407 | 408 | def __init__(self, rounding=torch.round): 409 | super(DeCompressJpeg, self).__init__() 410 | self.c_dequantize = CDequantize() 411 | self.y_dequantize = YDequantize() 412 | self.idct = iDCT8x8() 413 | self.merging = BlockMerging() 414 | self.chroma = ChromaUpsampling() 415 | self.colors = YCbCr2RGBJpeg() 416 | 417 | def forward(self, y, cb, cr, imgh, imgw, factor=1): 418 | """ 419 | Args: 420 | compressed(dict(tensor)): batch x h*w/64 x 8 x 8 421 | imgh(int) 422 | imgw(int) 423 | factor(float) 424 | 425 | Returns: 426 | Tensor: batch x 3 x height x width 427 | """ 428 | components = {'y': y, 'cb': cb, 'cr': cr} 429 | for k in components.keys(): 430 | if k in ('cb', 'cr'): 431 | comp = self.c_dequantize(components[k], factor=factor) 432 | height, width = int(imgh / 2), int(imgw / 2) 433 | else: 434 | comp = self.y_dequantize(components[k], factor=factor) 435 | height, width = imgh, imgw 436 | comp = self.idct(comp) 437 | components[k] = self.merging(comp, height, width) 438 | # 439 | image = self.chroma(components['y'], components['cb'], components['cr']) 440 | image = self.colors(image) 441 | 442 | image = torch.min(255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)) 443 | return image / 255 444 | 445 | 446 | # ------------------------ main DiffJPEG ------------------------ # 447 | 448 | 449 | class DiffJPEG(nn.Module): 450 | """This JPEG algorithm result is slightly different from cv2. 451 | DiffJPEG supports batch processing. 452 | 453 | Args: 454 | differentiable(bool): If True, uses custom differentiable rounding function, if False, uses standard torch.round 455 | """ 456 | 457 | def __init__(self, differentiable=True): 458 | super(DiffJPEG, self).__init__() 459 | if differentiable: 460 | rounding = diff_round 461 | else: 462 | rounding = torch.round 463 | 464 | self.compress = CompressJpeg(rounding=rounding) 465 | self.decompress = DeCompressJpeg(rounding=rounding) 466 | 467 | def forward(self, x, quality): 468 | """ 469 | Args: 470 | x (Tensor): Input image, bchw, rgb, [0, 1] 471 | quality(float): Quality factor for jpeg compression scheme. 472 | """ 473 | factor = quality 474 | if isinstance(factor, (int, float)): 475 | factor = quality_to_factor(factor) 476 | else: 477 | for i in range(factor.size(0)): 478 | factor[i] = quality_to_factor(factor[i]) 479 | h, w = x.size()[-2:] 480 | h_pad, w_pad = 0, 0 481 | # why should use 16 482 | if h % 16 != 0: 483 | h_pad = 16 - h % 16 484 | if w % 16 != 0: 485 | w_pad = 16 - w % 16 486 | x = F.pad(x, (0, w_pad, 0, h_pad), mode='constant', value=0) 487 | 488 | y, cb, cr = self.compress(x, factor=factor) 489 | recovered = self.decompress(y, cb, cr, (h + h_pad), (w + w_pad), factor=factor) 490 | recovered = recovered[:, :, 0:h, 0:w] 491 | return recovered 492 | 493 | 494 | if __name__ == '__main__': 495 | import cv2 496 | 497 | from basicsr.utils import img2tensor, tensor2img 498 | 499 | img_gt = cv2.imread('test.png') / 255. 500 | 501 | # -------------- cv2 -------------- # 502 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20] 503 | _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param) 504 | img_lq = np.float32(cv2.imdecode(encimg, 1)) 505 | cv2.imwrite('cv2_JPEG_20.png', img_lq) 506 | 507 | # -------------- DiffJPEG -------------- # 508 | jpeger = DiffJPEG(differentiable=False).cuda() 509 | img_gt = img2tensor(img_gt) 510 | img_gt = torch.stack([img_gt, img_gt]).cuda() 511 | quality = img_gt.new_tensor([20, 40]) 512 | out = jpeger(img_gt, quality=quality) 513 | 514 | cv2.imwrite('pt_JPEG_20.png', tensor2img(out[0])) 515 | cv2.imwrite('pt_JPEG_40.png', tensor2img(out[1])) 516 | -------------------------------------------------------------------------------- /src/utils/utils_blindsr.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import cv2 4 | import torch 5 | 6 | from utils import utils_image as util 7 | # import utils_image as util 8 | 9 | import random 10 | from scipy import ndimage 11 | import scipy 12 | import scipy.stats as ss 13 | from scipy.interpolate import interp2d 14 | from scipy.linalg import orth 15 | 16 | 17 | 18 | 19 | """ 20 | # -------------------------------------------- 21 | # Super-Resolution 22 | # -------------------------------------------- 23 | # 24 | # Kai Zhang (cskaizhang@gmail.com) 25 | # https://github.com/cszn 26 | # From 2019/03--2021/08 27 | # -------------------------------------------- 28 | """ 29 | 30 | def modcrop_np(img, sf): 31 | ''' 32 | Args: 33 | img: numpy image, WxH or WxHxC 34 | sf: scale factor 35 | 36 | Return: 37 | cropped image 38 | ''' 39 | w, h = img.shape[:2] 40 | im = np.copy(img) 41 | return im[:w - w % sf, :h - h % sf, ...] 42 | 43 | 44 | """ 45 | # -------------------------------------------- 46 | # anisotropic Gaussian kernels 47 | # -------------------------------------------- 48 | """ 49 | def analytic_kernel(k): 50 | """Calculate the X4 kernel from the X2 kernel (for proof see appendix in paper)""" 51 | k_size = k.shape[0] 52 | # Calculate the big kernels size 53 | big_k = np.zeros((3 * k_size - 2, 3 * k_size - 2)) 54 | # Loop over the small kernel to fill the big one 55 | for r in range(k_size): 56 | for c in range(k_size): 57 | big_k[2 * r:2 * r + k_size, 2 * c:2 * c + k_size] += k[r, c] * k 58 | # Crop the edges of the big kernel to ignore very small values and increase run time of SR 59 | crop = k_size // 2 60 | cropped_big_k = big_k[crop:-crop, crop:-crop] 61 | # Normalize to 1 62 | return cropped_big_k / cropped_big_k.sum() 63 | 64 | 65 | def anisotropic_Gaussian(ksize=15, theta=np.pi, l1=6, l2=6): 66 | """ generate an anisotropic Gaussian kernel 67 | Args: 68 | ksize : e.g., 15, kernel size 69 | theta : [0, pi], rotation angle range 70 | l1 : [0.1,50], scaling of eigenvalues 71 | l2 : [0.1,l1], scaling of eigenvalues 72 | If l1 = l2, will get an isotropic Gaussian kernel. 73 | 74 | Returns: 75 | k : kernel 76 | """ 77 | 78 | v = np.dot(np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]), np.array([1., 0.])) 79 | V = np.array([[v[0], v[1]], [v[1], -v[0]]]) 80 | D = np.array([[l1, 0], [0, l2]]) 81 | Sigma = np.dot(np.dot(V, D), np.linalg.inv(V)) 82 | k = gm_blur_kernel(mean=[0, 0], cov=Sigma, size=ksize) 83 | 84 | return k 85 | 86 | 87 | def gm_blur_kernel(mean, cov, size=15): 88 | center = size / 2.0 + 0.5 89 | k = np.zeros([size, size]) 90 | for y in range(size): 91 | for x in range(size): 92 | cy = y - center + 1 93 | cx = x - center + 1 94 | k[y, x] = ss.multivariate_normal.pdf([cx, cy], mean=mean, cov=cov) 95 | 96 | k = k / np.sum(k) 97 | return k 98 | 99 | 100 | def shift_pixel(x, sf, upper_left=True): 101 | """shift pixel for super-resolution with different scale factors 102 | Args: 103 | x: WxHxC or WxH 104 | sf: scale factor 105 | upper_left: shift direction 106 | """ 107 | h, w = x.shape[:2] 108 | shift = (sf-1)*0.5 109 | xv, yv = np.arange(0, w, 1.0), np.arange(0, h, 1.0) 110 | if upper_left: 111 | x1 = xv + shift 112 | y1 = yv + shift 113 | else: 114 | x1 = xv - shift 115 | y1 = yv - shift 116 | 117 | x1 = np.clip(x1, 0, w-1) 118 | y1 = np.clip(y1, 0, h-1) 119 | 120 | if x.ndim == 2: 121 | x = interp2d(xv, yv, x)(x1, y1) 122 | if x.ndim == 3: 123 | for i in range(x.shape[-1]): 124 | x[:, :, i] = interp2d(xv, yv, x[:, :, i])(x1, y1) 125 | 126 | return x 127 | 128 | 129 | def blur(x, k): 130 | ''' 131 | x: image, NxcxHxW 132 | k: kernel, Nx1xhxw 133 | ''' 134 | n, c = x.shape[:2] 135 | p1, p2 = (k.shape[-2]-1)//2, (k.shape[-1]-1)//2 136 | x = torch.nn.functional.pad(x, pad=(p1, p2, p1, p2), mode='replicate') 137 | k = k.repeat(1,c,1,1) 138 | k = k.view(-1, 1, k.shape[2], k.shape[3]) 139 | x = x.view(1, -1, x.shape[2], x.shape[3]) 140 | x = torch.nn.functional.conv2d(x, k, bias=None, stride=1, padding=0, groups=n*c) 141 | x = x.view(n, c, x.shape[2], x.shape[3]) 142 | 143 | return x 144 | 145 | 146 | 147 | def gen_kernel(k_size=np.array([15, 15]), scale_factor=np.array([4, 4]), min_var=0.6, max_var=10., noise_level=0): 148 | """" 149 | # modified version of https://github.com/assafshocher/BlindSR_dataset_generator 150 | # Kai Zhang 151 | # min_var = 0.175 * sf # variance of the gaussian kernel will be sampled between min_var and max_var 152 | # max_var = 2.5 * sf 153 | """ 154 | # Set random eigen-vals (lambdas) and angle (theta) for COV matrix 155 | lambda_1 = min_var + np.random.rand() * (max_var - min_var) 156 | lambda_2 = min_var + np.random.rand() * (max_var - min_var) 157 | theta = np.random.rand() * np.pi # random theta 158 | noise = -noise_level + np.random.rand(*k_size) * noise_level * 2 159 | 160 | # Set COV matrix using Lambdas and Theta 161 | LAMBDA = np.diag([lambda_1, lambda_2]) 162 | Q = np.array([[np.cos(theta), -np.sin(theta)], 163 | [np.sin(theta), np.cos(theta)]]) 164 | SIGMA = Q @ LAMBDA @ Q.T 165 | INV_SIGMA = np.linalg.inv(SIGMA)[None, None, :, :] 166 | 167 | # Set expectation position (shifting kernel for aligned image) 168 | MU = k_size // 2 - 0.5*(scale_factor - 1) # - 0.5 * (scale_factor - k_size % 2) 169 | MU = MU[None, None, :, None] 170 | 171 | # Create meshgrid for Gaussian 172 | [X,Y] = np.meshgrid(range(k_size[0]), range(k_size[1])) 173 | Z = np.stack([X, Y], 2)[:, :, :, None] 174 | 175 | # Calcualte Gaussian for every pixel of the kernel 176 | ZZ = Z-MU 177 | ZZ_t = ZZ.transpose(0,1,3,2) 178 | raw_kernel = np.exp(-0.5 * np.squeeze(ZZ_t @ INV_SIGMA @ ZZ)) * (1 + noise) 179 | 180 | # shift the kernel so it will be centered 181 | #raw_kernel_centered = kernel_shift(raw_kernel, scale_factor) 182 | 183 | # Normalize the kernel and return 184 | #kernel = raw_kernel_centered / np.sum(raw_kernel_centered) 185 | kernel = raw_kernel / np.sum(raw_kernel) 186 | return kernel 187 | 188 | 189 | def fspecial_gaussian(hsize, sigma): 190 | hsize = [hsize, hsize] 191 | siz = [(hsize[0]-1.0)/2.0, (hsize[1]-1.0)/2.0] 192 | std = sigma 193 | [x, y] = np.meshgrid(np.arange(-siz[1], siz[1]+1), np.arange(-siz[0], siz[0]+1)) 194 | arg = -(x*x + y*y)/(2*std*std) 195 | h = np.exp(arg) 196 | h[h < scipy.finfo(float).eps * h.max()] = 0 197 | sumh = h.sum() 198 | if sumh != 0: 199 | h = h/sumh 200 | return h 201 | 202 | 203 | def fspecial_laplacian(alpha): 204 | alpha = max([0, min([alpha,1])]) 205 | h1 = alpha/(alpha+1) 206 | h2 = (1-alpha)/(alpha+1) 207 | h = [[h1, h2, h1], [h2, -4/(alpha+1), h2], [h1, h2, h1]] 208 | h = np.array(h) 209 | return h 210 | 211 | 212 | def fspecial(filter_type, *args, **kwargs): 213 | ''' 214 | python code from: 215 | https://github.com/ronaldosena/imagens-medicas-2/blob/40171a6c259edec7827a6693a93955de2bd39e76/Aulas/aula_2_-_uniform_filter/matlab_fspecial.py 216 | ''' 217 | if filter_type == 'gaussian': 218 | return fspecial_gaussian(*args, **kwargs) 219 | if filter_type == 'laplacian': 220 | return fspecial_laplacian(*args, **kwargs) 221 | 222 | """ 223 | # -------------------------------------------- 224 | # degradation models 225 | # -------------------------------------------- 226 | """ 227 | 228 | 229 | def bicubic_degradation(x, sf=3): 230 | ''' 231 | Args: 232 | x: HxWxC image, [0, 1] 233 | sf: down-scale factor 234 | 235 | Return: 236 | bicubicly downsampled LR image 237 | ''' 238 | x = util.imresize_np(x, scale=1/sf) 239 | return x 240 | 241 | 242 | def srmd_degradation(x, k, sf=3): 243 | ''' blur + bicubic downsampling 244 | 245 | Args: 246 | x: HxWxC image, [0, 1] 247 | k: hxw, double 248 | sf: down-scale factor 249 | 250 | Return: 251 | downsampled LR image 252 | 253 | Reference: 254 | @inproceedings{zhang2018learning, 255 | title={Learning a single convolutional super-resolution network for multiple degradations}, 256 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 257 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 258 | pages={3262--3271}, 259 | year={2018} 260 | } 261 | ''' 262 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') # 'nearest' | 'mirror' 263 | x = bicubic_degradation(x, sf=sf) 264 | return x 265 | 266 | 267 | def dpsr_degradation(x, k, sf=3): 268 | 269 | ''' bicubic downsampling + blur 270 | 271 | Args: 272 | x: HxWxC image, [0, 1] 273 | k: hxw, double 274 | sf: down-scale factor 275 | 276 | Return: 277 | downsampled LR image 278 | 279 | Reference: 280 | @inproceedings{zhang2019deep, 281 | title={Deep Plug-and-Play Super-Resolution for Arbitrary Blur Kernels}, 282 | author={Zhang, Kai and Zuo, Wangmeng and Zhang, Lei}, 283 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition}, 284 | pages={1671--1681}, 285 | year={2019} 286 | } 287 | ''' 288 | x = bicubic_degradation(x, sf=sf) 289 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 290 | return x 291 | 292 | 293 | def classical_degradation(x, k, sf=3): 294 | ''' blur + downsampling 295 | 296 | Args: 297 | x: HxWxC image, [0, 1]/[0, 255] 298 | k: hxw, double 299 | sf: down-scale factor 300 | 301 | Return: 302 | downsampled LR image 303 | ''' 304 | x = ndimage.filters.convolve(x, np.expand_dims(k, axis=2), mode='wrap') 305 | #x = filters.correlate(x, np.expand_dims(np.flip(k), axis=2)) 306 | st = 0 307 | return x[st::sf, st::sf, ...] 308 | 309 | 310 | def add_sharpening(img, weight=0.5, radius=50, threshold=10): 311 | """USM sharpening. borrowed from real-ESRGAN 312 | Input image: I; Blurry image: B. 313 | 1. K = I + weight * (I - B) 314 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 315 | 3. Blur mask: 316 | 4. Out = Mask * K + (1 - Mask) * I 317 | Args: 318 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 319 | weight (float): Sharp weight. Default: 1. 320 | radius (float): Kernel size of Gaussian blur. Default: 50. 321 | threshold (int): 322 | """ 323 | if radius % 2 == 0: 324 | radius += 1 325 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 326 | residual = img - blur 327 | mask = np.abs(residual) * 255 > threshold 328 | mask = mask.astype('float32') 329 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 330 | 331 | K = img + weight * residual 332 | K = np.clip(K, 0, 1) 333 | return soft_mask * K + (1 - soft_mask) * img 334 | 335 | 336 | def add_blur(img, sf=4): 337 | wd2 = 4.0 + sf 338 | wd = 2.0 + 0.2*sf 339 | if random.random() < 0.5: 340 | l1 = wd2*random.random() 341 | l2 = wd2*random.random() 342 | k = anisotropic_Gaussian(ksize=2*random.randint(2,11)+3, theta=random.random()*np.pi, l1=l1, l2=l2) 343 | else: 344 | k = fspecial('gaussian', 2*random.randint(2,11)+3, wd*random.random()) 345 | img = ndimage.filters.convolve(img, np.expand_dims(k, axis=2), mode='mirror') 346 | 347 | return img 348 | 349 | 350 | def add_resize(img, sf=4): 351 | rnum = np.random.rand() 352 | if rnum > 0.8: # up 353 | sf1 = random.uniform(1, 2) 354 | elif rnum < 0.7: # down 355 | sf1 = random.uniform(0.5/sf, 1) 356 | else: 357 | sf1 = 1.0 358 | img = cv2.resize(img, (int(sf1*img.shape[1]), int(sf1*img.shape[0])), interpolation=random.choice([1, 2, 3])) 359 | img = np.clip(img, 0.0, 1.0) 360 | 361 | return img 362 | 363 | 364 | def add_Gaussian_noise(img, noise_level1=2, noise_level2=25): 365 | noise_level = random.randint(noise_level1, noise_level2) 366 | rnum = np.random.rand() 367 | if rnum > 0.6: # add color Gaussian noise 368 | img += np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) 369 | elif rnum < 0.4: # add grayscale Gaussian noise 370 | img += np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) 371 | else: # add noise 372 | L = noise_level2/255. 373 | D = np.diag(np.random.rand(3)) 374 | U = orth(np.random.rand(3,3)) 375 | conv = np.dot(np.dot(np.transpose(U), D), U) 376 | img += np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) 377 | img = np.clip(img, 0.0, 1.0) 378 | return img 379 | 380 | 381 | def add_speckle_noise(img, noise_level1=2, noise_level2=25): 382 | noise_level = random.randint(noise_level1, noise_level2) 383 | img = np.clip(img, 0.0, 1.0) 384 | rnum = random.random() 385 | if rnum > 0.6: 386 | img += img*np.random.normal(0, noise_level/255.0, img.shape).astype(np.float32) 387 | elif rnum < 0.4: 388 | img += img*np.random.normal(0, noise_level/255.0, (*img.shape[:2], 1)).astype(np.float32) 389 | else: 390 | L = noise_level2/255. 391 | D = np.diag(np.random.rand(3)) 392 | U = orth(np.random.rand(3,3)) 393 | conv = np.dot(np.dot(np.transpose(U), D), U) 394 | img += img*np.random.multivariate_normal([0,0,0], np.abs(L**2*conv), img.shape[:2]).astype(np.float32) 395 | img = np.clip(img, 0.0, 1.0) 396 | return img 397 | 398 | 399 | def add_Poisson_noise(img): 400 | img = np.clip((img * 255.0).round(), 0, 255) / 255. 401 | vals = 10**(2*random.random()+2.0) # [2, 4] 402 | if random.random() < 0.5: 403 | img = np.random.poisson(img * vals).astype(np.float32) / vals 404 | else: 405 | img_gray = np.dot(img[...,:3], [0.299, 0.587, 0.114]) 406 | img_gray = np.clip((img_gray * 255.0).round(), 0, 255) / 255. 407 | noise_gray = np.random.poisson(img_gray * vals).astype(np.float32) / vals - img_gray 408 | img += noise_gray[:, :, np.newaxis] 409 | img = np.clip(img, 0.0, 1.0) 410 | return img 411 | 412 | 413 | def add_JPEG_noise(img, min_q=30): 414 | quality_factor = random.randint(min_q, 95) 415 | img = cv2.cvtColor(util.single2uint(img), cv2.COLOR_RGB2BGR) 416 | result, encimg = cv2.imencode('.jpg', img, [int(cv2.IMWRITE_JPEG_QUALITY), quality_factor]) 417 | img = cv2.imdecode(encimg, 1) 418 | img = cv2.cvtColor(util.uint2single(img), cv2.COLOR_BGR2RGB) 419 | return img 420 | 421 | 422 | def random_crop(lq, hq, sf=4, lq_patchsize=64): 423 | h, w = lq.shape[:2] 424 | rnd_h = random.randint(0, h-lq_patchsize) 425 | rnd_w = random.randint(0, w-lq_patchsize) 426 | lq = lq[rnd_h:rnd_h + lq_patchsize, rnd_w:rnd_w + lq_patchsize, :] 427 | 428 | rnd_h_H, rnd_w_H = int(rnd_h * sf), int(rnd_w * sf) 429 | hq = hq[rnd_h_H:rnd_h_H + lq_patchsize*sf, rnd_w_H:rnd_w_H + lq_patchsize*sf, :] 430 | return lq, hq 431 | 432 | 433 | def degradation_bsrgan(img, sf=4, lq_patchsize=72, isp_model=None): 434 | """ 435 | This is the degradation model of BSRGAN from the paper 436 | "Designing a Practical Degradation Model for Deep Blind Image Super-Resolution" 437 | ---------- 438 | img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) 439 | sf: scale factor 440 | isp_model: camera ISP model 441 | 442 | Returns 443 | ------- 444 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 445 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 446 | """ 447 | isp_prob, jpeg_prob, scale2_prob = 0.25, 0.9, 0.25 448 | sf_ori = sf 449 | 450 | h1, w1 = img.shape[:2] 451 | img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop 452 | h, w = img.shape[:2] 453 | 454 | if h < lq_patchsize*sf or w < lq_patchsize*sf: 455 | raise ValueError(f'img size ({h1}X{w1}) is too small!') 456 | 457 | hq = img.copy() 458 | 459 | if sf == 4 and random.random() < scale2_prob: # downsample1 460 | if np.random.rand() < 0.5: 461 | img = cv2.resize(img, (int(1/2*img.shape[1]), int(1/2*img.shape[0])), interpolation=random.choice([1,2,3])) 462 | else: 463 | img = util.imresize_np(img, 1/2, True) 464 | img = np.clip(img, 0.0, 1.0) 465 | sf = 2 466 | 467 | shuffle_order = random.sample(range(7), 7) 468 | idx1, idx2 = shuffle_order.index(2), shuffle_order.index(3) 469 | if idx1 > idx2: # keep downsample3 last 470 | shuffle_order[idx1], shuffle_order[idx2] = shuffle_order[idx2], shuffle_order[idx1] 471 | 472 | for i in shuffle_order: 473 | 474 | if i == 0: 475 | img = add_blur(img, sf=sf) 476 | 477 | elif i == 1: 478 | img = add_blur(img, sf=sf) 479 | 480 | elif i == 2: 481 | a, b = img.shape[1], img.shape[0] 482 | # downsample2 483 | if random.random() < 0.75: 484 | sf1 = random.uniform(1,2*sf) 485 | img = cv2.resize(img, (int(1/sf1*img.shape[1]), int(1/sf1*img.shape[0])), interpolation=random.choice([1,2,3])) 486 | else: 487 | k = fspecial('gaussian', 25, random.uniform(0.1, 0.6*sf)) 488 | k_shifted = shift_pixel(k, sf) 489 | k_shifted = k_shifted/k_shifted.sum() # blur with shifted kernel 490 | img = ndimage.filters.convolve(img, np.expand_dims(k_shifted, axis=2), mode='mirror') 491 | img = img[0::sf, 0::sf, ...] # nearest downsampling 492 | img = np.clip(img, 0.0, 1.0) 493 | 494 | elif i == 3: 495 | # downsample3 496 | img = cv2.resize(img, (int(1/sf*a), int(1/sf*b)), interpolation=random.choice([1,2,3])) 497 | img = np.clip(img, 0.0, 1.0) 498 | 499 | elif i == 4: 500 | # add Gaussian noise 501 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 502 | 503 | elif i == 5: 504 | # add JPEG noise 505 | if random.random() < jpeg_prob: 506 | img = add_JPEG_noise(img) 507 | 508 | elif i == 6: 509 | # add processed camera sensor noise 510 | if random.random() < isp_prob and isp_model is not None: 511 | with torch.no_grad(): 512 | img, hq = isp_model.forward(img.copy(), hq) 513 | 514 | # add final JPEG compression noise 515 | img = add_JPEG_noise(img) 516 | 517 | # random crop 518 | img, hq = random_crop(img, hq, sf_ori, lq_patchsize) 519 | 520 | return img, hq 521 | 522 | 523 | 524 | 525 | def degradation_bsrgan_plus(img, sf=4, shuffle_prob=0.5, use_sharp=True, lq_patchsize=64, isp_model=None): 526 | """ 527 | This is an extended degradation model by combining 528 | the degradation models of BSRGAN and Real-ESRGAN 529 | ---------- 530 | img: HXWXC, [0, 1], its size should be large than (lq_patchsizexsf)x(lq_patchsizexsf) 531 | sf: scale factor 532 | use_shuffle: the degradation shuffle 533 | use_sharp: sharpening the img 534 | 535 | Returns 536 | ------- 537 | img: low-quality patch, size: lq_patchsizeXlq_patchsizeXC, range: [0, 1] 538 | hq: corresponding high-quality patch, size: (lq_patchsizexsf)X(lq_patchsizexsf)XC, range: [0, 1] 539 | """ 540 | 541 | h1, w1 = img.shape[:2] 542 | img = img.copy()[:h1 - h1 % sf, :w1 - w1 % sf, ...] # mod crop 543 | h, w = img.shape[:2] 544 | 545 | if h < lq_patchsize*sf or w < lq_patchsize*sf: 546 | raise ValueError(f'img size ({h1}X{w1}) is too small!') 547 | 548 | if use_sharp: 549 | img = add_sharpening(img) 550 | hq = img.copy() 551 | 552 | if random.random() < shuffle_prob: 553 | shuffle_order = random.sample(range(13), 13) 554 | else: 555 | shuffle_order = list(range(13)) 556 | # local shuffle for noise, JPEG is always the last one 557 | shuffle_order[2:6] = random.sample(shuffle_order[2:6], len(range(2, 6))) 558 | shuffle_order[9:13] = random.sample(shuffle_order[9:13], len(range(9, 13))) 559 | 560 | poisson_prob, speckle_prob, isp_prob = 0.1, 0.1, 0.1 561 | 562 | for i in shuffle_order: 563 | if i == 0: 564 | img = add_blur(img, sf=sf) 565 | elif i == 1: 566 | img = add_resize(img, sf=sf) 567 | elif i == 2: 568 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 569 | elif i == 3: 570 | if random.random() < poisson_prob: 571 | img = add_Poisson_noise(img) 572 | elif i == 4: 573 | if random.random() < speckle_prob: 574 | img = add_speckle_noise(img) 575 | elif i == 5: 576 | if random.random() < isp_prob and isp_model is not None: 577 | with torch.no_grad(): 578 | img, hq = isp_model.forward(img.copy(), hq) 579 | elif i == 6: 580 | img = add_JPEG_noise(img) 581 | elif i == 7: 582 | img = add_blur(img, sf=sf) 583 | elif i == 8: 584 | img = add_resize(img, sf=sf) 585 | elif i == 9: 586 | img = add_Gaussian_noise(img, noise_level1=2, noise_level2=25) 587 | elif i == 10: 588 | if random.random() < poisson_prob: 589 | img = add_Poisson_noise(img) 590 | elif i == 11: 591 | if random.random() < speckle_prob: 592 | img = add_speckle_noise(img) 593 | elif i == 12: 594 | if random.random() < isp_prob and isp_model is not None: 595 | with torch.no_grad(): 596 | img, hq = isp_model.forward(img.copy(), hq) 597 | else: 598 | print('check the shuffle!') 599 | 600 | # resize to desired size 601 | img = cv2.resize(img, (int(1/sf*hq.shape[1]), int(1/sf*hq.shape[0])), interpolation=random.choice([1, 2, 3])) 602 | 603 | # add final JPEG compression noise 604 | img = add_JPEG_noise(img) 605 | 606 | # random crop 607 | img, hq = random_crop(img, hq, sf, lq_patchsize) 608 | 609 | return img, hq 610 | 611 | 612 | 613 | if __name__ == '__main__': 614 | img = util.imread_uint('utils/test.png', 3) 615 | img = util.uint2single(img) 616 | sf = 4 617 | 618 | for i in range(20): 619 | img_lq, img_hq = degradation_bsrgan(img, sf=sf, lq_patchsize=72) 620 | print(i) 621 | lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) 622 | img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) 623 | util.imsave(img_concat, str(i)+'.png') 624 | 625 | # for i in range(10): 626 | # img_lq, img_hq = degradation_bsrgan_plus(img, sf=sf, shuffle_prob=0.1, use_sharp=True, lq_patchsize=64) 627 | # print(i) 628 | # lq_nearest = cv2.resize(util.single2uint(img_lq), (int(sf*img_lq.shape[1]), int(sf*img_lq.shape[0])), interpolation=0) 629 | # img_concat = np.concatenate([lq_nearest, util.single2uint(img_hq)], axis=1) 630 | # util.imsave(img_concat, str(i)+'.png') 631 | 632 | # run utils/utils_blindsr.py 633 | -------------------------------------------------------------------------------- /src/utils/basicsr_degradations.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import random 5 | import torch 6 | from scipy import special 7 | from scipy.stats import multivariate_normal 8 | from torchvision.transforms.functional_tensor import rgb_to_grayscale 9 | 10 | # -------------------------------------------------------------------- # 11 | # --------------------------- blur kernels --------------------------- # 12 | # -------------------------------------------------------------------- # 13 | 14 | 15 | # --------------------------- util functions --------------------------- # 16 | def sigma_matrix2(sig_x, sig_y, theta): 17 | """Calculate the rotated sigma matrix (two dimensional matrix). 18 | 19 | Args: 20 | sig_x (float): 21 | sig_y (float): 22 | theta (float): Radian measurement. 23 | 24 | Returns: 25 | ndarray: Rotated sigma matrix. 26 | """ 27 | d_matrix = np.array([[sig_x**2, 0], [0, sig_y**2]]) 28 | u_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]]) 29 | return np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) 30 | 31 | 32 | def mesh_grid(kernel_size): 33 | """Generate the mesh grid, centering at zero. 34 | 35 | Args: 36 | kernel_size (int): 37 | 38 | Returns: 39 | xy (ndarray): with the shape (kernel_size, kernel_size, 2) 40 | xx (ndarray): with the shape (kernel_size, kernel_size) 41 | yy (ndarray): with the shape (kernel_size, kernel_size) 42 | """ 43 | ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.) 44 | xx, yy = np.meshgrid(ax, ax) 45 | xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, 46 | 1))).reshape(kernel_size, kernel_size, 2) 47 | return xy, xx, yy 48 | 49 | 50 | def pdf2(sigma_matrix, grid): 51 | """Calculate PDF of the bivariate Gaussian distribution. 52 | 53 | Args: 54 | sigma_matrix (ndarray): with the shape (2, 2) 55 | grid (ndarray): generated by :func:`mesh_grid`, 56 | with the shape (K, K, 2), K is the kernel size. 57 | 58 | Returns: 59 | kernel (ndarrray): un-normalized kernel. 60 | """ 61 | inverse_sigma = np.linalg.inv(sigma_matrix) 62 | kernel = np.exp(-0.5 * np.sum(np.dot(grid, inverse_sigma) * grid, 2)) 63 | return kernel 64 | 65 | 66 | def cdf2(d_matrix, grid): 67 | """Calculate the CDF of the standard bivariate Gaussian distribution. 68 | Used in skewed Gaussian distribution. 69 | 70 | Args: 71 | d_matrix (ndarrasy): skew matrix. 72 | grid (ndarray): generated by :func:`mesh_grid`, 73 | with the shape (K, K, 2), K is the kernel size. 74 | 75 | Returns: 76 | cdf (ndarray): skewed cdf. 77 | """ 78 | rv = multivariate_normal([0, 0], [[1, 0], [0, 1]]) 79 | grid = np.dot(grid, d_matrix) 80 | cdf = rv.cdf(grid) 81 | return cdf 82 | 83 | 84 | def bivariate_Gaussian(kernel_size, sig_x, sig_y, theta, grid=None, isotropic=True): 85 | """Generate a bivariate isotropic or anisotropic Gaussian kernel. 86 | 87 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 88 | 89 | Args: 90 | kernel_size (int): 91 | sig_x (float): 92 | sig_y (float): 93 | theta (float): Radian measurement. 94 | grid (ndarray, optional): generated by :func:`mesh_grid`, 95 | with the shape (K, K, 2), K is the kernel size. Default: None 96 | isotropic (bool): 97 | 98 | Returns: 99 | kernel (ndarray): normalized kernel. 100 | """ 101 | if grid is None: 102 | grid, _, _ = mesh_grid(kernel_size) 103 | if isotropic: 104 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 105 | else: 106 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 107 | kernel = pdf2(sigma_matrix, grid) 108 | kernel = kernel / np.sum(kernel) 109 | return kernel 110 | 111 | 112 | def bivariate_generalized_Gaussian(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): 113 | """Generate a bivariate generalized Gaussian kernel. 114 | Described in `Parameter Estimation For Multivariate Generalized 115 | Gaussian Distributions`_ 116 | by Pascal et. al (2013). 117 | 118 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 119 | 120 | Args: 121 | kernel_size (int): 122 | sig_x (float): 123 | sig_y (float): 124 | theta (float): Radian measurement. 125 | beta (float): shape parameter, beta = 1 is the normal distribution. 126 | grid (ndarray, optional): generated by :func:`mesh_grid`, 127 | with the shape (K, K, 2), K is the kernel size. Default: None 128 | 129 | Returns: 130 | kernel (ndarray): normalized kernel. 131 | 132 | .. _Parameter Estimation For Multivariate Generalized Gaussian 133 | Distributions: https://arxiv.org/abs/1302.6498 134 | """ 135 | if grid is None: 136 | grid, _, _ = mesh_grid(kernel_size) 137 | if isotropic: 138 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 139 | else: 140 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 141 | inverse_sigma = np.linalg.inv(sigma_matrix) 142 | kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) 143 | kernel = kernel / np.sum(kernel) 144 | return kernel 145 | 146 | 147 | def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotropic=True): 148 | """Generate a plateau-like anisotropic kernel. 149 | 1 / (1+x^(beta)) 150 | 151 | Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution 152 | 153 | In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored. 154 | 155 | Args: 156 | kernel_size (int): 157 | sig_x (float): 158 | sig_y (float): 159 | theta (float): Radian measurement. 160 | beta (float): shape parameter, beta = 1 is the normal distribution. 161 | grid (ndarray, optional): generated by :func:`mesh_grid`, 162 | with the shape (K, K, 2), K is the kernel size. Default: None 163 | 164 | Returns: 165 | kernel (ndarray): normalized kernel. 166 | """ 167 | if grid is None: 168 | grid, _, _ = mesh_grid(kernel_size) 169 | if isotropic: 170 | sigma_matrix = np.array([[sig_x**2, 0], [0, sig_x**2]]) 171 | else: 172 | sigma_matrix = sigma_matrix2(sig_x, sig_y, theta) 173 | inverse_sigma = np.linalg.inv(sigma_matrix) 174 | kernel = np.reciprocal(np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta) + 1) 175 | kernel = kernel / np.sum(kernel) 176 | return kernel 177 | 178 | 179 | def random_bivariate_Gaussian(kernel_size, 180 | sigma_x_range, 181 | sigma_y_range, 182 | rotation_range, 183 | noise_range=None, 184 | isotropic=True): 185 | """Randomly generate bivariate isotropic or anisotropic Gaussian kernels. 186 | 187 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 188 | 189 | Args: 190 | kernel_size (int): 191 | sigma_x_range (tuple): [0.6, 5] 192 | sigma_y_range (tuple): [0.6, 5] 193 | rotation range (tuple): [-math.pi, math.pi] 194 | noise_range(tuple, optional): multiplicative kernel noise, 195 | [0.75, 1.25]. Default: None 196 | 197 | Returns: 198 | kernel (ndarray): 199 | """ 200 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 201 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 202 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 203 | if isotropic is False: 204 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 205 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 206 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 207 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 208 | else: 209 | sigma_y = sigma_x 210 | rotation = 0 211 | 212 | kernel = bivariate_Gaussian(kernel_size, sigma_x, sigma_y, rotation, isotropic=isotropic) 213 | 214 | # add multiplicative noise 215 | if noise_range is not None: 216 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 217 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 218 | kernel = kernel * noise 219 | kernel = kernel / np.sum(kernel) 220 | return kernel 221 | 222 | 223 | def random_bivariate_generalized_Gaussian(kernel_size, 224 | sigma_x_range, 225 | sigma_y_range, 226 | rotation_range, 227 | beta_range, 228 | noise_range=None, 229 | isotropic=True): 230 | """Randomly generate bivariate generalized Gaussian kernels. 231 | 232 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 233 | 234 | Args: 235 | kernel_size (int): 236 | sigma_x_range (tuple): [0.6, 5] 237 | sigma_y_range (tuple): [0.6, 5] 238 | rotation range (tuple): [-math.pi, math.pi] 239 | beta_range (tuple): [0.5, 8] 240 | noise_range(tuple, optional): multiplicative kernel noise, 241 | [0.75, 1.25]. Default: None 242 | 243 | Returns: 244 | kernel (ndarray): 245 | """ 246 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 247 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 248 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 249 | if isotropic is False: 250 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 251 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 252 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 253 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 254 | else: 255 | sigma_y = sigma_x 256 | rotation = 0 257 | 258 | # assume beta_range[0] < 1 < beta_range[1] 259 | if np.random.uniform() < 0.5: 260 | beta = np.random.uniform(beta_range[0], 1) 261 | else: 262 | beta = np.random.uniform(1, beta_range[1]) 263 | 264 | kernel = bivariate_generalized_Gaussian(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) 265 | 266 | # add multiplicative noise 267 | if noise_range is not None: 268 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 269 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 270 | kernel = kernel * noise 271 | kernel = kernel / np.sum(kernel) 272 | return kernel 273 | 274 | 275 | def random_bivariate_plateau(kernel_size, 276 | sigma_x_range, 277 | sigma_y_range, 278 | rotation_range, 279 | beta_range, 280 | noise_range=None, 281 | isotropic=True): 282 | """Randomly generate bivariate plateau kernels. 283 | 284 | In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored. 285 | 286 | Args: 287 | kernel_size (int): 288 | sigma_x_range (tuple): [0.6, 5] 289 | sigma_y_range (tuple): [0.6, 5] 290 | rotation range (tuple): [-math.pi/2, math.pi/2] 291 | beta_range (tuple): [1, 4] 292 | noise_range(tuple, optional): multiplicative kernel noise, 293 | [0.75, 1.25]. Default: None 294 | 295 | Returns: 296 | kernel (ndarray): 297 | """ 298 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 299 | assert sigma_x_range[0] < sigma_x_range[1], 'Wrong sigma_x_range.' 300 | sigma_x = np.random.uniform(sigma_x_range[0], sigma_x_range[1]) 301 | if isotropic is False: 302 | assert sigma_y_range[0] < sigma_y_range[1], 'Wrong sigma_y_range.' 303 | assert rotation_range[0] < rotation_range[1], 'Wrong rotation_range.' 304 | sigma_y = np.random.uniform(sigma_y_range[0], sigma_y_range[1]) 305 | rotation = np.random.uniform(rotation_range[0], rotation_range[1]) 306 | else: 307 | sigma_y = sigma_x 308 | rotation = 0 309 | 310 | # TODO: this may be not proper 311 | if np.random.uniform() < 0.5: 312 | beta = np.random.uniform(beta_range[0], 1) 313 | else: 314 | beta = np.random.uniform(1, beta_range[1]) 315 | 316 | kernel = bivariate_plateau(kernel_size, sigma_x, sigma_y, rotation, beta, isotropic=isotropic) 317 | # add multiplicative noise 318 | if noise_range is not None: 319 | assert noise_range[0] < noise_range[1], 'Wrong noise range.' 320 | noise = np.random.uniform(noise_range[0], noise_range[1], size=kernel.shape) 321 | kernel = kernel * noise 322 | kernel = kernel / np.sum(kernel) 323 | 324 | return kernel 325 | 326 | 327 | def random_mixed_kernels(kernel_list, 328 | kernel_prob, 329 | kernel_size=21, 330 | sigma_x_range=(0.6, 5), 331 | sigma_y_range=(0.6, 5), 332 | rotation_range=(-math.pi, math.pi), 333 | betag_range=(0.5, 8), 334 | betap_range=(0.5, 8), 335 | noise_range=None): 336 | """Randomly generate mixed kernels. 337 | 338 | Args: 339 | kernel_list (tuple): a list name of kernel types, 340 | support ['iso', 'aniso', 'skew', 'generalized', 'plateau_iso', 341 | 'plateau_aniso'] 342 | kernel_prob (tuple): corresponding kernel probability for each 343 | kernel type 344 | kernel_size (int): 345 | sigma_x_range (tuple): [0.6, 5] 346 | sigma_y_range (tuple): [0.6, 5] 347 | rotation range (tuple): [-math.pi, math.pi] 348 | beta_range (tuple): [0.5, 8] 349 | noise_range(tuple, optional): multiplicative kernel noise, 350 | [0.75, 1.25]. Default: None 351 | 352 | Returns: 353 | kernel (ndarray): 354 | """ 355 | kernel_type = random.choices(kernel_list, kernel_prob)[0] 356 | if kernel_type == 'iso': 357 | kernel = random_bivariate_Gaussian( 358 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True) 359 | elif kernel_type == 'aniso': 360 | kernel = random_bivariate_Gaussian( 361 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False) 362 | elif kernel_type == 'generalized_iso': 363 | kernel = random_bivariate_generalized_Gaussian( 364 | kernel_size, 365 | sigma_x_range, 366 | sigma_y_range, 367 | rotation_range, 368 | betag_range, 369 | noise_range=noise_range, 370 | isotropic=True) 371 | elif kernel_type == 'generalized_aniso': 372 | kernel = random_bivariate_generalized_Gaussian( 373 | kernel_size, 374 | sigma_x_range, 375 | sigma_y_range, 376 | rotation_range, 377 | betag_range, 378 | noise_range=noise_range, 379 | isotropic=False) 380 | elif kernel_type == 'plateau_iso': 381 | kernel = random_bivariate_plateau( 382 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True) 383 | elif kernel_type == 'plateau_aniso': 384 | kernel = random_bivariate_plateau( 385 | kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False) 386 | return kernel 387 | 388 | 389 | np.seterr(divide='ignore', invalid='ignore') 390 | 391 | 392 | def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0): 393 | """2D sinc filter, ref: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter 394 | 395 | Args: 396 | cutoff (float): cutoff frequency in radians (pi is max) 397 | kernel_size (int): horizontal and vertical size, must be odd. 398 | pad_to (int): pad kernel size to desired size, must be odd or zero. 399 | """ 400 | assert kernel_size % 2 == 1, 'Kernel size must be an odd number.' 401 | kernel = np.fromfunction( 402 | lambda x, y: cutoff * special.j1(cutoff * np.sqrt( 403 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt( 404 | (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size]) 405 | kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi) 406 | kernel = kernel / np.sum(kernel) 407 | if pad_to > kernel_size: 408 | pad_size = (pad_to - kernel_size) // 2 409 | kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size))) 410 | return kernel 411 | 412 | 413 | # ------------------------------------------------------------- # 414 | # --------------------------- noise --------------------------- # 415 | # ------------------------------------------------------------- # 416 | 417 | # ----------------------- Gaussian Noise ----------------------- # 418 | 419 | 420 | def generate_gaussian_noise(img, sigma=10, gray_noise=False): 421 | """Generate Gaussian noise. 422 | 423 | Args: 424 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 425 | sigma (float): Noise scale (measured in range 255). Default: 10. 426 | 427 | Returns: 428 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 429 | float32. 430 | """ 431 | if gray_noise: 432 | noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255. 433 | noise = np.expand_dims(noise, axis=2).repeat(3, axis=2) 434 | else: 435 | noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255. 436 | return noise 437 | 438 | 439 | def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False): 440 | """Add Gaussian noise. 441 | 442 | Args: 443 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 444 | sigma (float): Noise scale (measured in range 255). Default: 10. 445 | 446 | Returns: 447 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 448 | float32. 449 | """ 450 | noise = generate_gaussian_noise(img, sigma, gray_noise) 451 | out = img + noise 452 | if clip and rounds: 453 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 454 | elif clip: 455 | out = np.clip(out, 0, 1) 456 | elif rounds: 457 | out = (out * 255.0).round() / 255. 458 | return out 459 | 460 | 461 | def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0): 462 | """Add Gaussian noise (PyTorch version). 463 | 464 | Args: 465 | img (Tensor): Shape (b, c, h, w), range[0, 1], float32. 466 | scale (float | Tensor): Noise scale. Default: 1.0. 467 | 468 | Returns: 469 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 470 | float32. 471 | """ 472 | b, _, h, w = img.size() 473 | if not isinstance(sigma, (float, int)): 474 | sigma = sigma.view(img.size(0), 1, 1, 1) 475 | if isinstance(gray_noise, (float, int)): 476 | cal_gray_noise = gray_noise > 0 477 | else: 478 | gray_noise = gray_noise.view(b, 1, 1, 1) 479 | cal_gray_noise = torch.sum(gray_noise) > 0 480 | 481 | if cal_gray_noise: 482 | noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255. 483 | noise_gray = noise_gray.view(b, 1, h, w) 484 | 485 | # always calculate color noise 486 | noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255. 487 | 488 | if cal_gray_noise: 489 | noise = noise * (1 - gray_noise) + noise_gray * gray_noise 490 | return noise 491 | 492 | 493 | def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False): 494 | """Add Gaussian noise (PyTorch version). 495 | 496 | Args: 497 | img (Tensor): Shape (b, c, h, w), range[0, 1], float32. 498 | scale (float | Tensor): Noise scale. Default: 1.0. 499 | 500 | Returns: 501 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 502 | float32. 503 | """ 504 | noise = generate_gaussian_noise_pt(img, sigma, gray_noise) 505 | out = img + noise 506 | if clip and rounds: 507 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 508 | elif clip: 509 | out = torch.clamp(out, 0, 1) 510 | elif rounds: 511 | out = (out * 255.0).round() / 255. 512 | return out 513 | 514 | 515 | # ----------------------- Random Gaussian Noise ----------------------- # 516 | def random_generate_gaussian_noise(img, sigma_range=(0, 10), gray_prob=0): 517 | sigma = np.random.uniform(sigma_range[0], sigma_range[1]) 518 | if np.random.uniform() < gray_prob: 519 | gray_noise = True 520 | else: 521 | gray_noise = False 522 | return generate_gaussian_noise(img, sigma, gray_noise) 523 | 524 | 525 | def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 526 | noise = random_generate_gaussian_noise(img, sigma_range, gray_prob) 527 | out = img + noise 528 | if clip and rounds: 529 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 530 | elif clip: 531 | out = np.clip(out, 0, 1) 532 | elif rounds: 533 | out = (out * 255.0).round() / 255. 534 | return out 535 | 536 | 537 | def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0): 538 | sigma = torch.rand( 539 | img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0] 540 | gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) 541 | gray_noise = (gray_noise < gray_prob).float() 542 | return generate_gaussian_noise_pt(img, sigma, gray_noise) 543 | 544 | 545 | def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 546 | noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob) 547 | out = img + noise 548 | if clip and rounds: 549 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 550 | elif clip: 551 | out = torch.clamp(out, 0, 1) 552 | elif rounds: 553 | out = (out * 255.0).round() / 255. 554 | return out 555 | 556 | 557 | # ----------------------- Poisson (Shot) Noise ----------------------- # 558 | 559 | 560 | def generate_poisson_noise(img, scale=1.0, gray_noise=False): 561 | """Generate poisson noise. 562 | 563 | Ref: https://github.com/scikit-image/scikit-image/blob/main/skimage/util/noise.py#L37-L219 564 | 565 | Args: 566 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 567 | scale (float): Noise scale. Default: 1.0. 568 | gray_noise (bool): Whether generate gray noise. Default: False. 569 | 570 | Returns: 571 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 572 | float32. 573 | """ 574 | if gray_noise: 575 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 576 | # round and clip image for counting vals correctly 577 | img = np.clip((img * 255.0).round(), 0, 255) / 255. 578 | vals = len(np.unique(img)) 579 | vals = 2**np.ceil(np.log2(vals)) 580 | out = np.float32(np.random.poisson(img * vals) / float(vals)) 581 | noise = out - img 582 | if gray_noise: 583 | noise = np.repeat(noise[:, :, np.newaxis], 3, axis=2) 584 | return noise * scale 585 | 586 | 587 | def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False): 588 | """Add poisson noise. 589 | 590 | Args: 591 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 592 | scale (float): Noise scale. Default: 1.0. 593 | gray_noise (bool): Whether generate gray noise. Default: False. 594 | 595 | Returns: 596 | (Numpy array): Returned noisy image, shape (h, w, c), range[0, 1], 597 | float32. 598 | """ 599 | noise = generate_poisson_noise(img, scale, gray_noise) 600 | out = img + noise 601 | if clip and rounds: 602 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 603 | elif clip: 604 | out = np.clip(out, 0, 1) 605 | elif rounds: 606 | out = (out * 255.0).round() / 255. 607 | return out 608 | 609 | 610 | def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0): 611 | """Generate a batch of poisson noise (PyTorch version) 612 | 613 | Args: 614 | img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. 615 | scale (float | Tensor): Noise scale. Number or Tensor with shape (b). 616 | Default: 1.0. 617 | gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). 618 | 0 for False, 1 for True. Default: 0. 619 | 620 | Returns: 621 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 622 | float32. 623 | """ 624 | b, _, h, w = img.size() 625 | if isinstance(gray_noise, (float, int)): 626 | cal_gray_noise = gray_noise > 0 627 | else: 628 | gray_noise = gray_noise.view(b, 1, 1, 1) 629 | cal_gray_noise = torch.sum(gray_noise) > 0 630 | if cal_gray_noise: 631 | img_gray = rgb_to_grayscale(img, num_output_channels=1) 632 | # round and clip image for counting vals correctly 633 | img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255. 634 | # use for-loop to get the unique values for each sample 635 | vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)] 636 | vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] 637 | vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1) 638 | out = torch.poisson(img_gray * vals) / vals 639 | noise_gray = out - img_gray 640 | noise_gray = noise_gray.expand(b, 3, h, w) 641 | 642 | # always calculate color noise 643 | # round and clip image for counting vals correctly 644 | img = torch.clamp((img * 255.0).round(), 0, 255) / 255. 645 | # use for-loop to get the unique values for each sample 646 | vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)] 647 | vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list] 648 | vals = img.new_tensor(vals_list).view(b, 1, 1, 1) 649 | out = torch.poisson(img * vals) / vals 650 | noise = out - img 651 | if cal_gray_noise: 652 | noise = noise * (1 - gray_noise) + noise_gray * gray_noise 653 | if not isinstance(scale, (float, int)): 654 | scale = scale.view(b, 1, 1, 1) 655 | return noise * scale 656 | 657 | 658 | def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0): 659 | """Add poisson noise to a batch of images (PyTorch version). 660 | 661 | Args: 662 | img (Tensor): Input image, shape (b, c, h, w), range [0, 1], float32. 663 | scale (float | Tensor): Noise scale. Number or Tensor with shape (b). 664 | Default: 1.0. 665 | gray_noise (float | Tensor): 0-1 number or Tensor with shape (b). 666 | 0 for False, 1 for True. Default: 0. 667 | 668 | Returns: 669 | (Tensor): Returned noisy image, shape (b, c, h, w), range[0, 1], 670 | float32. 671 | """ 672 | noise = generate_poisson_noise_pt(img, scale, gray_noise) 673 | out = img + noise 674 | if clip and rounds: 675 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 676 | elif clip: 677 | out = torch.clamp(out, 0, 1) 678 | elif rounds: 679 | out = (out * 255.0).round() / 255. 680 | return out 681 | 682 | 683 | # ----------------------- Random Poisson (Shot) Noise ----------------------- # 684 | 685 | 686 | def random_generate_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0): 687 | scale = np.random.uniform(scale_range[0], scale_range[1]) 688 | if np.random.uniform() < gray_prob: 689 | gray_noise = True 690 | else: 691 | gray_noise = False 692 | return generate_poisson_noise(img, scale, gray_noise) 693 | 694 | 695 | def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 696 | noise = random_generate_poisson_noise(img, scale_range, gray_prob) 697 | out = img + noise 698 | if clip and rounds: 699 | out = np.clip((out * 255.0).round(), 0, 255) / 255. 700 | elif clip: 701 | out = np.clip(out, 0, 1) 702 | elif rounds: 703 | out = (out * 255.0).round() / 255. 704 | return out 705 | 706 | 707 | def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0): 708 | scale = torch.rand( 709 | img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0] 710 | gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device) 711 | gray_noise = (gray_noise < gray_prob).float() 712 | return generate_poisson_noise_pt(img, scale, gray_noise) 713 | 714 | 715 | def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=True, rounds=False): 716 | noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob) 717 | out = img + noise 718 | if clip and rounds: 719 | out = torch.clamp((out * 255.0).round(), 0, 255) / 255. 720 | elif clip: 721 | out = torch.clamp(out, 0, 1) 722 | elif rounds: 723 | out = (out * 255.0).round() / 255. 724 | return out 725 | 726 | 727 | # ------------------------------------------------------------------------ # 728 | # --------------------------- JPEG compression --------------------------- # 729 | # ------------------------------------------------------------------------ # 730 | 731 | 732 | def add_jpg_compression(img, quality=90): 733 | """Add JPG compression artifacts. 734 | 735 | Args: 736 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 737 | quality (float): JPG compression quality. 0 for lowest quality, 100 for 738 | best quality. Default: 90. 739 | 740 | Returns: 741 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], 742 | float32. 743 | """ 744 | img = np.clip(img, 0, 1) 745 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 746 | _, encimg = cv2.imencode('.jpg', img * 255., encode_param) 747 | img = np.float32(cv2.imdecode(encimg, 1)) / 255. 748 | return img 749 | 750 | 751 | def random_add_jpg_compression(img, quality_range=(90, 100)): 752 | """Randomly add JPG compression artifacts. 753 | 754 | Args: 755 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 756 | quality_range (tuple[float] | list[float]): JPG compression quality 757 | range. 0 for lowest quality, 100 for best quality. 758 | Default: (90, 100). 759 | 760 | Returns: 761 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], 762 | float32. 763 | """ 764 | quality = np.random.uniform(quality_range[0], quality_range[1]) 765 | return add_jpg_compression(img, quality) 766 | -------------------------------------------------------------------------------- /src/utils/utils_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import random 4 | import numpy as np 5 | import torch 6 | import cv2 7 | from numpy import cov 8 | from numpy import trace 9 | from numpy import iscomplexobj 10 | from scipy.linalg import sqrtm 11 | from torchvision.utils import make_grid 12 | from datetime import datetime 13 | # import torchvision.transforms as transforms 14 | import matplotlib.pyplot as plt 15 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 16 | 17 | 18 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tif'] 19 | 20 | 21 | def is_image_file(filename): 22 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 23 | 24 | 25 | def get_timestamp(): 26 | return datetime.now().strftime('%y%m%d-%H%M%S') 27 | 28 | 29 | def imshow(x, title=None, cbar=False, figsize=None): 30 | plt.figure(figsize=figsize) 31 | plt.imshow(np.squeeze(x), interpolation='nearest', cmap='gray') 32 | if title: 33 | plt.title(title) 34 | if cbar: 35 | plt.colorbar() 36 | plt.show() 37 | 38 | 39 | def surf(Z, cmap='rainbow', figsize=None): 40 | plt.figure(figsize=figsize) 41 | ax3 = plt.axes(projection='3d') 42 | 43 | w, h = Z.shape[:2] 44 | xx = np.arange(0,w,1) 45 | yy = np.arange(0,h,1) 46 | X, Y = np.meshgrid(xx, yy) 47 | ax3.plot_surface(X,Y,Z,cmap=cmap) 48 | #ax3.contour(X,Y,Z, zdim='z',offset=-2,cmap=cmap) 49 | plt.show() 50 | 51 | 52 | ''' 53 | # -------------------------------------------- 54 | # get image pathes 55 | # -------------------------------------------- 56 | ''' 57 | 58 | 59 | def get_image_paths(dataroot): 60 | paths = None # return None if dataroot is None 61 | if dataroot is not None: 62 | paths = sorted(_get_paths_from_images(dataroot)) 63 | return paths 64 | 65 | 66 | def _get_paths_from_images(path): 67 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 68 | images = [] 69 | for dirpath, _, fnames in sorted(os.walk(path)): 70 | for fname in sorted(fnames): 71 | if is_image_file(fname): 72 | img_path = os.path.join(dirpath, fname) 73 | images.append(img_path) 74 | assert images, '{:s} has no valid image file'.format(path) 75 | return images 76 | 77 | 78 | ''' 79 | # -------------------------------------------- 80 | # split large images into small images 81 | # -------------------------------------------- 82 | ''' 83 | 84 | 85 | def patches_from_image(img, p_size=512, p_overlap=64, p_max=800): 86 | w, h = img.shape[:2] 87 | patches = [] 88 | if w > p_max and h > p_max: 89 | w1 = list(np.arange(0, w-p_size, p_size-p_overlap, dtype=np.int)) 90 | h1 = list(np.arange(0, h-p_size, p_size-p_overlap, dtype=np.int)) 91 | w1.append(w-p_size) 92 | h1.append(h-p_size) 93 | # print(w1) 94 | # print(h1) 95 | for i in w1: 96 | for j in h1: 97 | patches.append(img[i:i+p_size, j:j+p_size,:]) 98 | else: 99 | patches.append(img) 100 | 101 | return patches 102 | 103 | 104 | def imssave(imgs, img_path): 105 | """ 106 | imgs: list, N images of size WxHxC 107 | """ 108 | img_name, ext = os.path.splitext(os.path.basename(img_path)) 109 | 110 | for i, img in enumerate(imgs): 111 | if img.ndim == 3: 112 | img = img[:, :, [2, 1, 0]] 113 | new_path = os.path.join(os.path.dirname(img_path), img_name+str('_s{:04d}'.format(i))+'.png') 114 | cv2.imwrite(new_path, img) 115 | 116 | 117 | def split_imageset(original_dataroot, taget_dataroot, n_channels=3, p_size=800, p_overlap=96, p_max=1000): 118 | """ 119 | split the large images from original_dataroot into small overlapped images with size (p_size)x(p_size), 120 | and save them into taget_dataroot; only the images with larger size than (p_max)x(p_max) 121 | will be splitted. 122 | 123 | Args: 124 | original_dataroot: 125 | taget_dataroot: 126 | p_size: size of small images 127 | p_overlap: patch size in training is a good choice 128 | p_max: images with smaller size than (p_max)x(p_max) keep unchanged. 129 | """ 130 | paths = get_image_paths(original_dataroot) 131 | for img_path in paths: 132 | # img_name, ext = os.path.splitext(os.path.basename(img_path)) 133 | img = imread_uint(img_path, n_channels=n_channels) 134 | patches = patches_from_image(img, p_size, p_overlap, p_max) 135 | imssave(patches, os.path.join(taget_dataroot,os.path.basename(img_path))) 136 | #if original_dataroot == taget_dataroot: 137 | #del img_path 138 | 139 | ''' 140 | # -------------------------------------------- 141 | # makedir 142 | # -------------------------------------------- 143 | ''' 144 | 145 | 146 | def mkdir(path): 147 | if not os.path.exists(path): 148 | os.makedirs(path) 149 | 150 | 151 | def mkdirs(paths): 152 | if isinstance(paths, str): 153 | mkdir(paths) 154 | else: 155 | for path in paths: 156 | mkdir(path) 157 | 158 | 159 | def mkdir_and_rename(path): 160 | if os.path.exists(path): 161 | new_name = path + '_archived_' + get_timestamp() 162 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 163 | os.rename(path, new_name) 164 | os.makedirs(path) 165 | 166 | 167 | ''' 168 | # -------------------------------------------- 169 | # read image from path 170 | # opencv is fast, but read BGR numpy image 171 | # -------------------------------------------- 172 | ''' 173 | 174 | 175 | # -------------------------------------------- 176 | # get uint8 image of size HxWxn_channles (RGB) 177 | # -------------------------------------------- 178 | def imread_uint(path, n_channels=3): 179 | # input: path 180 | # output: HxWx3(RGB or GGG), or HxWx1 (G) 181 | if n_channels == 1: 182 | img = cv2.imread(path, 0) # cv2.IMREAD_GRAYSCALE 183 | img = np.expand_dims(img, axis=2) # HxWx1 184 | elif n_channels == 3: 185 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR or G 186 | if img.ndim == 2: 187 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) # GGG 188 | else: 189 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # RGB 190 | return img 191 | 192 | 193 | # -------------------------------------------- 194 | # matlab's imwrite 195 | # -------------------------------------------- 196 | def imsave(img, img_path): 197 | img = np.squeeze(img) 198 | if img.ndim == 3: 199 | img = img[:, :, [2, 1, 0]] 200 | cv2.imwrite(img_path, img) 201 | 202 | def imwrite(img, img_path): 203 | img = np.squeeze(img) 204 | if img.ndim == 3: 205 | img = img[:, :, [2, 1, 0]] 206 | cv2.imwrite(img_path, img) 207 | 208 | 209 | 210 | # -------------------------------------------- 211 | # get single image of size HxWxn_channles (BGR) 212 | # -------------------------------------------- 213 | def read_img(path): 214 | # read image by cv2 215 | # return: Numpy float32, HWC, BGR, [0,1] 216 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # cv2.IMREAD_GRAYSCALE 217 | img = img.astype(np.float32) / 255. 218 | if img.ndim == 2: 219 | img = np.expand_dims(img, axis=2) 220 | # some images have 4 channels 221 | if img.shape[2] > 3: 222 | img = img[:, :, :3] 223 | return img 224 | 225 | 226 | ''' 227 | # -------------------------------------------- 228 | # image format conversion 229 | # -------------------------------------------- 230 | # numpy(single) <---> numpy(unit) 231 | # numpy(single) <---> tensor 232 | # numpy(unit) <---> tensor 233 | # -------------------------------------------- 234 | ''' 235 | 236 | 237 | # -------------------------------------------- 238 | # numpy(single) [0, 1] <---> numpy(unit) 239 | # -------------------------------------------- 240 | 241 | 242 | def uint2single(img): 243 | 244 | return np.float32(img/255.) 245 | 246 | 247 | def single2uint(img): 248 | 249 | return np.uint8((img.clip(0, 1)*255.).round()) 250 | 251 | 252 | def uint162single(img): 253 | 254 | return np.float32(img/65535.) 255 | 256 | 257 | def single2uint16(img): 258 | 259 | return np.uint16((img.clip(0, 1)*65535.).round()) 260 | 261 | 262 | # -------------------------------------------- 263 | # numpy(unit) (HxWxC or HxW) <---> tensor 264 | # -------------------------------------------- 265 | 266 | 267 | # convert uint to 4-dimensional torch tensor 268 | def uint2tensor4(img): 269 | if img.ndim == 2: 270 | img = np.expand_dims(img, axis=2) 271 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.).unsqueeze(0) 272 | 273 | 274 | # convert uint to 3-dimensional torch tensor 275 | def uint2tensor3(img): 276 | if img.ndim == 2: 277 | img = np.expand_dims(img, axis=2) 278 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().div(255.) 279 | 280 | 281 | # convert 2/3/4-dimensional torch tensor to uint 282 | def tensor2uint(img): 283 | img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() 284 | if img.ndim == 3: 285 | img = np.transpose(img, (1, 2, 0)) 286 | return np.uint8((img*255.0).round()) 287 | 288 | 289 | # -------------------------------------------- 290 | # numpy(single) (HxWxC) <---> tensor 291 | # -------------------------------------------- 292 | 293 | 294 | # convert single (HxWxC) to 3-dimensional torch tensor 295 | def single2tensor3(img): 296 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float() 297 | 298 | 299 | # convert single (HxWxC) to 4-dimensional torch tensor 300 | def single2tensor4(img): 301 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1).float().unsqueeze(0) 302 | 303 | 304 | # convert torch tensor to single 305 | def tensor2single(img): 306 | img = img.data.squeeze().float().cpu().numpy() 307 | if img.ndim == 3: 308 | img = np.transpose(img, (1, 2, 0)) 309 | 310 | return img 311 | 312 | # convert torch tensor to single 313 | def tensor2single3(img): 314 | img = img.data.squeeze().float().cpu().numpy() 315 | if img.ndim == 3: 316 | img = np.transpose(img, (1, 2, 0)) 317 | elif img.ndim == 2: 318 | img = np.expand_dims(img, axis=2) 319 | return img 320 | 321 | 322 | def single2tensor5(img): 323 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float().unsqueeze(0) 324 | 325 | 326 | def single32tensor5(img): 327 | return torch.from_numpy(np.ascontiguousarray(img)).float().unsqueeze(0).unsqueeze(0) 328 | 329 | 330 | def single42tensor4(img): 331 | return torch.from_numpy(np.ascontiguousarray(img)).permute(2, 0, 1, 3).float() 332 | 333 | 334 | # from skimage.io import imread, imsave 335 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 336 | ''' 337 | Converts a torch Tensor into an image Numpy array of BGR channel order 338 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 339 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 340 | ''' 341 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # squeeze first, then clamp 342 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 343 | n_dim = tensor.dim() 344 | if n_dim == 4: 345 | n_img = len(tensor) 346 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 347 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 348 | elif n_dim == 3: 349 | img_np = tensor.numpy() 350 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 351 | elif n_dim == 2: 352 | img_np = tensor.numpy() 353 | else: 354 | raise TypeError( 355 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 356 | if out_type == np.uint8: 357 | img_np = (img_np * 255.0).round() 358 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 359 | return img_np.astype(out_type) 360 | 361 | # quantize tensor using the round operation 362 | def quantize(img, rgb_range=255): 363 | ''' 364 | args: 365 | img - tensor, [0-255] 366 | return: 367 | img - tensor, [0-255] 368 | ''' 369 | pixel_range = 255 / rgb_range 370 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 371 | 372 | 373 | ''' 374 | # -------------------------------------------- 375 | # Augmentation, flipe and/or rotate 376 | # -------------------------------------------- 377 | # The following two are enough. 378 | # (1) augmet_img: numpy image of WxHxC or WxH 379 | # (2) augment_img_tensor4: tensor image 1xCxWxH 380 | # -------------------------------------------- 381 | ''' 382 | 383 | 384 | def augment_img(img, mode=0): 385 | '''Kai Zhang (github: https://github.com/cszn) 386 | ''' 387 | if mode == 0: 388 | return img 389 | elif mode == 1: 390 | return np.flipud(np.rot90(img)) 391 | elif mode == 2: 392 | return np.flipud(img) 393 | elif mode == 3: 394 | return np.rot90(img, k=3) 395 | elif mode == 4: 396 | return np.flipud(np.rot90(img, k=2)) 397 | elif mode == 5: 398 | return np.rot90(img) 399 | elif mode == 6: 400 | return np.rot90(img, k=2) 401 | elif mode == 7: 402 | return np.flipud(np.rot90(img, k=3)) 403 | 404 | 405 | def augment_img_tensor4(img, mode=0): 406 | '''Kai Zhang (github: https://github.com/cszn) 407 | ''' 408 | if mode == 0: 409 | return img 410 | elif mode == 1: 411 | return img.rot90(1, [2, 3]).flip([2]) 412 | elif mode == 2: 413 | return img.flip([2]) 414 | elif mode == 3: 415 | return img.rot90(3, [2, 3]) 416 | elif mode == 4: 417 | return img.rot90(2, [2, 3]).flip([2]) 418 | elif mode == 5: 419 | return img.rot90(1, [2, 3]) 420 | elif mode == 6: 421 | return img.rot90(2, [2, 3]) 422 | elif mode == 7: 423 | return img.rot90(3, [2, 3]).flip([2]) 424 | 425 | 426 | def augment_img_tensor(img, mode=0): 427 | '''Kai Zhang (github: https://github.com/cszn) 428 | ''' 429 | img_size = img.size() 430 | img_np = img.data.cpu().numpy() 431 | if len(img_size) == 3: 432 | img_np = np.transpose(img_np, (1, 2, 0)) 433 | elif len(img_size) == 4: 434 | img_np = np.transpose(img_np, (2, 3, 1, 0)) 435 | img_np = augment_img(img_np, mode=mode) 436 | img_tensor = torch.from_numpy(np.ascontiguousarray(img_np)) 437 | if len(img_size) == 3: 438 | img_tensor = img_tensor.permute(2, 0, 1) 439 | elif len(img_size) == 4: 440 | img_tensor = img_tensor.permute(3, 2, 0, 1) 441 | 442 | return img_tensor.type_as(img) 443 | 444 | 445 | def augment_img_np3(img, mode=0): 446 | if mode == 0: 447 | return img 448 | elif mode == 1: 449 | return img.transpose(1, 0, 2) 450 | elif mode == 2: 451 | return img[::-1, :, :] 452 | elif mode == 3: 453 | img = img[::-1, :, :] 454 | img = img.transpose(1, 0, 2) 455 | return img 456 | elif mode == 4: 457 | return img[:, ::-1, :] 458 | elif mode == 5: 459 | img = img[:, ::-1, :] 460 | img = img.transpose(1, 0, 2) 461 | return img 462 | elif mode == 6: 463 | img = img[:, ::-1, :] 464 | img = img[::-1, :, :] 465 | return img 466 | elif mode == 7: 467 | img = img[:, ::-1, :] 468 | img = img[::-1, :, :] 469 | img = img.transpose(1, 0, 2) 470 | return img 471 | 472 | 473 | def augment_imgs(img_list, hflip=True, rot=True): 474 | # horizontal flip OR rotate 475 | hflip = hflip and random.random() < 0.5 476 | vflip = rot and random.random() < 0.5 477 | rot90 = rot and random.random() < 0.5 478 | 479 | def _augment(img): 480 | if hflip: 481 | img = img[:, ::-1, :] 482 | if vflip: 483 | img = img[::-1, :, :] 484 | if rot90: 485 | img = img.transpose(1, 0, 2) 486 | return img 487 | 488 | return [_augment(img) for img in img_list] 489 | 490 | 491 | ''' 492 | # -------------------------------------------- 493 | # modcrop and shave 494 | # -------------------------------------------- 495 | ''' 496 | 497 | 498 | def modcrop(img_in, scale): 499 | # img_in: Numpy, HWC or HW 500 | img = np.copy(img_in) 501 | if img.ndim == 2: 502 | H, W = img.shape 503 | H_r, W_r = H % scale, W % scale 504 | img = img[:H - H_r, :W - W_r] 505 | elif img.ndim == 3: 506 | H, W, C = img.shape 507 | H_r, W_r = H % scale, W % scale 508 | img = img[:H - H_r, :W - W_r, :] 509 | else: 510 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 511 | return img 512 | 513 | 514 | def shave(img_in, border=0): 515 | # img_in: Numpy, HWC or HW 516 | img = np.copy(img_in) 517 | h, w = img.shape[:2] 518 | img = img[border:h-border, border:w-border] 519 | return img 520 | 521 | 522 | ''' 523 | # -------------------------------------------- 524 | # image processing process on numpy image 525 | # channel_convert(in_c, tar_type, img_list): 526 | # rgb2ycbcr(img, only_y=True): 527 | # bgr2ycbcr(img, only_y=True): 528 | # ycbcr2rgb(img): 529 | # -------------------------------------------- 530 | ''' 531 | 532 | 533 | def rgb2ycbcr(img, only_y=True): 534 | '''same as matlab rgb2ycbcr 535 | only_y: only return Y channel 536 | Input: 537 | uint8, [0, 255] 538 | float, [0, 1] 539 | ''' 540 | in_img_type = img.dtype 541 | img.astype(np.float32) 542 | if in_img_type != np.uint8: 543 | img *= 255. 544 | # convert 545 | if only_y: 546 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 547 | else: 548 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 549 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 550 | if in_img_type == np.uint8: 551 | rlt = rlt.round() 552 | else: 553 | rlt /= 255. 554 | return rlt.astype(in_img_type) 555 | 556 | 557 | def ycbcr2rgb(img): 558 | '''same as matlab ycbcr2rgb 559 | Input: 560 | uint8, [0, 255] 561 | float, [0, 1] 562 | ''' 563 | in_img_type = img.dtype 564 | img.astype(np.float32) 565 | if in_img_type != np.uint8: 566 | img *= 255. 567 | # convert 568 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 569 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 570 | if in_img_type == np.uint8: 571 | rlt = rlt.round() 572 | else: 573 | rlt /= 255. 574 | return rlt.astype(in_img_type) 575 | 576 | 577 | def bgr2ycbcr(img, only_y=True): 578 | '''bgr version of rgb2ycbcr 579 | only_y: only return Y channel 580 | Input: 581 | uint8, [0, 255] 582 | float, [0, 1] 583 | ''' 584 | in_img_type = img.dtype 585 | img.astype(np.float32) 586 | if in_img_type != np.uint8: 587 | img *= 255. 588 | # convert 589 | if only_y: 590 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 591 | else: 592 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 593 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 594 | if in_img_type == np.uint8: 595 | rlt = rlt.round() 596 | else: 597 | rlt /= 255. 598 | return rlt.astype(in_img_type) 599 | 600 | 601 | def channel_convert(in_c, tar_type, img_list): 602 | # conversion among BGR, gray and y 603 | if in_c == 3 and tar_type == 'gray': # BGR to gray 604 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] 605 | return [np.expand_dims(img, axis=2) for img in gray_list] 606 | elif in_c == 3 and tar_type == 'y': # BGR to y 607 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] 608 | return [np.expand_dims(img, axis=2) for img in y_list] 609 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR 610 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] 611 | else: 612 | return img_list 613 | 614 | 615 | ''' 616 | # -------------------------------------------- 617 | # metric, PSNR and SSIM 618 | # -------------------------------------------- 619 | ''' 620 | def calculate_psnr(img1, img2): 621 | "" 622 | # img1 and img2 have range [0, 255] 623 | img1 = img1.astype(np.float64) 624 | img2 = img2.astype(np.float64) 625 | mse = np.mean((img1 - img2) ** 2) 626 | if mse == 0: 627 | return float("inf") 628 | return 20 * math.log10(255.0 / math.sqrt(mse)) 629 | 630 | 631 | def calculate_ssim(img1, img2): 632 | "only use to calculate ssim in y channel" 633 | C1 = (0.01 * 255) ** 2 634 | C2 = (0.03 * 255) ** 2 635 | 636 | img1 = img1.astype(np.float64) 637 | img2 = img2.astype(np.float64) 638 | kernel = cv2.getGaussianKernel(11, 1.5) 639 | window = np.outer(kernel, kernel.transpose()) 640 | 641 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 642 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 643 | mu1_sq = mu1 ** 2 644 | mu2_sq = mu2 ** 2 645 | mu1_mu2 = mu1 * mu2 646 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 647 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 648 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 649 | 650 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 651 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 652 | ) 653 | return ssim_map.mean() 654 | 655 | def calc_psnr_ssim(img1, img2, scale=2): 656 | assert scale != 0 657 | img1_crop = img1[scale:-scale, scale:-scale] 658 | img2_crop = img2[scale:-scale, scale:-scale] 659 | psnr = calculate_psnr(img1_crop, img2_crop) 660 | ssim = calculate_ssim(img1_crop, img2_crop) 661 | return psnr, ssim 662 | 663 | ''' 664 | # -------------------------------------------- 665 | # matlab's bicubic imresize (numpy and torch) [0, 1] 666 | # -------------------------------------------- 667 | ''' 668 | 669 | 670 | # matlab 'imresize' function, now only support 'bicubic' 671 | def cubic(x): 672 | absx = torch.abs(x) 673 | absx2 = absx**2 674 | absx3 = absx**3 675 | return (1.5*absx3 - 2.5*absx2 + 1) * ((absx <= 1).type_as(absx)) + \ 676 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * (((absx > 1)*(absx <= 2)).type_as(absx)) 677 | 678 | 679 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 680 | if (scale < 1) and (antialiasing): 681 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width 682 | kernel_width = kernel_width / scale 683 | 684 | # Output-space coordinates 685 | x = torch.linspace(1, out_length, out_length) 686 | 687 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 688 | # in output space maps to 0.5 in input space, and 0.5+scale in output 689 | # space maps to 1.5 in input space. 690 | u = x / scale + 0.5 * (1 - 1 / scale) 691 | 692 | # What is the left-most pixel that can be involved in the computation? 693 | left = torch.floor(u - kernel_width / 2) 694 | 695 | # What is the maximum number of pixels that can be involved in the 696 | # computation? Note: it's OK to use an extra pixel here; if the 697 | # corresponding weights are all zero, it will be eliminated at the end 698 | # of this function. 699 | P = math.ceil(kernel_width) + 2 700 | 701 | # The indices of the input pixels involved in computing the k-th output 702 | # pixel are in row k of the indices matrix. 703 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( 704 | 1, P).expand(out_length, P) 705 | 706 | # The weights used to compute the k-th output pixel are in row k of the 707 | # weights matrix. 708 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices 709 | # apply cubic kernel 710 | if (scale < 1) and (antialiasing): 711 | weights = scale * cubic(distance_to_center * scale) 712 | else: 713 | weights = cubic(distance_to_center) 714 | # Normalize the weights matrix so that each row sums to 1. 715 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 716 | weights = weights / weights_sum.expand(out_length, P) 717 | 718 | # If a column in weights is all zero, get rid of it. only consider the first and last column. 719 | weights_zero_tmp = torch.sum((weights == 0), 0) 720 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 721 | indices = indices.narrow(1, 1, P - 2) 722 | weights = weights.narrow(1, 1, P - 2) 723 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 724 | indices = indices.narrow(1, 0, P - 2) 725 | weights = weights.narrow(1, 0, P - 2) 726 | weights = weights.contiguous() 727 | indices = indices.contiguous() 728 | sym_len_s = -indices.min() + 1 729 | sym_len_e = indices.max() - in_length 730 | indices = indices + sym_len_s - 1 731 | return weights, indices, int(sym_len_s), int(sym_len_e) 732 | 733 | 734 | # -------------------------------------------- 735 | # imresize for tensor image [0, 1] 736 | # -------------------------------------------- 737 | def imresize(img, scale, antialiasing=True): 738 | # Now the scale should be the same for H and W 739 | # input: img: pytorch tensor, CHW or HW [0,1] 740 | # output: CHW or HW [0,1] w/o round 741 | need_squeeze = True if img.dim() == 2 else False 742 | if need_squeeze: 743 | img.unsqueeze_(0) 744 | in_C, in_H, in_W = img.size() 745 | out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 746 | kernel_width = 4 747 | kernel = 'cubic' 748 | 749 | # Return the desired dimension order for performing the resize. The 750 | # strategy is to perform the resize first along the dimension with the 751 | # smallest scale factor. 752 | # Now we do not support this. 753 | 754 | # get weights and indices 755 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 756 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 757 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 758 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 759 | # process H dimension 760 | # symmetric copying 761 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) 762 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) 763 | 764 | sym_patch = img[:, :sym_len_Hs, :] 765 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 766 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 767 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) 768 | 769 | sym_patch = img[:, -sym_len_He:, :] 770 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 771 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 772 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 773 | 774 | out_1 = torch.FloatTensor(in_C, out_H, in_W) 775 | kernel_width = weights_H.size(1) 776 | for i in range(out_H): 777 | idx = int(indices_H[i][0]) 778 | for j in range(out_C): 779 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 780 | 781 | # process W dimension 782 | # symmetric copying 783 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) 784 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) 785 | 786 | sym_patch = out_1[:, :, :sym_len_Ws] 787 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 788 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 789 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) 790 | 791 | sym_patch = out_1[:, :, -sym_len_We:] 792 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 793 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 794 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 795 | 796 | out_2 = torch.FloatTensor(in_C, out_H, out_W) 797 | kernel_width = weights_W.size(1) 798 | for i in range(out_W): 799 | idx = int(indices_W[i][0]) 800 | for j in range(out_C): 801 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_W[i]) 802 | if need_squeeze: 803 | out_2.squeeze_() 804 | return out_2 805 | 806 | 807 | # -------------------------------------------- 808 | # imresize for numpy image [0, 1] 809 | # -------------------------------------------- 810 | def imresize_np(img, scale, antialiasing=True): 811 | # Now the scale should be the same for H and W 812 | # input: img: Numpy, HWC or HW [0,1] 813 | # output: HWC or HW [0,1] w/o round 814 | img = torch.from_numpy(img) 815 | need_squeeze = True if img.dim() == 2 else False 816 | if need_squeeze: 817 | img.unsqueeze_(2) 818 | 819 | in_H, in_W, in_C = img.size() 820 | out_C, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 821 | kernel_width = 4 822 | kernel = 'cubic' 823 | 824 | # Return the desired dimension order for performing the resize. The 825 | # strategy is to perform the resize first along the dimension with the 826 | # smallest scale factor. 827 | # Now we do not support this. 828 | 829 | # get weights and indices 830 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 831 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 832 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 833 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 834 | # process H dimension 835 | # symmetric copying 836 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) 837 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) 838 | 839 | sym_patch = img[:sym_len_Hs, :, :] 840 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 841 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 842 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) 843 | 844 | sym_patch = img[-sym_len_He:, :, :] 845 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 846 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 847 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 848 | 849 | out_1 = torch.FloatTensor(out_H, in_W, in_C) 850 | kernel_width = weights_H.size(1) 851 | for i in range(out_H): 852 | idx = int(indices_H[i][0]) 853 | for j in range(out_C): 854 | out_1[i, :, j] = img_aug[idx:idx + kernel_width, :, j].transpose(0, 1).mv(weights_H[i]) 855 | 856 | # process W dimension 857 | # symmetric copying 858 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) 859 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) 860 | 861 | sym_patch = out_1[:, :sym_len_Ws, :] 862 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 863 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 864 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) 865 | 866 | sym_patch = out_1[:, -sym_len_We:, :] 867 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 868 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 869 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 870 | 871 | out_2 = torch.FloatTensor(out_H, out_W, in_C) 872 | kernel_width = weights_W.size(1) 873 | for i in range(out_W): 874 | idx = int(indices_W[i][0]) 875 | for j in range(out_C): 876 | out_2[:, i, j] = out_1_aug[:, idx:idx + kernel_width, j].mv(weights_W[i]) 877 | if need_squeeze: 878 | out_2.squeeze_() 879 | 880 | return out_2.numpy() 881 | 882 | 883 | 884 | 885 | 886 | 887 | 888 | 889 | --------------------------------------------------------------------------------