├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── VisToyExample.ipynb ├── basicsr ├── .DS_Store ├── __init__.py ├── archs │ ├── RRDB_arch.py │ ├── __init__.py │ ├── adacode_arch.py │ ├── adacode_contrast_arch.py │ ├── arch_util.py │ ├── discriminator_arch.py │ ├── fema_utils.py │ ├── femasr_arch.py │ ├── network_swinir.py │ └── vgg_arch.py ├── data │ ├── __init__.py │ ├── bsrgan_train_dataset.py │ ├── bsrgan_util.py │ ├── data_sampler.py │ ├── data_util.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ ├── single_image_dataset.py │ └── transforms.py ├── losses │ ├── __init__.py │ ├── loss_util.py │ └── losses.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── femasr_model.py │ ├── lr_scheduler.py │ └── ori_femasr_model.py ├── test.py ├── train.py ├── train_mergedcodebook.py └── utils │ ├── __init__.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ └── registry.py ├── evaluate.py ├── figures ├── .DS_Store ├── model.jpeg └── teaser.jpeg ├── generate_dataset.py ├── generate_inpaint_dataset.py ├── inference.py ├── options ├── stage1 │ └── train_AdaCode_HQ_stage1_category.yml ├── stage2 │ └── train_AdaCode_HQ_stage2.yaml └── stage3 │ └── train_AdaCode_stage3.yaml ├── requirements.txt ├── scripts ├── .DS_Store ├── data_preparation │ ├── create_lmdb.py │ ├── download_datasets.py │ ├── extract_images_from_tfrecords.py │ ├── extract_subimages.py │ ├── generate_meta_info.py │ ├── prepare_SR_dataset.py │ ├── prepare_hifacegan_dataset.py │ └── regroup_reds_dataset.py ├── download_gdrive.py ├── download_pretrained_models.py ├── matlab_scripts │ ├── back_projection │ │ ├── backprojection.m │ │ ├── main_bp.m │ │ └── main_reverse_filter.m │ ├── generate_LR_Vimeo90K.m │ └── generate_bicubic_img.m ├── metrics │ ├── calculate_fid_folder.py │ ├── calculate_fid_stats_from_datasets.py │ ├── calculate_lpips.py │ ├── calculate_niqe.py │ ├── calculate_psnr_ssim.py │ └── calculate_stylegan2_fid.py ├── model_conversion │ ├── convert_dfdnet.py │ ├── convert_models.py │ ├── convert_ridnet.py │ └── convert_stylegan.py └── publish_models.py ├── setup.py ├── test_recon.py ├── testset ├── .DS_Store ├── inpaint │ ├── .DS_Store │ ├── 0801_s009.png │ ├── 0804_s003.png │ ├── 0804_s006.png │ ├── 0808_s010.png │ ├── 0809_s006.png │ ├── 0813_s009.png │ ├── 0817_s003.png │ ├── 0817_s012.png │ ├── 0830_s007.png │ ├── 0831_s006.png │ ├── 0833_s003.png │ ├── 0836_s001.png │ ├── 0836_s010.png │ ├── 0839_s012.png │ ├── 0842_s002.png │ ├── 0848_s009.png │ ├── 0849_s004.png │ ├── 0852_s010.png │ ├── 0856_s006.png │ ├── 0856_s010.png │ ├── 0863_s004.png │ ├── 0865_s011.png │ ├── 0866_s005.png │ ├── 0870_s005.png │ ├── 0874_s007.png │ └── 64010.png └── sr │ ├── 0003.jpg │ ├── 0014.jpg │ ├── 0015.jpg │ ├── 0030.jpg │ ├── 0032.jpg │ ├── 0054.jpg │ ├── 0068.jpg │ ├── 49370.png │ ├── ADE_val_00000015.jpg │ ├── ADE_val_00000114.jpg │ ├── Canon_004_LR4.png │ ├── Canon_045_LR4.png │ ├── OST_009.png │ ├── OST_020.png │ ├── OST_120.png │ ├── building.png │ ├── butterfly.png │ ├── butterfly2.png │ ├── chip.png │ ├── comic1.png │ ├── comic2.png │ ├── comic3.png │ ├── computer.png │ ├── dog.png │ ├── dped_crop00061.png │ ├── foreman.png │ ├── frog.png │ ├── oldphoto3.png │ ├── oldphoto6.png │ ├── painting.png │ ├── pattern.png │ ├── ppt3.png │ └── tiger.png └── vis_codebook.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | data/* 3 | datasets/* 4 | experiments* 5 | experiments/* 6 | test_results*/ 7 | results*/ 8 | tb_logger*/* 9 | wandb/* 10 | tmp*/* 11 | tmp* 12 | *.sh 13 | .vscode* 14 | .github 15 | */.DS_Store 16 | 17 | # ignored files 18 | version.py 19 | 20 | # ignored files with suffix 21 | *.html 22 | *.gif 23 | *.pth 24 | *.zip 25 | *.npy 26 | 27 | # template 28 | 29 | # Byte-compiled / optimized / DLL files 30 | __pycache__/ 31 | *.py[cod] 32 | *$py.class 33 | 34 | # C extensions 35 | *.so 36 | 37 | # Distribution / packaging 38 | .Python 39 | build/ 40 | develop-eggs/ 41 | dist/ 42 | downloads/ 43 | eggs/ 44 | .eggs/ 45 | lib/ 46 | lib64/ 47 | parts/ 48 | sdist/ 49 | var/ 50 | wheels/ 51 | *.egg-info/ 52 | .installed.cfg 53 | *.egg 54 | MANIFEST 55 | 56 | # PyInstaller 57 | # Usually these files are written by a python script from a template 58 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 59 | *.manifest 60 | *.spec 61 | 62 | # Installer logs 63 | pip-log.txt 64 | pip-delete-this-directory.txt 65 | 66 | # Unit test / coverage reports 67 | htmlcov/ 68 | .tox/ 69 | .coverage 70 | .coverage.* 71 | .cache 72 | nosetests.xml 73 | coverage.xml 74 | *.cover 75 | .hypothesis/ 76 | .pytest_cache/ 77 | 78 | # Translations 79 | *.mo 80 | *.pot 81 | 82 | # Django stuff: 83 | *.log 84 | local_settings.py 85 | db.sqlite3 86 | 87 | # Flask stuff: 88 | instance/ 89 | .webassets-cache 90 | 91 | # Scrapy stuff: 92 | .scrapy 93 | 94 | # Sphinx documentation 95 | docs/_build/ 96 | 97 | # PyBuilder 98 | target/ 99 | 100 | # Jupyter Notebook 101 | .ipynb_checkpoints 102 | 103 | # pyenv 104 | .python-version 105 | 106 | # celery beat schedule file 107 | celerybeat-schedule 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AdaCode 2 | 3 | This repository provides the official implementation for the paper 4 | 5 | > **Learning Image-Adaptive Codebooks for Class-Agnostic Image Restoration**
6 | > [Kechun Liu](https://kechunl.github.io/), Yitong Jiang, Inchang Choi, [Jinwei Gu](https://www.gujinwei.org/)
7 | > In ICCV 2023. 8 | 9 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2306.06513) 10 | [![Website](https://img.shields.io/badge/website-link-blue.svg)](https://kechunl.github.io/AdaCode/) 11 | [![Video](https://img.shields.io/badge/video-red.svg)](https://www.youtube.com/watch?v=GOp-4kbgyoM) 12 | > **Abstract:** Recent work of discrete generative priors, in the form of codebooks, has shown exciting performance for image reconstruction and restoration, since the discrete prior space spanned by the codebooks increases the robustness against diverse image degradations. Nevertheless, these methods require separate training of codebooks for different image categories, which limits their use to specific image categories only (e.g. face, architecture, etc.), and fail to handle arbitrary natural images. **In this paper, we propose AdaCode for learning image-adaptive codebooks for classagnostic image restoration. Instead of learning a single codebook for all categories of images, we learn a set of basis codebooks. For a given input image, AdaCode learns a weight map with which we compute a weighted combination of these basis codebooks for adaptive image restoration.** Intuitively, AdaCode is a more flexible and expressive discrete generative prior than previous work. Experimental results show that AdaCode achieves state-of-the-art performance on image reconstruction and restoration tasks, including image super-resolution and inpainting. 13 | 14 | 15 | 16 | --- 17 | ### Dependencies and Installation 18 | 19 | - Ubuntu >= 18.04 20 | - CUDA >= 11.0 21 | - Other required packages in `requirements.txt` 22 | ``` 23 | # git clone this repository 24 | git clone https://github.com/kechunl/AdaCode.git 25 | cd AdaCode 26 | 27 | # create new anaconda env 28 | conda create -n AdaCode python=3.8 29 | conda activate AdaCode 30 | 31 | # install python dependencies 32 | pip install -r requirements.txt 33 | python setup.py develop 34 | ``` 35 | 36 | --- 37 | 38 | ### Inference 39 | 40 | ``` 41 | # super resolution 42 | python inference.py -i ./testset/sr -t sr -s 2 -o results_x2/ 43 | python inference.py -i ./testset/sr -t sr -s 4 -o results_x4/ 44 | 45 | # inpainting 46 | python inference.py -i ./testset/inpaint -t inpaint -o results_inpaint/ 47 | ``` 48 | 49 | ## Train the model 50 | 51 | ### Preparation 52 | 53 | #### Dataset 54 | 55 | Please prepare the training and testing data follow descriptions in the [paper](https://arxiv.org/abs/2306.06513). In brief, you need to crop 512 x 512 high resolution patches, and generate the low resolution patches with [`degradation_bsrgan`](https://github.com/cszn/BSRGAN/blob/3a958f40a9a24e8b81c3cb1960f05b0e91f1b421/utils/utils_blindsr.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L432) function provided by [BSRGAN](https://github.com/cszn/BSRGAN). You may find `generate_dataset.py` and `generate_inpaint_datset.py` useful in generating low resolution and masked patches. 56 | 57 | #### Model preparation 58 | 59 | Before training, you need to 60 | - Download the pretrained stage 2 model: [generator](https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_S2_model_g.pth), [discriminator](https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_S2_model_d.pth) 61 | - Put the pretrained models in `experiments/pretrained_models` 62 | - Specify their path in the corresponding option file. 63 | 64 | ### Train SR model 65 | 66 | ``` 67 | python basicsr/train.py -opt options/stage3/train_Adacode_stage3.yaml 68 | ``` 69 | 70 | ## Acknowledgement 71 | 72 | This project is based on [FeMaSR](https://github.com/chaofengc/FeMaSR). -------------------------------------------------------------------------------- /basicsr/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/basicsr/.DS_Store -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .models import * 7 | from .test import * 8 | from .train import * 9 | from .utils import * 10 | -------------------------------------------------------------------------------- /basicsr/archs/RRDB_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_layer(block, n_layers): 8 | layers = [] 9 | for _ in range(n_layers): 10 | layers.append(block()) 11 | return nn.Sequential(*layers) 12 | 13 | 14 | class ResidualDenseBlock_5C(nn.Module): 15 | def __init__(self, nf=64, gc=32, bias=True): 16 | super(ResidualDenseBlock_5C, self).__init__() 17 | # gc: growth channel, i.e. intermediate channels 18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # initialization 26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 27 | 28 | def forward(self, x): 29 | x1 = self.lrelu(self.conv1(x)) 30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 34 | return x5 * 0.2 + x 35 | 36 | 37 | class RRDB(nn.Module): 38 | '''Residual in Residual Dense Block''' 39 | 40 | def __init__(self, nf, gc=32): 41 | super(RRDB, self).__init__() 42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 45 | 46 | def forward(self, x): 47 | out = self.RDB1(x) 48 | out = self.RDB2(out) 49 | out = self.RDB3(out) 50 | return out * 0.2 + x 51 | 52 | 53 | class RRDBNet(nn.Module): 54 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 55 | super(RRDBNet, self).__init__() 56 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 57 | 58 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 59 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 60 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | #### upsampling 62 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 66 | 67 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 68 | 69 | def forward(self, x): 70 | fea = self.conv_first(x) 71 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 72 | fea = fea + trunk 73 | 74 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 75 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 76 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 77 | 78 | return out -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /basicsr/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from basicsr.utils.registry import ARCH_REGISTRY 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from torch.nn.utils import spectral_norm 5 | 6 | 7 | @ARCH_REGISTRY.register() 8 | class UNetDiscriminatorSN(nn.Module): 9 | """Defines a U-Net discriminator with spectral normalization (SN) 10 | 11 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 12 | 13 | Arg: 14 | num_in_ch (int): Channel number of inputs. Default: 3. 15 | num_feat (int): Channel number of base intermediate features. Default: 64. 16 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 17 | """ 18 | 19 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 20 | super(UNetDiscriminatorSN, self).__init__() 21 | self.skip_connection = skip_connection 22 | norm = spectral_norm 23 | # the first convolution 24 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 25 | # downsample 26 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 27 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 28 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 29 | # upsample 30 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 31 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 32 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 33 | # extra convolutions 34 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 35 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 36 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 37 | 38 | def forward(self, x): 39 | # downsample 40 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 41 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 42 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 43 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 44 | 45 | # upsample 46 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 47 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 48 | 49 | if self.skip_connection: 50 | x4 = x4 + x2 51 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 52 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 53 | 54 | if self.skip_connection: 55 | x5 = x5 + x1 56 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 57 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 58 | 59 | if self.skip_connection: 60 | x6 = x6 + x0 61 | 62 | # extra convolutions 63 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 64 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 65 | out = self.conv9(out) 66 | 67 | return out -------------------------------------------------------------------------------- /basicsr/archs/fema_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | from torch import nn as nn 4 | 5 | class NormLayer(nn.Module): 6 | """Normalization Layers. 7 | ------------ 8 | # Arguments 9 | - channels: input channels, for batch norm and instance norm. 10 | - input_size: input shape without batch size, for layer norm. 11 | """ 12 | def __init__(self, channels, norm_type='bn'): 13 | super(NormLayer, self).__init__() 14 | norm_type = norm_type.lower() 15 | self.norm_type = norm_type 16 | self.channels = channels 17 | if norm_type == 'bn': 18 | self.norm = nn.BatchNorm2d(channels, affine=True) 19 | elif norm_type == 'in': 20 | self.norm = nn.InstanceNorm2d(channels, affine=False) 21 | elif norm_type == 'gn': 22 | self.norm = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True) 23 | elif norm_type == 'none': 24 | self.norm = lambda x: x*1.0 25 | else: 26 | assert 1==0, 'Norm type {} not support.'.format(norm_type) 27 | 28 | def forward(self, x): 29 | return self.norm(x) 30 | 31 | 32 | class ActLayer(nn.Module): 33 | """activation layer. 34 | ------------ 35 | # Arguments 36 | - relu type: type of relu layer, candidates are 37 | - ReLU 38 | - LeakyReLU: default relu slope 0.2 39 | - PRelu 40 | - SELU 41 | - none: direct pass 42 | """ 43 | def __init__(self, channels, relu_type='leakyrelu'): 44 | super(ActLayer, self).__init__() 45 | relu_type = relu_type.lower() 46 | if relu_type == 'relu': 47 | self.func = nn.ReLU(True) 48 | elif relu_type == 'leakyrelu': 49 | self.func = nn.LeakyReLU(0.2, inplace=True) 50 | elif relu_type == 'prelu': 51 | self.func = nn.PReLU(channels) 52 | elif relu_type == 'none': 53 | self.func = lambda x: x*1.0 54 | elif relu_type == 'silu': 55 | self.func = nn.SiLU(True) 56 | elif relu_type == 'gelu': 57 | self.func = nn.GELU() 58 | else: 59 | assert 1==0, 'activation type {} not support.'.format(relu_type) 60 | 61 | def forward(self, x): 62 | return self.func(x) 63 | 64 | 65 | class ResBlock(nn.Module): 66 | """ 67 | Use preactivation version of residual block, the same as taming 68 | """ 69 | def __init__(self, in_channel, out_channel, norm_type='gn', act_type='leakyrelu'): 70 | super(ResBlock, self).__init__() 71 | 72 | self.conv = nn.Sequential( 73 | NormLayer(in_channel, norm_type), 74 | ActLayer(in_channel, act_type), 75 | nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1), 76 | NormLayer(out_channel, norm_type), 77 | ActLayer(out_channel, act_type), 78 | nn.Conv2d(out_channel, out_channel, 3, stride=1, padding=1), 79 | ) 80 | 81 | def forward(self, input): 82 | res = self.conv(input) 83 | out = res + input 84 | return out 85 | 86 | 87 | class CombineQuantBlock(nn.Module): 88 | def __init__(self, in_ch1, in_ch2, out_channel): 89 | super().__init__() 90 | self.conv = nn.Conv2d(in_ch1 + in_ch2, out_channel, 3, 1, 1) 91 | 92 | def forward(self, input1, input2=None): 93 | if input2 is not None: 94 | input2 = F.interpolate(input2, input1.shape[2:]) 95 | input = torch.cat((input1, input2), dim=1) 96 | else: 97 | input = input1 98 | out = self.conv(input) 99 | return out 100 | 101 | 102 | -------------------------------------------------------------------------------- /basicsr/archs/vgg_arch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from torch import nn as nn 5 | from torchvision.models import vgg as vgg 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' 10 | NAMES = { 11 | 'vgg11': [ 12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 14 | 'pool5' 15 | ], 16 | 'vgg13': [ 17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 20 | ], 21 | 'vgg16': [ 22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 25 | 'pool5' 26 | ], 27 | 'vgg19': [ 28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', 30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' 32 | ] 33 | } 34 | 35 | 36 | def insert_bn(names): 37 | """Insert bn layer after each conv. 38 | 39 | Args: 40 | names (list): The list of layer names. 41 | 42 | Returns: 43 | list: The list of layer names with bn layers. 44 | """ 45 | names_bn = [] 46 | for name in names: 47 | names_bn.append(name) 48 | if 'conv' in name: 49 | position = name.replace('conv', '') 50 | names_bn.append('bn' + position) 51 | return names_bn 52 | 53 | 54 | @ARCH_REGISTRY.register() 55 | class VGGFeatureExtractor(nn.Module): 56 | """VGG network for feature extraction. 57 | 58 | In this implementation, we allow users to choose whether use normalization 59 | in the input feature and the type of vgg network. Note that the pretrained 60 | path must fit the vgg type. 61 | 62 | Args: 63 | layer_name_list (list[str]): Forward function returns the corresponding 64 | features according to the layer_name_list. 65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 67 | use_input_norm (bool): If True, normalize the input image. Importantly, 68 | the input feature must in the range [0, 1]. Default: True. 69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. 70 | Default: False. 71 | requires_grad (bool): If true, the parameters of VGG network will be 72 | optimized. Default: False. 73 | remove_pooling (bool): If true, the max pooling operations in VGG net 74 | will be removed. Default: False. 75 | pooling_stride (int): The stride of max pooling operation. Default: 2. 76 | """ 77 | 78 | def __init__(self, 79 | layer_name_list, 80 | vgg_type='vgg19', 81 | use_input_norm=True, 82 | range_norm=False, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | self.range_norm = range_norm 91 | 92 | self.names = NAMES[vgg_type.replace('_bn', '')] 93 | if 'bn' in vgg_type: 94 | self.names = insert_bn(self.names) 95 | 96 | # only borrow layers that will be used to avoid unused params 97 | max_idx = 0 98 | for v in layer_name_list: 99 | idx = self.names.index(v) 100 | if idx > max_idx: 101 | max_idx = idx 102 | 103 | if os.path.exists(VGG_PRETRAIN_PATH): 104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False) 105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) 106 | vgg_net.load_state_dict(state_dict) 107 | else: 108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True) 109 | 110 | features = vgg_net.features[:max_idx + 1] 111 | 112 | modified_net = OrderedDict() 113 | for k, v in zip(self.names, features): 114 | if 'pool' in k: 115 | # if remove_pooling is true, pooling operation will be removed 116 | if remove_pooling: 117 | continue 118 | else: 119 | # in some cases, we may want to change the default stride 120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) 121 | else: 122 | modified_net[k] = v 123 | 124 | self.vgg_net = nn.Sequential(modified_net) 125 | 126 | if not requires_grad: 127 | self.vgg_net.eval() 128 | for param in self.parameters(): 129 | param.requires_grad = False 130 | else: 131 | self.vgg_net.train() 132 | for param in self.parameters(): 133 | param.requires_grad = True 134 | 135 | if self.use_input_norm: 136 | # the mean is for image with range [0, 1] 137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 138 | # the std is for image with range [0, 1] 139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 140 | 141 | def forward(self, x): 142 | """Forward function. 143 | 144 | Args: 145 | x (Tensor): Input tensor with shape (n, c, h, w). 146 | 147 | Returns: 148 | Tensor: Forward results. 149 | """ 150 | if self.range_norm: 151 | x = (x + 1) / 2 152 | if self.use_input_norm: 153 | x = (x - self.mean) / self.std 154 | 155 | output = {} 156 | for key, layer in self.vgg_net._modules.items(): 157 | x = layer(x) 158 | if key in self.layer_name_list: 159 | output[key] = x.clone() 160 | 161 | return output 162 | -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/bsrgan_train_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from torch.utils import data as data 3 | 4 | from basicsr.data.bsrgan_util import degradation_bsrgan 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | from .data_util import make_dataset 10 | 11 | import cv2 12 | import random 13 | 14 | 15 | def random_resize(img, scale_factor=1.): 16 | return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC) 17 | 18 | 19 | def random_crop(img, out_size): 20 | h, w = img.shape[:2] 21 | rnd_h = random.randint(0, h - out_size) 22 | rnd_w = random.randint(0, w - out_size) 23 | return img[rnd_h: rnd_h + out_size, rnd_w: rnd_w + out_size] 24 | 25 | 26 | @DATASET_REGISTRY.register() 27 | class BSRGANTrainDataset(data.Dataset): 28 | """Synthesize LR-HR pairs online with BSRGAN for image restoration. 29 | 30 | Args: 31 | opt (dict): Config for train datasets. It contains the following keys: 32 | dataroot_gt (str): Data root path for gt. 33 | dataroot_lq (str): Data root path for lq. 34 | meta_info_file (str): Path for meta information file. 35 | gt_size (int): Cropped patched size for gt patches. 36 | use_flip (bool): Use horizontal flips. 37 | use_rot (bool): Use rotation (use vertical flip and transposing h 38 | and w for implementation). 39 | 40 | scale (bool): Scale, which will be added automatically. 41 | phase (str): 'train' or 'val'. 42 | """ 43 | 44 | def __init__(self, opt): 45 | super(BSRGANTrainDataset, self).__init__() 46 | self.opt = opt 47 | # file client (io backend) 48 | self.file_client = None 49 | self.io_backend_opt = opt['io_backend'] 50 | 51 | if opt.get('dataroot_gt') is not None: 52 | self.gt_folder = opt['dataroot_gt'] 53 | self.gt_paths = make_dataset(self.gt_folder) 54 | elif opt.get('datafile_gt') is not None: 55 | with open(opt.get('datafile_gt'), "r") as f: 56 | paths = f.read().splitlines() 57 | self.gt_paths = paths 58 | else: 59 | raise ValueError("Unknown path for gt.") 60 | 61 | def __getitem__(self, index): 62 | 63 | scale = self.opt['scale'] 64 | 65 | gt_path = self.gt_paths[index] 66 | img_gt = cv2.imread(gt_path).astype(np.float32) / 255. 67 | 68 | img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB 69 | gt_size = self.opt['gt_size'] 70 | 71 | if self.opt['phase'] == 'train': 72 | if self.opt['use_resize_crop']: 73 | input_gt_size = img_gt.shape[0] 74 | input_gt_random_size = random.randint(gt_size, input_gt_size) 75 | resize_factor = input_gt_random_size / input_gt_size 76 | img_gt = random_resize(img_gt, resize_factor) 77 | 78 | img_gt = random_crop(img_gt, gt_size) 79 | 80 | img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, lq_patchsize=self.opt['gt_size'] // scale, use_crop=False) 81 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], 82 | self.opt['use_rot']) 83 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=False, float32=True) 84 | 85 | return { 86 | 'lq': img_lq, 87 | 'gt': img_gt, 88 | 'lq_path': gt_path, 89 | 'gt_path': gt_path 90 | } 91 | 92 | def __len__(self): 93 | return len(self.gt_paths) 94 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 116 | 117 | def next(self): 118 | torch.cuda.current_stream().wait_stream(self.stream) 119 | batch = self.batch 120 | self.preload() 121 | return batch 122 | 123 | def reset(self): 124 | self.loader = iter(self.ori_loader) 125 | self.preload() 126 | -------------------------------------------------------------------------------- /basicsr/data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | from basicsr.utils.matlab_functions import rgb2ycbcr 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | 10 | 11 | @DATASET_REGISTRY.register() 12 | class SingleImageDataset(data.Dataset): 13 | """Read only lq images in the test phase. 14 | 15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 16 | 17 | There are two modes: 18 | 1. 'meta_info_file': Use meta information file to generate paths. 19 | 2. 'folder': Scan folders to generate paths. 20 | 21 | Args: 22 | opt (dict): Config for train datasets. It contains the following keys: 23 | dataroot_lq (str): Data root path for lq. 24 | meta_info_file (str): Path for meta information file. 25 | io_backend (dict): IO backend type and other kwarg. 26 | """ 27 | 28 | def __init__(self, opt): 29 | super(SingleImageDataset, self).__init__() 30 | self.opt = opt 31 | # file client (io backend) 32 | self.file_client = None 33 | self.io_backend_opt = opt['io_backend'] 34 | self.mean = opt['mean'] if 'mean' in opt else None 35 | self.std = opt['std'] if 'std' in opt else None 36 | self.lq_folder = opt['dataroot_lq'] 37 | 38 | if self.io_backend_opt['type'] == 'lmdb': 39 | self.io_backend_opt['db_paths'] = [self.lq_folder] 40 | self.io_backend_opt['client_keys'] = ['lq'] 41 | self.paths = paths_from_lmdb(self.lq_folder) 42 | elif 'meta_info_file' in self.opt: 43 | with open(self.opt['meta_info_file'], 'r') as fin: 44 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 51 | 52 | # load lq image 53 | lq_path = self.paths[index] 54 | img_bytes = self.file_client.get(lq_path, 'lq') 55 | img_lq = imfrombytes(img_bytes, float32=True) 56 | 57 | # color space transform 58 | if 'color' in self.opt and self.opt['color'] == 'y': 59 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 60 | 61 | # BGR to RGB, HWC to CHW, numpy to tensor 62 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 63 | # normalize 64 | if self.mean is not None or self.std is not None: 65 | normalize(img_lq, self.mean, self.std, inplace=True) 66 | return {'lq': img_lq, 'lq_path': lq_path} 67 | 68 | def __len__(self): 69 | return len(self.paths) 70 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[:h - h_remainder, :w - w_remainder, ...] 21 | else: 22 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 27 | """Paired random crop. Support Numpy array and Tensor inputs. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. Default: None. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | # determine input type: Numpy array or Tensor 53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 54 | 55 | if input_type == 'Tensor': 56 | h_lq, w_lq = img_lqs[0].size()[-2:] 57 | h_gt, w_gt = img_gts[0].size()[-2:] 58 | else: 59 | h_lq, w_lq = img_lqs[0].shape[0:2] 60 | h_gt, w_gt = img_gts[0].shape[0:2] 61 | lq_patch_size = gt_patch_size // scale 62 | 63 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 65 | f'multiplication of LQ ({h_lq}, {w_lq}).') 66 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 68 | f'({lq_patch_size}, {lq_patch_size}). ' 69 | f'Please remove {gt_path}.') 70 | 71 | # randomly choose top and left coordinates for lq patch 72 | top = random.randint(0, h_lq - lq_patch_size) 73 | left = random.randint(0, w_lq - lq_patch_size) 74 | 75 | # crop lq patch 76 | if input_type == 'Tensor': 77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 78 | else: 79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 80 | 81 | # crop corresponding gt patch 82 | top_gt, left_gt = int(top * scale), int(left * scale) 83 | if input_type == 'Tensor': 84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 85 | else: 86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 87 | if len(img_gts) == 1: 88 | img_gts = img_gts[0] 89 | if len(img_lqs) == 1: 90 | img_lqs = img_lqs[0] 91 | return img_gts, img_lqs 92 | 93 | 94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 96 | 97 | We use vertical flip and transpose for rotation implementation. 98 | All the images in the list use the same augmentation. 99 | 100 | Args: 101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 102 | is an ndarray, it will be transformed to a list. 103 | hflip (bool): Horizontal flip. Default: True. 104 | rotation (bool): Ratotation. Default: True. 105 | flows (list[ndarray]: Flows to be augmented. If the input is an 106 | ndarray, it will be transformed to a list. 107 | Dimension is (h, w, 2). Default: None. 108 | return_status (bool): Return the status of flip and rotation. 109 | Default: False. 110 | 111 | Returns: 112 | list[ndarray] | ndarray: Augmented images and flows. If returned 113 | results only have one element, just return ndarray. 114 | 115 | """ 116 | hflip = hflip and random.random() < 0.5 117 | vflip = rotation and random.random() < 0.5 118 | rot90 = rotation and random.random() < 0.5 119 | 120 | def _augment(img): 121 | if hflip: # horizontal 122 | cv2.flip(img, 1, img) 123 | if vflip: # vertical 124 | cv2.flip(img, 0, img) 125 | if rot90: 126 | img = img.transpose(1, 0, 2) 127 | return img 128 | 129 | def _augment_flow(flow): 130 | if hflip: # horizontal 131 | cv2.flip(flow, 1, flow) 132 | flow[:, :, 0] *= -1 133 | if vflip: # vertical 134 | cv2.flip(flow, 0, flow) 135 | flow[:, :, 1] *= -1 136 | if rot90: 137 | flow = flow.transpose(1, 0, 2) 138 | flow = flow[:, :, [1, 0]] 139 | return flow 140 | 141 | if not isinstance(imgs, list): 142 | imgs = [imgs] 143 | imgs = [_augment(img) for img in imgs] 144 | if len(imgs) == 1: 145 | imgs = imgs[0] 146 | 147 | if flows is not None: 148 | if not isinstance(flows, list): 149 | flows = [flows] 150 | flows = [_augment_flow(flow) for flow in flows] 151 | if len(flows) == 1: 152 | flows = flows[0] 153 | return imgs, flows 154 | else: 155 | if return_status: 156 | return imgs, (hflip, vflip, rot90) 157 | else: 158 | return imgs 159 | 160 | 161 | def img_rotate(img, angle, center=None, scale=1.0): 162 | """Rotate image. 163 | 164 | Args: 165 | img (ndarray): Image to be rotated. 166 | angle (float): Rotation angle in degrees. Positive values mean 167 | counter-clockwise rotation. 168 | center (tuple[int]): Rotation center. If the center is None, 169 | initialize it as the center of the image. Default: None. 170 | scale (float): Isotropic scale factor. Default: 1.0. 171 | """ 172 | (h, w) = img.shape[:2] 173 | 174 | if center is None: 175 | center = (w // 2, h // 2) 176 | 177 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 178 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 179 | return rotated_img 180 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils import get_root_logger 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | from .losses import (CharbonnierLoss, GANLoss, ContrastiveLoss, L1Loss, MSELoss, SoftCrossEntropy, PerceptualLoss, WeightedTVLoss, g_path_regularize, 6 | gradient_penalty_loss, r1_penalty) 7 | 8 | __all__ = [ 9 | 'ContrastiveLoss', 'L1Loss', 'MSELoss', 'SoftCrossEntropy', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss', 10 | 'r1_penalty', 'g_path_regularize' 11 | ] 12 | 13 | 14 | def build_loss(opt): 15 | """Build loss from options. 16 | 17 | Args: 18 | opt (dict): Configuration. It must contain: 19 | type (str): Model type. 20 | """ 21 | opt = deepcopy(opt) 22 | loss_type = opt.pop('type') 23 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 24 | logger = get_root_logger() 25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 26 | return loss 27 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with 12 | # '_model.py' 13 | model_folder = osp.dirname(osp.abspath(__file__)) 14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 15 | # import all the model modules 16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 17 | 18 | 19 | def build_model(opt): 20 | """Build model from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | model_type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 28 | logger = get_root_logger() 29 | logger.info(f'Model [{model.__class__.__name__}] is created.') 30 | return model 31 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .diffjpeg import DiffJPEG 2 | from .file_client import FileClient 3 | from .img_process_util import USMSharp, usm_sharp 4 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 5 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 6 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 7 | 8 | __all__ = [ 9 | # file_client.py 10 | 'FileClient', 11 | # img_util.py 12 | 'img2tensor', 13 | 'tensor2img', 14 | 'imfrombytes', 15 | 'imwrite', 16 | 'crop_border', 17 | # logger.py 18 | 'MessageLogger', 19 | 'AvgTimer', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | # diffjpeg 33 | 'DiffJPEG', 34 | # img_process_util 35 | 'USMSharp', 36 | 'usm_sharp' 37 | ] 38 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Ref: 15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 16 | 17 | Args: 18 | file_id (str): File id. 19 | save_path (str): Save path. 20 | """ 21 | 22 | session = requests.Session() 23 | URL = 'https://docs.google.com/uc?export=download' 24 | params = {'id': file_id} 25 | 26 | response = session.get(URL, params=params, stream=True) 27 | token = get_confirm_token(response) 28 | if token: 29 | params['confirm'] = token 30 | response = session.get(URL, params=params, stream=True) 31 | 32 | # get file size 33 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 34 | if 'Content-Range' in response_file_size.headers: 35 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 36 | else: 37 | file_size = None 38 | 39 | save_response_content(response, save_path, file_size) 40 | 41 | 42 | def get_confirm_token(response): 43 | for key, value in response.cookies.items(): 44 | if key.startswith('download_warning'): 45 | return value 46 | return None 47 | 48 | 49 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 50 | if file_size is not None: 51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 52 | 53 | readable_file_size = sizeof_fmt(file_size) 54 | else: 55 | pbar = None 56 | 57 | with open(destination, 'wb') as f: 58 | downloaded_size = 0 59 | for chunk in response.iter_content(chunk_size): 60 | downloaded_size += chunk_size 61 | if pbar is not None: 62 | pbar.update(1) 63 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 64 | if chunk: # filter out keep-alive new chunks 65 | f.write(chunk) 66 | if pbar is not None: 67 | pbar.close() 68 | 69 | 70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 71 | """Load file form http url, will download models if necessary. 72 | 73 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 74 | 75 | Args: 76 | url (str): URL to be downloaded. 77 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 78 | Default: None. 79 | progress (bool): Whether to show the download progress. Default: True. 80 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 81 | 82 | Returns: 83 | str: The path to the downloaded file. 84 | """ 85 | if model_dir is None: # use the pytorch hub_dir 86 | hub_dir = get_dir() 87 | model_dir = os.path.join(hub_dir, 'checkpoints') 88 | 89 | os.makedirs(model_dir, exist_ok=True) 90 | 91 | parts = urlparse(url) 92 | filename = os.path.basename(parts.path) 93 | if file_name is not None: 94 | filename = file_name 95 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 96 | if not os.path.exists(cached_file): 97 | print(f'Downloading: "{url}" to {cached_file}\n') 98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 99 | return cached_file 100 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, ' 'header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key) or ('pretrain_codebook' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file siz. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj): 39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 40 | f"in '{self._name}' registry!") 41 | self._obj_map[name] = obj 42 | 43 | def register(self, obj=None): 44 | """ 45 | Register the given object under the the name `obj.__name__`. 46 | Can be used as either a decorator or not. 47 | See docstring of this class for usage. 48 | """ 49 | if obj is None: 50 | # used as a decorator 51 | def deco(func_or_class): 52 | name = func_or_class.__name__ 53 | self._do_register(name, func_or_class) 54 | return func_or_class 55 | 56 | return deco 57 | 58 | # used as a function call 59 | name = obj.__name__ 60 | self._do_register(name, obj) 61 | 62 | def get(self, name): 63 | ret = self._obj_map.get(name) 64 | if ret is None: 65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 66 | return ret 67 | 68 | def __contains__(self, name): 69 | return name in self._obj_map 70 | 71 | def __iter__(self): 72 | return iter(self._obj_map.items()) 73 | 74 | def keys(self): 75 | return self._obj_map.keys() 76 | 77 | 78 | DATASET_REGISTRY = Registry('dataset') 79 | ARCH_REGISTRY = Registry('arch') 80 | MODEL_REGISTRY = Registry('model') 81 | LOSS_REGISTRY = Registry('loss') 82 | METRIC_REGISTRY = Registry('metric') 83 | -------------------------------------------------------------------------------- /evaluate.py: -------------------------------------------------------------------------------- 1 | import pyiqa 2 | import argparse 3 | import cv2 4 | import glob 5 | import os 6 | from tqdm import tqdm 7 | import torch 8 | import yaml 9 | 10 | 11 | def main(): 12 | """Evaluation. Metrics: PSNR, SSIM, LPIPS 13 | """ 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-p', '--pred', type=str, help='Predicted image or folder') 16 | parser.add_argument('-g', '--gt', type=str, help='groundtruth image or folder') 17 | parser.add_argument('-o', '--output', type=str, help='Output folder') 18 | args = parser.parse_args() 19 | 20 | if args.output is None: 21 | args.output = args.pred 22 | os.makedirs(args.output, exist_ok=True) 23 | 24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 25 | metric_funcs = {} 26 | metric_funcs['psnr'] = pyiqa.create_metric('psnr', device=device, crop_border=4, test_y_channel=True) 27 | metric_funcs['ssim'] = pyiqa.create_metric('ssim', device=device, crop_border=4, test_y_channel=True) 28 | metric_funcs['lpips'] = pyiqa.create_metric('lpips', device=device) 29 | 30 | metric_results = {'psnr': 0, 'ssim': 0, 'lpips': 0} 31 | 32 | # image 33 | pred_img = glob.glob(os.path.join(args.pred, '*.png')) 34 | gt_img = glob.glob(os.path.join(args.gt, '*.png')) 35 | basename = list(set([os.path.basename(img) for img in pred_img]) & set([os.path.basename(img) for img in gt_img])) 36 | data = [[os.path.join(args.pred, bn), os.path.join(args.gt, bn)] for bn in basename] 37 | 38 | # evaluate 39 | for idx in range(len(data)): 40 | for name in metric_funcs.keys(): 41 | metric_results[name] += metric_funcs[name](*data[idx]).item() 42 | 43 | for name in metric_results.keys(): 44 | metric_results[name] /= len(data) 45 | 46 | print(metric_results) 47 | with open(os.path.join(args.output, 'result.yaml'), 'w') as outfile: 48 | yaml.dump(metric_results, outfile, default_flow_style=False) 49 | 50 | if __name__ == '__main__': 51 | main() -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/.DS_Store -------------------------------------------------------------------------------- /figures/model.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/model.jpeg -------------------------------------------------------------------------------- /figures/teaser.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/teaser.jpeg -------------------------------------------------------------------------------- /generate_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | import random 5 | from tqdm import tqdm 6 | from multiprocessing import Pool 7 | import argparse 8 | 9 | from basicsr.data.bsrgan_util import degradation_bsrgan 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 14 | '.tif', '.TIF', '.tiff', '.TIFF', 15 | ] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | 22 | def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True): 23 | images = [] 24 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 25 | 26 | for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)): 27 | for fname in fnames: 28 | if is_image_file(fname): 29 | path = os.path.join(root, fname) 30 | images.append(path) 31 | return images[:min(max_dataset_size, len(images))] 32 | 33 | def degrade_img(hr_path, save_path): 34 | img_gt = cv2.imread(hr_path).astype(np.float32) / 255. 35 | img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB 36 | img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, use_crop=False) 37 | img_lq = (img_lq[:, :, [2, 1, 0]] * 255).astype(np.uint8) 38 | print(f'Save {save_path}') 39 | cv2.imwrite(save_path, img_lq) 40 | 41 | argparser = argparse.ArgumentParser(description='Generate degradation dataset following BSRGAN') 42 | argparser.add_argument('--scale', '-s', type=int, choices=[2, 4], help='downscale factor') 43 | argparser.add_argument('--dir', type=str, help='image directory') 44 | args = argparser.parse_args() 45 | 46 | seed = 123 47 | random.seed(seed) 48 | np.random.seed(seed) 49 | 50 | scale = args.scale 51 | hr_img_list = make_dataset(args.dir) 52 | pool = Pool(processes=40) 53 | 54 | for hr_path in hr_img_list: 55 | try: 56 | save_dir = os.path.dirname(hr_path) + f'_LR_X{scale}' 57 | save_path = os.path.join(save_dir, os.path.basename(hr_path)) 58 | if not os.path.exists(save_dir): 59 | os.makedirs(save_dir, exist_ok=True) 60 | pool.apply_async(degrade_img(hr_path, save_path)) 61 | except: 62 | print(hr_path, ': LR not generated.') 63 | 64 | pool.close() 65 | pool.join() 66 | 67 | -------------------------------------------------------------------------------- /generate_inpaint_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | import random 6 | from tqdm import tqdm 7 | from multiprocessing import Pool 8 | import argparse 9 | 10 | from basicsr.data.paired_image_dataset import brush_stroke_mask 11 | 12 | IMG_EXTENSIONS = [ 13 | '.jpg', '.JPG', '.jpeg', '.JPEG', 14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 15 | '.tif', '.TIF', '.tiff', '.TIFF', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | def degrade_img(hr_path, save_path): 35 | img_gt = cv2.imread(hr_path) 36 | img_lq = np.array(brush_stroke_mask(Image.fromarray(img_gt))).astype(np.uint8) 37 | print(f'Save {save_path}') 38 | cv2.imwrite(save_path, img_lq) 39 | 40 | argparser = argparse.ArgumentParser(description='Generate inpainting dataset') 41 | argparser.add_argument('--dir', type=str, help='image directory') 42 | args = argparser.parse_args() 43 | 44 | seed = 123 45 | random.seed(seed) 46 | np.random.seed(seed) 47 | 48 | hr_img_list = make_dataset(args.dir) 49 | pool = Pool(processes=40) 50 | 51 | for hr_path in hr_img_list: 52 | try: 53 | save_dir = os.path.dirname(hr_path) + f'_inpaint' 54 | save_path = os.path.join(save_dir, os.path.basename(hr_path)) 55 | if not os.path.exists(save_dir): 56 | os.makedirs(save_dir, exist_ok=True) 57 | pool.apply_async(degrade_img(hr_path, save_path)) 58 | except: 59 | print(hr_path, ': masked not generated.') 60 | 61 | pool.close() 62 | pool.join() 63 | 64 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import torch 8 | from yaml import load 9 | import matplotlib as m 10 | import matplotlib.pyplot as plt 11 | 12 | from basicsr.utils import img2tensor, tensor2img, imwrite 13 | from basicsr.archs.adacode_contrast_arch import AdaCodeSRNet_Contrast 14 | from basicsr.utils.download_util import load_file_from_url 15 | 16 | pretrain_model_url = { 17 | 'SR_x2': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X2_model_g.pth', 18 | 'SR_x4': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X4_model_g.pth', 19 | 'inpaint': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_Inpaint_model_g.pth', 20 | } 21 | 22 | def main(): 23 | """Inference demo for FeMaSR 24 | """ 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder') 27 | parser.add_argument('-w', '--weight', type=str, default=None, help='path for model weights') 28 | parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') 29 | parser.add_argument('-t', '--task', type=str, choices=['sr', 'inpaint'], help='inference task') 30 | parser.add_argument('-s', '--out_scale', type=int, default=4, help='final upsampling scale of the image for SR task') 31 | parser.add_argument('--suffix', type=str, default='', help='Suffix of the restored image') 32 | parser.add_argument('--max_size', type=int, default=600, help='Max image size for whole image inference, otherwise use tiled_test') 33 | parser.add_argument('--vis_weight', action='store_true', help='visualize weight map') 34 | args = parser.parse_args() 35 | 36 | if args.vis_weight: 37 | os.makedirs(os.path.join(args.output, 'weight_map'), exist_ok=True) 38 | 39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 40 | 41 | if args.task == 'sr': 42 | model_url_key = f'SR_x{args.out_scale}' 43 | scale = args.out_scale 44 | bs = 8 45 | else: 46 | model_url_key = 'inpaint' 47 | scale = 1 48 | bs = 2 49 | 50 | if args.weight is None: 51 | os.makedirs('./checkpoint', exist_ok=True) 52 | weight_path = load_file_from_url(pretrain_model_url[f'{model_url_key}'], model_dir='./checkpoint') 53 | else: 54 | weight_path = args.weight 55 | 56 | # set up the model 57 | model_params = torch.load(weight_path)['params'] 58 | codebook_dim = np.array([v.size() for k,v in model_params.items() if 'quantize_group' in k]) 59 | codebook_dim_list = [] 60 | for k in codebook_dim: 61 | temp = k.tolist() 62 | temp.insert(0,32) 63 | codebook_dim_list.append(temp) 64 | model = AdaCodeSRNet_Contrast(codebook_params=codebook_dim_list, LQ_stage=True, AdaCode_stage=True, weight_softmax=False, batch_size=bs, scale_factor=scale).to(device) 65 | model.load_state_dict(torch.load(weight_path)['params'], strict=False) 66 | model.eval() 67 | 68 | os.makedirs(args.output, exist_ok=True) 69 | if os.path.isfile(args.input): 70 | paths = [args.input] 71 | else: 72 | paths = sorted(glob.glob(os.path.join(args.input, '*'))) 73 | 74 | pbar = tqdm(total=len(paths), unit='image') 75 | for idx, path in enumerate(paths): 76 | try: 77 | img_name = os.path.basename(path) 78 | pbar.set_description(f'Test {img_name}') 79 | 80 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 81 | img_tensor = img2tensor(img).to(device) / 255. 82 | img_tensor = img_tensor.unsqueeze(0) 83 | 84 | max_size = args.max_size ** 2 85 | h, w = img_tensor.shape[2:] 86 | if h * w < max_size: 87 | output = model.test(img_tensor, vis_weight=args.vis_weight) 88 | else: 89 | output = model.test_tile(img_tensor, vis_weight=args.vis_weight) 90 | 91 | if args.vis_weight: 92 | weight_map = output[1] 93 | vis_weight(weight_map, os.path.join(args.output, 'weight_map', img_name)) 94 | output = output[0] 95 | output_img = tensor2img(output) 96 | 97 | save_path = os.path.join(args.output, f'{img_name}') 98 | imwrite(output_img, save_path) 99 | pbar.update(1) 100 | except: 101 | print(path, ' fails.') 102 | pbar.close() 103 | 104 | def vis_weight(weight, save_path): 105 | # weight: B x n x 1 x H x W 106 | weight = weight.cpu().numpy() 107 | # normalize weights 108 | # norm_weight = weight 109 | norm_weight = (weight - weight.mean()) / weight.std() / 2 110 | norm_weight = np.abs(norm_weight) 111 | norm_weight *= 255 112 | norm_weight = np.clip(norm_weight, 0, 255) 113 | norm_weight = norm_weight.astype(np.uint8) 114 | # visualize 115 | display_grid = np.zeros((weight.shape[3], (weight.shape[4]+1)*weight.shape[1]-1)) 116 | for img_id in range(len(norm_weight)): 117 | for c in range(norm_weight.shape[1]): 118 | display_grid[:, c*weight.shape[4]+c:(c+1)*weight.shape[4]+c] = norm_weight[img_id, c, 0, :, :] 119 | # weight_path = save_path.split('.')[0] + '_{}.png'.format(str(c)) 120 | # Image.fromarray(norm_weight[img_id, c, 0, :, :]).save(weight_path) 121 | plt.figure(figsize=(30,150)) 122 | plt.axis('off') 123 | plt.imshow(display_grid) 124 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 125 | plt.close() 126 | 127 | if __name__ == '__main__': 128 | main() 129 | -------------------------------------------------------------------------------- /options/stage1/train_AdaCode_HQ_stage1_category.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: AdaCode_stage1_category 3 | model_type: FeMaSRModel 4 | scale: 4 # doesn't matter for this stage 5 | num_gpu: 4 # set num_gpu: 0 for cpu mode 6 | manual_seed: 0 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: General_Image_Train 12 | type: BSRGANTrainDataset 13 | datafile_gt: ./data/category_data/train/portrait.txt 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 256 18 | use_resize_crop: true 19 | use_flip: true 20 | use_rot: true 21 | 22 | # data loader 23 | use_shuffle: true 24 | batch_size_per_gpu: &bsz 8 25 | num_worker_per_gpu: *bsz 26 | dataset_enlarge_ratio: 1 27 | 28 | prefetch_mode: cpu 29 | num_prefetch_queue: *bsz 30 | 31 | val: 32 | name: General_Image_Valid 33 | type: PairedImageDataset 34 | datafile_gt: ./data/category_data/valid/portrait.txt 35 | datafile_lq: ./data/category_data/valid/portrait.txt 36 | io_backend: 37 | type: disk 38 | 39 | # network structures 40 | network_g: 41 | type: FeMaSRNet 42 | gt_resolution: 256 43 | norm_type: 'gn' 44 | act_type: 'silu' 45 | use_semantic_loss: true 46 | codebook_params: # has to order from low to high 47 | - [32, 256, 256] 48 | 49 | # for HQ stage training 50 | LQ_stage: false 51 | use_quantize: true 52 | 53 | network_d: 54 | type: UNetDiscriminatorSN 55 | num_in_ch: 3 56 | 57 | # path 58 | path: 59 | # pretrain_network_g: ./experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth 60 | # pretrain_network_d: ~ 61 | # pretrain_network_g: ./experiments/004_FeMaSR_HQ_stage/models/net_g_best_.pth 62 | # pretrain_network_d: ./experiments/004_FeMaSR_HQ_stage/models/net_d_best_.pth 63 | strict_load: false 64 | # resume_state: ~ 65 | 66 | # training settings 67 | train: 68 | optim_g: 69 | type: Adam 70 | lr: !!float 1e-4 71 | weight_decay: 0 72 | betas: [0.9, 0.99] 73 | optim_d: 74 | type: Adam 75 | lr: !!float 4e-4 76 | weight_decay: 0 77 | betas: [0.9, 0.99] 78 | 79 | scheduler: 80 | type: MultiStepLR 81 | milestones: [1800, 3600, 5400, 7200, 9000, 10800, 12600] 82 | gamma: 1 83 | 84 | total_iter: 73500 85 | warmup_iter: -1 # no warm up 86 | 87 | # losses 88 | pixel_opt: 89 | type: L1Loss 90 | loss_weight: 1.0 91 | reduction: mean 92 | 93 | perceptual_opt: 94 | type: LPIPSLoss 95 | loss_weight: !!float 1.0 96 | 97 | gan_opt: 98 | type: GANLoss 99 | gan_type: hinge 100 | real_label_val: 1.0 101 | fake_label_val: 0.0 102 | loss_weight: 0.1 103 | 104 | codebook_opt: 105 | loss_weight: 1.0 106 | 107 | semantic_opt: 108 | loss_weight: 0.1 109 | 110 | net_d_iters: 1 111 | net_d_init_iters: !!float 0 112 | 113 | # validation settings· 114 | val: 115 | val_freq: !!float 3e2 116 | save_img: false 117 | 118 | key_metric: lpips 119 | metrics: 120 | psnr: # metric name, not used in this codebase 121 | type: psnr 122 | crop_border: 4 123 | test_y_channel: true 124 | color_space: ycbcr 125 | ssim: 126 | type: ssim 127 | crop_border: 4 128 | test_y_channel: true 129 | color_space: ycbcr 130 | lpips: 131 | type: lpips 132 | better: lower 133 | 134 | # logging settings 135 | logger: 136 | print_freq: 100 137 | save_checkpoint_freq: !!float 1e9 138 | save_latest_freq: !!float 1e3 139 | show_tf_imgs_freq: !!float 3e2 140 | use_tb_logger: true 141 | -------------------------------------------------------------------------------- /options/stage2/train_AdaCode_HQ_stage2.yaml: -------------------------------------------------------------------------------- 1 | # GENERATE TIME: Fri Oct 21 18:30:19 2022 2 | # CMD: 3 | # basicsr/train.py --local_rank=0 -opt options/train_AdaCode_HQ_stage2.yaml --launcher pytorch 4 | 5 | # general settings 6 | name: AdaCode_stage2_recon 7 | model_type: FeMaSRModel 8 | scale: &upscale 4 # doens't matter for this stage 9 | num_gpu: 8 # set num_gpu: 0 for cpu mode 10 | manual_seed: 0 11 | 12 | # dataset and data loader settings 13 | datasets: 14 | train: 15 | name: General_Image_Train 16 | type: BSRGANTrainDataset 17 | datafile_gt: ./data/category_data/train/all_train.txt 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 256 22 | use_resize_crop: true 23 | use_flip: true 24 | use_rot: true 25 | 26 | # data loader 27 | use_shuffle: true 28 | batch_size_per_gpu: &bsz 8 29 | num_worker_per_gpu: *bsz 30 | dataset_enlarge_ratio: 1 31 | 32 | prefetch_mode: cpu 33 | num_prefetch_queue: *bsz 34 | 35 | val: 36 | name: General_Image_Valid 37 | type: PairedImageDataset 38 | datafile_gt: ./data/category_data/valid/all_valid.txt 39 | datafile_lq: ./data/category_data/valid/all_valid.txt 40 | # crop_eval_size: 384 41 | io_backend: 42 | type: disk 43 | 44 | # network structures 45 | network_g: 46 | type: AdaCodeSRNet 47 | gt_resolution: 256 48 | norm_type: 'gn' 49 | act_type: 'silu' 50 | scale_factor: *upscale 51 | 52 | ### TODO: modify the configuration carefully 53 | AdaCode_stage: true 54 | LQ_stage: false 55 | frozen_module_keywords: ['quantize'] 56 | 57 | network_d: 58 | type: UNetDiscriminatorSN 59 | num_in_ch: 3 60 | 61 | # path 62 | path: 63 | pretrain_codebook: 64 | - ./experiments/AdaCode_stage1_c0_codebook512x256/models/net_g_best_.pth 65 | - ./experiments/AdaCode_stage1_c1_codebook256x256/models/net_g_best_.pth 66 | - ./experiments/AdaCode_stage1_c2_codebook512x256/models/net_g_best_.pth 67 | - ./experiments/AdaCode_stage1_c3_codebook256x256/models/net_g_best_.pth 68 | - ./experiments/AdaCode_stage1_ffhq_codebook256x256/models/net_g_best_.pth 69 | 70 | # pretrain_network_hq: ./experiments/008_FeMaSR_HQ_stage/models/net_g_best_.pth 71 | # pretrain_network_g: ~ 72 | # pretrain_network_d: ./experiments/008_FeMaSR_HQ_stage/models/net_d_best_.pth 73 | strict_load: false 74 | # resume_state: ~ 75 | 76 | # training settings 77 | train: 78 | optim_g: 79 | type: Adam 80 | lr: !!float 1e-4 81 | weight_decay: 0 82 | betas: [0.9, 0.99] 83 | optim_d: 84 | type: Adam 85 | lr: !!float 4e-4 86 | weight_decay: 0 87 | betas: [0.9, 0.99] 88 | 89 | scheduler: 90 | type: MultiStepLR 91 | milestones: [18000, 36000, 54000, 72000, 90000, 108000, 126000] 92 | gamma: 1 93 | 94 | total_iter: 730000 95 | warmup_iter: -1 # no warm up 96 | 97 | # losses 98 | pixel_opt: 99 | type: L1Loss 100 | loss_weight: 1.0 101 | reduction: mean 102 | 103 | perceptual_opt: 104 | type: LPIPSLoss 105 | loss_weight: !!float 1.0 106 | 107 | gan_opt: 108 | type: GANLoss 109 | gan_type: hinge 110 | real_label_val: 1.0 111 | fake_label_val: 0.0 112 | loss_weight: 0.1 113 | 114 | codebook_opt: 115 | loss_weight: 1.0 116 | 117 | semantic_opt: 118 | loss_weight: 0.1 119 | 120 | net_d_iters: 1 121 | net_d_init_iters: !!float 0 122 | 123 | # validation settings· 124 | val: 125 | val_freq: !!float 3e3 126 | save_img: true 127 | 128 | key_metric: lpips 129 | metrics: 130 | psnr: # metric name, can be arbitrary 131 | type: psnr 132 | crop_border: 4 133 | test_y_channel: true 134 | ssim: 135 | type: ssim 136 | crop_border: 4 137 | test_y_channel: true 138 | lpips: 139 | type: lpips 140 | better: lower 141 | 142 | # logging settings 143 | logger: 144 | print_freq: 100 145 | save_checkpoint_freq: !!float 1e9 146 | save_latest_freq: !!float 3e3 147 | show_tf_imgs_freq: !!float 1e3 148 | use_tb_logger: true 149 | 150 | # wandb: 151 | # project: ESRGAN 152 | # resume_id: ~ 153 | 154 | # dist training settings 155 | # dist_params: 156 | # backend: nccl 157 | # port: 16500 #29500 158 | -------------------------------------------------------------------------------- /options/stage3/train_AdaCode_stage3.yaml: -------------------------------------------------------------------------------- 1 | # GENERATE TIME: Sun Nov 27 00:55:18 2022 2 | # CMD: 3 | # basicsr/train.py --local_rank=0 -opt options/train_AdaCode_stage3.yaml --launcher pytorch 4 | 5 | # general settings 6 | name: AdaCode_stage3_SR_X2 7 | model_type: FeMaSRModel 8 | scale: &upscale 2 9 | num_gpu: 8 # set num_gpu: 0 for cpu mode 10 | manual_seed: 0 11 | 12 | # dataset and data loader settings 13 | datasets: 14 | train: 15 | name: General_Image_Train 16 | # type: BSRGANTrainDataset 17 | # datafile_gt: /newDisk/users/liukechun/research/FeMaSR/data/category_data/train/all_train.txt 18 | type: PairedImageDataset 19 | datafile_gt: ./data/category_data/train/all_train.txt 20 | datafile_lq: ./data/category_data/train/all_train_X2.txt 21 | io_backend: 22 | type: disk 23 | 24 | gt_size: 256 25 | use_resize_crop: true 26 | use_flip: true 27 | use_rot: true 28 | 29 | # data loader 30 | use_shuffle: true 31 | batch_size_per_gpu: &bsz 8 32 | num_worker_per_gpu: *bsz 33 | dataset_enlarge_ratio: 1 34 | 35 | prefetch_mode: cpu 36 | num_prefetch_queue: *bsz 37 | 38 | val: 39 | name: General_Image_Valid 40 | type: PairedImageDataset 41 | datafile_gt: ./data/category_data/valid/all_valid.txt 42 | datafile_lq: ./data/category_data/valid/all_valid_X2.txt 43 | # crop_eval_size: 384 44 | io_backend: 45 | type: disk 46 | 47 | # network structures 48 | network_g: 49 | type: AdaCodeSRNet_Contrast 50 | gt_resolution: 256 51 | norm_type: 'gn' 52 | act_type: 'silu' 53 | scale_factor: *upscale 54 | 55 | ### TODO: modify the configuration carefully 56 | AdaCode_stage: true 57 | LQ_stage: true 58 | frozen_module_keywords: ['quantize', 'decoder', 'after_quant_group', 'out_conv'] 59 | 60 | weight_softmax: false 61 | 62 | network_d: 63 | type: UNetDiscriminatorSN 64 | num_in_ch: 3 65 | 66 | # path 67 | path: 68 | # pretrain_codebook: 69 | # - ./experiments/AdaCode_stage1_c0_codebook512x256/models/net_g_best_.pth 70 | # - ./experiments/AdaCode_stage1_c1_codebook256x256/models/net_g_best_.pth 71 | # - ./experiments/AdaCode_stage1_c2_codebook512x256/models/net_g_best_.pth 72 | # - ./experiments/AdaCode_stage1_c3_codebook256x256/models/net_g_best_.pth 73 | # - ./experiments/AdaCode_stage1_ffhq_codebook256x256/models/net_g_best_.pth 74 | 75 | pretrain_network_hq: ./experiments/stage2/AdaCode_stage2_recon/models/net_g_best_.pth 76 | # pretrain_network_g: ~ 77 | pretrain_network_d: ./experiments/stage2/AdaCode_stage2_recon/models/net_d_best_.pth 78 | strict_load: false 79 | # resume_state: ~ 80 | 81 | # training settings 82 | train: 83 | optim_g: 84 | type: Adam 85 | lr: !!float 1e-4 86 | weight_decay: 0 87 | betas: [0.9, 0.99] 88 | optim_d: 89 | type: Adam 90 | lr: !!float 4e-4 91 | weight_decay: 0 92 | betas: [0.9, 0.99] 93 | 94 | scheduler: 95 | type: MultiStepLR 96 | milestones: [18000, 36000, 54000, 72000, 90000, 108000, 126000] 97 | gamma: 1 98 | 99 | total_iter: 730000 100 | warmup_iter: -1 # no warm up 101 | 102 | # losses 103 | pixel_opt: 104 | type: L1Loss 105 | loss_weight: 1.0 106 | reduction: mean 107 | 108 | perceptual_opt: 109 | type: LPIPSLoss 110 | loss_weight: !!float 1.0 111 | 112 | gan_opt: 113 | type: GANLoss 114 | gan_type: hinge 115 | real_label_val: 1.0 116 | fake_label_val: 0.0 117 | loss_weight: 0.1 118 | 119 | codebook_opt: 120 | loss_weight: 1.0 121 | 122 | semantic_opt: 123 | loss_weight: 0.1 124 | 125 | before_quant_opt: 126 | loss_weight: 0.0 127 | 128 | after_quant_opt: 129 | loss_weight: 1.0 130 | 131 | contrast_opt: 132 | loss_weight: 1.0 133 | 134 | net_d_iters: 1 135 | net_d_init_iters: !!float 0 136 | 137 | # validation settings· 138 | val: 139 | val_freq: !!float 3e3 140 | save_img: false 141 | 142 | key_metric: lpips 143 | metrics: 144 | psnr: # metric name, can be arbitrary 145 | type: psnr 146 | crop_border: 4 147 | test_y_channel: true 148 | ssim: 149 | type: ssim 150 | crop_border: 4 151 | test_y_channel: true 152 | lpips: 153 | type: lpips 154 | better: lower 155 | 156 | # logging settings 157 | logger: 158 | print_freq: 100 159 | save_checkpoint_freq: !!float 1e9 160 | save_latest_freq: !!float 3e3 161 | show_tf_imgs_freq: !!float 1e3 162 | use_tb_logger: true 163 | 164 | # wandb: 165 | # project: ESRGAN 166 | # resume_id: ~ 167 | 168 | # dist training settings 169 | # dist_params: 170 | # backend: nccl 171 | # port: 16500 #29500 172 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | setuptools>=57.0.0 5 | numpy>=1.20.3 6 | opencv-python-headless 7 | Pillow>=8.3.2 8 | pyyaml 9 | requests 10 | scikit-image 11 | scikit-learn 12 | scipy 13 | tb-nightly 14 | torch>=1.8.1 15 | torchvision>=0.9 16 | tqdm 17 | yapf 18 | pyiqa 19 | einops 20 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/data_preparation/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | 8 | def create_lmdb_for_div2k(): 9 | """Create lmdb files for DIV2K dataset. 10 | 11 | Usage: 12 | Before run this script, please run `extract_subimages.py`. 13 | Typically, there are four folders to be processed for DIV2K dataset. 14 | DIV2K_train_HR_sub 15 | DIV2K_train_LR_bicubic/X2_sub 16 | DIV2K_train_LR_bicubic/X3_sub 17 | DIV2K_train_LR_bicubic/X4_sub 18 | Remember to modify opt configurations according to your settings. 19 | """ 20 | # HR images 21 | folder_path = 'datasets/DIV2K/DIV2K_train_HR_sub' 22 | lmdb_path = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb' 23 | img_path_list, keys = prepare_keys_div2k(folder_path) 24 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 25 | 26 | # LRx2 images 27 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub' 28 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb' 29 | img_path_list, keys = prepare_keys_div2k(folder_path) 30 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 31 | 32 | # LRx3 images 33 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub' 34 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb' 35 | img_path_list, keys = prepare_keys_div2k(folder_path) 36 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 37 | 38 | # LRx4 images 39 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub' 40 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb' 41 | img_path_list, keys = prepare_keys_div2k(folder_path) 42 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 43 | 44 | 45 | def prepare_keys_div2k(folder_path): 46 | """Prepare image path list and keys for DIV2K dataset. 47 | 48 | Args: 49 | folder_path (str): Folder path. 50 | 51 | Returns: 52 | list[str]: Image path list. 53 | list[str]: Key list. 54 | """ 55 | print('Reading image path list ...') 56 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False))) 57 | keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)] 58 | 59 | return img_path_list, keys 60 | 61 | 62 | def create_lmdb_for_reds(): 63 | """Create lmdb files for REDS dataset. 64 | 65 | Usage: 66 | Before run this script, please run `merge_reds_train_val.py`. 67 | We take two folders for example: 68 | train_sharp 69 | train_sharp_bicubic 70 | Remember to modify opt configurations according to your settings. 71 | """ 72 | # train_sharp 73 | folder_path = 'datasets/REDS/train_sharp' 74 | lmdb_path = 'datasets/REDS/train_sharp_with_val.lmdb' 75 | img_path_list, keys = prepare_keys_reds(folder_path) 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 77 | 78 | # train_sharp_bicubic 79 | folder_path = 'datasets/REDS/train_sharp_bicubic' 80 | lmdb_path = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb' 81 | img_path_list, keys = prepare_keys_reds(folder_path) 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 83 | 84 | 85 | def prepare_keys_reds(folder_path): 86 | """Prepare image path list and keys for REDS dataset. 87 | 88 | Args: 89 | folder_path (str): Folder path. 90 | 91 | Returns: 92 | list[str]: Image path list. 93 | list[str]: Key list. 94 | """ 95 | print('Reading image path list ...') 96 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=True))) 97 | keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000 98 | 99 | return img_path_list, keys 100 | 101 | 102 | def create_lmdb_for_vimeo90k(): 103 | """Create lmdb files for Vimeo90K dataset. 104 | 105 | Usage: 106 | Remember to modify opt configurations according to your settings. 107 | """ 108 | # GT 109 | folder_path = 'datasets/vimeo90k/vimeo_septuplet/sequences' 110 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb' 111 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' 112 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'gt') 113 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 114 | 115 | # LQ 116 | folder_path = 'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences' 117 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb' 118 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt' 119 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'lq') 120 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True) 121 | 122 | 123 | def prepare_keys_vimeo90k(folder_path, train_list_path, mode): 124 | """Prepare image path list and keys for Vimeo90K dataset. 125 | 126 | Args: 127 | folder_path (str): Folder path. 128 | train_list_path (str): Path to the official train list. 129 | mode (str): One of 'gt' or 'lq'. 130 | 131 | Returns: 132 | list[str]: Image path list. 133 | list[str]: Key list. 134 | """ 135 | print('Reading image path list ...') 136 | with open(train_list_path, 'r') as fin: 137 | train_list = [line.strip() for line in fin] 138 | 139 | img_path_list = [] 140 | keys = [] 141 | for line in train_list: 142 | folder, sub_folder = line.split('/') 143 | img_path_list.extend([osp.join(folder, sub_folder, f'im{j + 1}.png') for j in range(7)]) 144 | keys.extend([f'{folder}/{sub_folder}/im{j + 1}' for j in range(7)]) 145 | 146 | if mode == 'gt': 147 | print('Only keep the 4th frame for the gt mode.') 148 | img_path_list = [v for v in img_path_list if v.endswith('im4.png')] 149 | keys = [v for v in keys if v.endswith('/im4')] 150 | 151 | return img_path_list, keys 152 | 153 | 154 | if __name__ == '__main__': 155 | parser = argparse.ArgumentParser() 156 | 157 | parser.add_argument( 158 | '--dataset', 159 | type=str, 160 | help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' " 161 | 'You may need to modify the corresponding configurations in codes.')) 162 | args = parser.parse_args() 163 | dataset = args.dataset.lower() 164 | if dataset == 'div2k': 165 | create_lmdb_for_div2k() 166 | elif dataset == 'reds': 167 | create_lmdb_for_reds() 168 | elif dataset == 'vimeo90k': 169 | create_lmdb_for_vimeo90k() 170 | else: 171 | raise ValueError('Wrong dataset.') 172 | -------------------------------------------------------------------------------- /scripts/data_preparation/download_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | from os import path as osp 5 | 6 | from basicsr.utils.download_util import download_file_from_google_drive 7 | 8 | 9 | def download_dataset(dataset, file_ids): 10 | save_path_root = './datasets/' 11 | os.makedirs(save_path_root, exist_ok=True) 12 | 13 | for file_name, file_id in file_ids.items(): 14 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 15 | if osp.exists(save_path): 16 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') 17 | if user_response.lower() == 'y': 18 | print(f'Covering {file_name} to {save_path}') 19 | download_file_from_google_drive(file_id, save_path) 20 | elif user_response.lower() == 'n': 21 | print(f'Skipping {file_name}') 22 | else: 23 | raise ValueError('Wrong input. Only accepts Y/N.') 24 | else: 25 | print(f'Downloading {file_name} to {save_path}') 26 | download_file_from_google_drive(file_id, save_path) 27 | 28 | # unzip 29 | if save_path.endswith('.zip'): 30 | extracted_path = save_path.replace('.zip', '') 31 | print(f'Extract {save_path} to {extracted_path}') 32 | import zipfile 33 | with zipfile.ZipFile(save_path, 'r') as zip_ref: 34 | zip_ref.extractall(extracted_path) 35 | 36 | file_name = file_name.replace('.zip', '') 37 | subfolder = osp.join(extracted_path, file_name) 38 | if osp.isdir(subfolder): 39 | print(f'Move {subfolder} to {extracted_path}') 40 | import shutil 41 | for path in glob.glob(osp.join(subfolder, '*')): 42 | shutil.move(path, extracted_path) 43 | shutil.rmtree(subfolder) 44 | 45 | 46 | if __name__ == '__main__': 47 | parser = argparse.ArgumentParser() 48 | 49 | parser.add_argument( 50 | 'dataset', 51 | type=str, 52 | help=("Options: 'Set5', 'Set14'. " 53 | "Set to 'all' if you want to download all the dataset.")) 54 | args = parser.parse_args() 55 | 56 | file_ids = { 57 | 'Set5': { 58 | 'Set5.zip': # file name 59 | '1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id 60 | }, 61 | 'Set14': { 62 | 'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E', 63 | } 64 | } 65 | 66 | if args.dataset == 'all': 67 | for dataset in file_ids.keys(): 68 | download_dataset(dataset, file_ids[dataset]) 69 | else: 70 | download_dataset(args.dataset, file_ids[args.dataset]) 71 | -------------------------------------------------------------------------------- /scripts/data_preparation/extract_subimages.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import sys 5 | from multiprocessing import Pool 6 | from os import path as osp 7 | from tqdm import tqdm 8 | 9 | from basicsr.utils import scandir 10 | 11 | 12 | def main(): 13 | """A multi-thread tool to crop large images to sub-images for faster IO. 14 | 15 | It is used for DIV2K dataset. 16 | 17 | opt (dict): Configuration dict. It contains: 18 | n_thread (int): Thread number. 19 | compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. 20 | A higher value means a smaller size and longer compression time. 21 | Use 0 for faster CPU decompression. Default: 3, same in cv2. 22 | 23 | input_folder (str): Path to the input folder. 24 | save_folder (str): Path to save folder. 25 | crop_size (int): Crop size. 26 | step (int): Step for overlapped sliding window. 27 | thresh_size (int): Threshold size. Patches whose size is lower 28 | than thresh_size will be dropped. 29 | 30 | Usage: 31 | For each folder, run this script. 32 | Typically, there are four folders to be processed for DIV2K dataset. 33 | DIV2K_train_HR 34 | DIV2K_train_LR_bicubic/X2 35 | DIV2K_train_LR_bicubic/X3 36 | DIV2K_train_LR_bicubic/X4 37 | After process, each sub_folder should have the same number of 38 | subimages. 39 | Remember to modify opt configurations according to your settings. 40 | """ 41 | 42 | opt = {} 43 | opt['n_thread'] = 20 44 | opt['compression_level'] = 3 45 | 46 | opt['input_folder'] = '../../../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/' 47 | opt['save_folder'] = '../../../datasets/HQ_sub/OST_train_HR_sub/' 48 | opt['crop_size'] = 320 49 | opt['step'] = 160 50 | opt['thresh_size'] = 0 51 | 52 | # HR images 53 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_HR' 54 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_HR_sub' 55 | opt['crop_size'] = 480 56 | opt['step'] = 240 57 | opt['thresh_size'] = 0 58 | extract_subimages(opt) 59 | 60 | # LRx2 images 61 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2' 62 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub' 63 | opt['crop_size'] = 240 64 | opt['step'] = 120 65 | opt['thresh_size'] = 0 66 | extract_subimages(opt) 67 | 68 | # LRx3 images 69 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3' 70 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub' 71 | opt['crop_size'] = 160 72 | opt['step'] = 80 73 | opt['thresh_size'] = 0 74 | extract_subimages(opt) 75 | 76 | # LRx4 images 77 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4' 78 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub' 79 | opt['crop_size'] = 120 80 | opt['step'] = 60 81 | opt['thresh_size'] = 0 82 | extract_subimages(opt) 83 | 84 | 85 | def extract_subimages(opt): 86 | """Crop images to subimages. 87 | 88 | Args: 89 | opt (dict): Configuration dict. It contains: 90 | input_folder (str): Path to the input folder. 91 | save_folder (str): Path to save folder. 92 | n_thread (int): Thread number. 93 | """ 94 | input_folder = opt['input_folder'] 95 | save_folder = opt['save_folder'] 96 | if not osp.exists(save_folder): 97 | os.makedirs(save_folder) 98 | print(f'mkdir {save_folder} ...') 99 | else: 100 | print(f'Folder {save_folder} already exists. Exit.') 101 | sys.exit(1) 102 | 103 | img_list = list(scandir(input_folder, recursive=True, full_path=True)) 104 | 105 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 106 | pool = Pool(opt['n_thread']) 107 | for path in img_list: 108 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 109 | pool.close() 110 | pool.join() 111 | pbar.close() 112 | print('All processes done.') 113 | 114 | 115 | def worker(path, opt): 116 | """Worker for each process. 117 | 118 | Args: 119 | path (str): Image path. 120 | opt (dict): Configuration dict. It contains: 121 | crop_size (int): Crop size. 122 | step (int): Step for overlapped sliding window. 123 | thresh_size (int): Threshold size. Patches whose size is lower 124 | than thresh_size will be dropped. 125 | save_folder (str): Path to save folder. 126 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 127 | 128 | Returns: 129 | process_info (str): Process information displayed in progress bar. 130 | """ 131 | crop_size = opt['crop_size'] 132 | step = opt['step'] 133 | thresh_size = opt['thresh_size'] 134 | img_name, extension = osp.splitext(osp.basename(path)) 135 | 136 | # remove the x2, x3, x4 and x8 in the filename for DIV2K 137 | img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') 138 | 139 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 140 | 141 | h, w = img.shape[0:2] 142 | h_space = np.arange(0, h - crop_size + 1, step) 143 | if h - (h_space[-1] + crop_size) > thresh_size: 144 | h_space = np.append(h_space, h - crop_size) 145 | w_space = np.arange(0, w - crop_size + 1, step) 146 | if w - (w_space[-1] + crop_size) > thresh_size: 147 | w_space = np.append(w_space, w - crop_size) 148 | 149 | index = 0 150 | for x in h_space: 151 | for y in w_space: 152 | index += 1 153 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 154 | cropped_img = np.ascontiguousarray(cropped_img) 155 | cv2.imwrite( 156 | osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img, 157 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 158 | process_info = f'Processing {img_name} ...' 159 | return process_info 160 | 161 | 162 | if __name__ == '__main__': 163 | main() 164 | -------------------------------------------------------------------------------- /scripts/data_preparation/generate_meta_info.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from PIL import Image 3 | 4 | from basicsr.utils import scandir 5 | 6 | 7 | def generate_meta_info_div2k(): 8 | """Generate meta info for DIV2K dataset. 9 | """ 10 | 11 | gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/' 12 | meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt' 13 | 14 | img_list = sorted(list(scandir(gt_folder))) 15 | 16 | with open(meta_info_txt, 'w') as f: 17 | for idx, img_path in enumerate(img_list): 18 | img = Image.open(osp.join(gt_folder, img_path)) # lazy load 19 | width, height = img.size 20 | mode = img.mode 21 | if mode == 'RGB': 22 | n_channel = 3 23 | elif mode == 'L': 24 | n_channel = 1 25 | else: 26 | raise ValueError(f'Unsupported mode {mode}.') 27 | 28 | info = f'{img_path} ({height},{width},{n_channel})' 29 | print(idx + 1, info) 30 | f.write(f'{info}\n') 31 | 32 | 33 | if __name__ == '__main__': 34 | generate_meta_info_div2k() 35 | -------------------------------------------------------------------------------- /scripts/data_preparation/prepare_hifacegan_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | from tqdm import tqdm 4 | 5 | 6 | class Mosaic16x: 7 | """ 8 | Mosaic16x: A customized image augmentor for 16-pixel mosaic 9 | By default it replaces each pixel value with the mean value 10 | of its 16x16 neighborhood 11 | """ 12 | 13 | def augment_image(self, x): 14 | h, w = x.shape[:2] 15 | x = x.astype('float') # avoid overflow for uint8 16 | irange, jrange = (h + 15) // 16, (w + 15) // 16 17 | for i in range(irange): 18 | for j in range(jrange): 19 | mean = x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16].mean(axis=(0, 1)) 20 | x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16] = mean 21 | 22 | return x.astype('uint8') 23 | 24 | 25 | class DegradationSimulator: 26 | """ 27 | Generating training/testing data pairs on the fly. 28 | The degradation script is aligned with HiFaceGAN paper settings. 29 | 30 | Args: 31 | opt(str | op): Config for degradation script, with degradation type and parameters 32 | Custom degradation is possible by passing an inherited class from ia.augmentors 33 | """ 34 | 35 | def __init__(self, ): 36 | import imgaug.augmenters as ia 37 | self.default_deg_templates = { 38 | 'sr4x': 39 | ia.Sequential([ 40 | # It's almost like a 4x bicubic downsampling 41 | ia.Resize((0.25000, 0.25001), cv2.INTER_AREA), 42 | ia.Resize({ 43 | 'height': 512, 44 | 'width': 512 45 | }, cv2.INTER_CUBIC), 46 | ]), 47 | 'sr4x8x': 48 | ia.Sequential([ 49 | ia.Resize((0.125, 0.25), cv2.INTER_AREA), 50 | ia.Resize({ 51 | 'height': 512, 52 | 'width': 512 53 | }, cv2.INTER_CUBIC), 54 | ]), 55 | 'denoise': 56 | ia.OneOf([ 57 | ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True), 58 | ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True), 59 | ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True), 60 | ]), 61 | 'deblur': 62 | ia.OneOf([ 63 | ia.MotionBlur(k=(10, 20)), 64 | ia.GaussianBlur((3.0, 8.0)), 65 | ]), 66 | 'jpeg': 67 | ia.JpegCompression(compression=(50, 85)), 68 | '16x': 69 | Mosaic16x(), 70 | } 71 | 72 | rand_deg_list = [ 73 | self.default_deg_templates['deblur'], 74 | self.default_deg_templates['denoise'], 75 | self.default_deg_templates['jpeg'], 76 | self.default_deg_templates['sr4x8x'], 77 | ] 78 | self.default_deg_templates['face_renov'] = ia.Sequential(rand_deg_list, random_order=True) 79 | 80 | def create_training_dataset(self, deg, gt_folder, lq_folder=None): 81 | from imgaug.augmenters.meta import Augmenter # baseclass 82 | """ 83 | Create a degradation simulator and apply it to GT images on the fly 84 | Save the degraded result in the lq_folder (if None, name it as GT_deg) 85 | """ 86 | if not lq_folder: 87 | suffix = deg if isinstance(deg, str) else 'custom' 88 | lq_folder = '_'.join([gt_folder.replace('gt', 'lq'), suffix]) 89 | print(lq_folder) 90 | os.makedirs(lq_folder, exist_ok=True) 91 | 92 | if isinstance(deg, str): 93 | assert deg in self.default_deg_templates, ( 94 | f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}') 95 | deg = self.default_deg_templates[deg] 96 | else: 97 | assert isinstance(deg, Augmenter), f'Deg must be either str|Augmenter, got {deg}' 98 | 99 | names = os.listdir(gt_folder) 100 | for name in tqdm(names): 101 | gt = cv2.imread(os.path.join(gt_folder, name)) 102 | lq = deg.augment_image(gt) 103 | # pack = np.concatenate([lq, gt], axis=0) 104 | cv2.imwrite(os.path.join(lq_folder, name), lq) 105 | 106 | print('Dataset prepared.') 107 | 108 | 109 | if __name__ == '__main__': 110 | simuator = DegradationSimulator() 111 | gt_folder = 'datasets/FFHQ_512_gt' 112 | deg = 'sr4x' 113 | simuator.create_training_dataset(deg, gt_folder) 114 | -------------------------------------------------------------------------------- /scripts/data_preparation/regroup_reds_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | 5 | def regroup_reds_dataset(train_path, val_path): 6 | """Regroup original REDS datasets. 7 | 8 | We merge train and validation data into one folder, and separate the 9 | validation clips in reds_dataset.py. 10 | There are 240 training clips (starting from 0 to 239), 11 | so we name the validation clip index starting from 240 to 269 (total 30 12 | validation clips). 13 | 14 | Args: 15 | train_path (str): Path to the train folder. 16 | val_path (str): Path to the validation folder. 17 | """ 18 | # move the validation data to the train folder 19 | val_folders = glob.glob(os.path.join(val_path, '*')) 20 | for folder in val_folders: 21 | new_folder_idx = int(folder.split('/')[-1]) + 240 22 | os.system(f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}') 23 | 24 | 25 | if __name__ == '__main__': 26 | # train_sharp 27 | train_path = 'datasets/REDS/train_sharp' 28 | val_path = 'datasets/REDS/val_sharp' 29 | regroup_reds_dataset(train_path, val_path) 30 | 31 | # train_sharp_bicubic 32 | train_path = 'datasets/REDS/train_sharp_bicubic/X4' 33 | val_path = 'datasets/REDS/val_sharp_bicubic/X4' 34 | regroup_reds_dataset(train_path, val_path) 35 | -------------------------------------------------------------------------------- /scripts/download_gdrive.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from basicsr.utils.download_util import download_file_from_google_drive 4 | 5 | if __name__ == '__main__': 6 | parser = argparse.ArgumentParser() 7 | 8 | parser.add_argument('--id', type=str, help='File id') 9 | parser.add_argument('--output', type=str, help='Save path') 10 | args = parser.parse_args() 11 | 12 | download_file_from_google_drive(args.id, args.save_path) 13 | -------------------------------------------------------------------------------- /scripts/download_pretrained_models.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from os import path as osp 4 | 5 | from basicsr.utils.download_util import download_file_from_google_drive 6 | 7 | 8 | def download_pretrained_models(method, file_ids): 9 | save_path_root = f'./experiments/pretrained_models/{method}' 10 | os.makedirs(save_path_root, exist_ok=True) 11 | 12 | for file_name, file_id in file_ids.items(): 13 | save_path = osp.abspath(osp.join(save_path_root, file_name)) 14 | if osp.exists(save_path): 15 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n') 16 | if user_response.lower() == 'y': 17 | print(f'Covering {file_name} to {save_path}') 18 | download_file_from_google_drive(file_id, save_path) 19 | elif user_response.lower() == 'n': 20 | print(f'Skipping {file_name}') 21 | else: 22 | raise ValueError('Wrong input. Only accepts Y/N.') 23 | else: 24 | print(f'Downloading {file_name} to {save_path}') 25 | download_file_from_google_drive(file_id, save_path) 26 | 27 | 28 | if __name__ == '__main__': 29 | parser = argparse.ArgumentParser() 30 | 31 | parser.add_argument( 32 | 'method', 33 | type=str, 34 | help=("Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. " 35 | "Set to 'all' to download all the models.")) 36 | args = parser.parse_args() 37 | 38 | file_ids = { 39 | 'ESRGAN': { 40 | 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name 41 | '1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id 42 | 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM' 43 | }, 44 | 'EDVR': { 45 | 'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb', 46 | 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy', 47 | 'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny', 48 | 'EDVR_M_x4_SR_REDS_official-32075921.pth': '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6', 49 | 'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl', 50 | 'EDVR_L_deblur_REDS_official-ca46bd8c.pth': '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE', 51 | 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW' 52 | }, 53 | 'StyleGAN': { 54 | 'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth': '1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g', 55 | 'stylegan2_ffhq_config_f_1024_discriminator_official-a386354a.pth': '1nPqCxm8TkDU3IvXdHCzPUxlBwR5Pd78G', 56 | 'stylegan2_cat_config_f_256_official-0a9173ad.pth': '1gfJkX6XO5pJ2J8LyMdvUgGldz7xwWpBJ', 57 | 'stylegan2_cat_config_f_256_discriminator_official-2c97fd08.pth': '1hy5FEQQl28XvfqpiWvSBd8YnIzsyDRb7', 58 | 'stylegan2_church_config_f_256_official-44ba63bf.pth': '1FCQMZXeOKZyl-xYKbl1Y_x2--rFl-1N_', 59 | 'stylegan2_church_config_f_256_discriminator_official-20cd675b.pth': # noqa: E501 60 | '1BS9ODHkUkhfTGFVfR6alCMGtr9nGm9ox', 61 | 'stylegan2_car_config_f_512_official-e8fcab4f.pth': '14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva', 62 | 'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth': '1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2', 63 | 'stylegan2_horse_config_f_256_official-26d57fee.pth': '12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx', 64 | 'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4' 65 | }, 66 | 'EDSR': { 67 | 'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV', 68 | 'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5', 69 | 'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn', 70 | 'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU', 71 | 'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su', 72 | 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl' 73 | }, 74 | 'DUF': { 75 | 'DUF_x2_16L_official-39537cb9.pth': '1e91cEZOlUUk35keK9EnuK0F54QegnUKo', 76 | 'DUF_x3_16L_official-34ce53ec.pth': '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76', 77 | 'DUF_x4_16L_official-bf8f0cfa.pth': '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J', 78 | 'DUF_x4_28L_official-cbada450.pth': '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4', 79 | 'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T' 80 | }, 81 | 'TOF': { 82 | 'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb' 83 | }, 84 | 'DFDNet': { 85 | 'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79', 86 | 'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe' 87 | }, 88 | 'dlib': { 89 | 'mmod_human_face_detector-4cb19393.dat': '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL', 90 | 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F', 91 | 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni' 92 | }, 93 | 'flownet': { 94 | 'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF' 95 | }, 96 | 'BasicVSR': { 97 | 'BasicVSR_REDS4-543c8261.pth': '1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p', 98 | 'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth': '1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW', 99 | 'BasicVSR_Vimeo90K_BIx4-2a29695a.pth': '1ykIu1jv5wo95Kca2TjlieJFxeV4VVfHP', 100 | 'EDVR_REDS_pretrained_for_IconVSR-f62a2f1e.pth': '1ShfwddugTmT3_kB8VL6KpCMrIpEO5sBi', 101 | 'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth': '16vR262NDVyVv5Q49xp2Sb-Llu05f63tt', 102 | 'IconVSR_REDS-aaa5367f.pth': '1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz', 103 | 'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth': '13lp55s-YTd-fApx8tTy24bbHsNIGXdAH', 104 | 'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g' 105 | } 106 | } 107 | 108 | if args.method == 'all': 109 | for method in file_ids.keys(): 110 | download_pretrained_models(method, file_ids[method]) 111 | else: 112 | download_pretrained_models(args.method, file_ids[args.method]) 113 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/backprojection.m: -------------------------------------------------------------------------------- 1 | function [im_h] = backprojection(im_h, im_l, maxIter) 2 | 3 | [row_l, col_l,~] = size(im_l); 4 | [row_h, col_h,~] = size(im_h); 5 | 6 | p = fspecial('gaussian', 5, 1); 7 | p = p.^2; 8 | p = p./sum(p(:)); 9 | 10 | im_l = double(im_l); 11 | im_h = double(im_h); 12 | 13 | for ii = 1:maxIter 14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); 15 | im_diff = im_l - im_l_s; 16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); 17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); 18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); 19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); 20 | end 21 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/main_bp.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20bp'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | %tic 19 | im_out = backprojection(im_out, im_LR, max_iter); 20 | %toc 21 | imwrite(im_out, fullfile(save_folder, im_name)); 22 | end 23 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/back_projection/main_reverse_filter.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20if'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | J = imresize(im_LR,4,'bicubic'); 19 | %tic 20 | for m = 1:max_iter 21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); 22 | end 23 | %toc 24 | imwrite(im_out, fullfile(save_folder, im_name)); 25 | end 26 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/generate_LR_Vimeo90K.m: -------------------------------------------------------------------------------- 1 | function generate_LR_Vimeo90K() 2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset 3 | 4 | up_scale = 4; 5 | mod_scale = 4; 6 | idx = 0; 7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png'); 8 | for i = 1 : length(filepaths) 9 | [~,imname,ext] = fileparts(filepaths(i).name); 10 | folder_path = filepaths(i).folder; 11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4'); 12 | if ~exist(save_LR_folder, 'dir') 13 | mkdir(save_LR_folder); 14 | end 15 | if isempty(imname) 16 | disp('Ignore . folder.'); 17 | elseif strcmp(imname, '.') 18 | disp('Ignore .. folder.'); 19 | else 20 | idx = idx + 1; 21 | str_result = sprintf('%d\t%s.\n', idx, imname); 22 | fprintf(str_result); 23 | % read image 24 | img = imread(fullfile(folder_path, [imname, ext])); 25 | img = im2double(img); 26 | % modcrop 27 | img = modcrop(img, mod_scale); 28 | % LR 29 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 30 | if exist('save_LR_folder', 'var') 31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png'])); 32 | end 33 | end 34 | end 35 | end 36 | 37 | %% modcrop 38 | function img = modcrop(img, modulo) 39 | if size(img,3) == 1 40 | sz = size(img); 41 | sz = sz - mod(sz, modulo); 42 | img = img(1:sz(1), 1:sz(2)); 43 | else 44 | tmpsz = size(img); 45 | sz = tmpsz(1:2); 46 | sz = sz - mod(sz, modulo); 47 | img = img(1:sz(1), 1:sz(2),:); 48 | end 49 | end 50 | -------------------------------------------------------------------------------- /scripts/matlab_scripts/generate_bicubic_img.m: -------------------------------------------------------------------------------- 1 | function generate_bicubic_img() 2 | %% matlab code to genetate mod images, bicubic-downsampled images and 3 | %% bicubic_upsampled images 4 | 5 | %% set configurations 6 | % comment the unnecessary lines 7 | input_folder = '../../datasets/Set5/original'; 8 | save_mod_folder = '../../datasets/Set5/GTmod12'; 9 | save_lr_folder = '../../datasets/Set5/LRbicx2'; 10 | % save_bic_folder = ''; 11 | 12 | mod_scale = 12; 13 | up_scale = 2; 14 | 15 | if exist('save_mod_folder', 'var') 16 | if exist(save_mod_folder, 'dir') 17 | disp(['It will cover ', save_mod_folder]); 18 | else 19 | mkdir(save_mod_folder); 20 | end 21 | end 22 | if exist('save_lr_folder', 'var') 23 | if exist(save_lr_folder, 'dir') 24 | disp(['It will cover ', save_lr_folder]); 25 | else 26 | mkdir(save_lr_folder); 27 | end 28 | end 29 | if exist('save_bic_folder', 'var') 30 | if exist(save_bic_folder, 'dir') 31 | disp(['It will cover ', save_bic_folder]); 32 | else 33 | mkdir(save_bic_folder); 34 | end 35 | end 36 | 37 | idx = 0; 38 | filepaths = dir(fullfile(input_folder,'*.*')); 39 | for i = 1 : length(filepaths) 40 | [paths, img_name, ext] = fileparts(filepaths(i).name); 41 | if isempty(img_name) 42 | disp('Ignore . folder.'); 43 | elseif strcmp(img_name, '.') 44 | disp('Ignore .. folder.'); 45 | else 46 | idx = idx + 1; 47 | str_result = sprintf('%d\t%s.\n', idx, img_name); 48 | fprintf(str_result); 49 | 50 | % read image 51 | img = imread(fullfile(input_folder, [img_name, ext])); 52 | img = im2double(img); 53 | 54 | % modcrop 55 | img = modcrop(img, mod_scale); 56 | if exist('save_mod_folder', 'var') 57 | imwrite(img, fullfile(save_mod_folder, [img_name, '.png'])); 58 | end 59 | 60 | % LR 61 | im_lr = imresize(img, 1/up_scale, 'bicubic'); 62 | if exist('save_lr_folder', 'var') 63 | imwrite(im_lr, fullfile(save_lr_folder, [img_name, '.png'])); 64 | end 65 | 66 | % Bicubic 67 | if exist('save_bic_folder', 'var') 68 | im_bicubic = imresize(im_lr, up_scale, 'bicubic'); 69 | imwrite(im_bicubic, fullfile(save_bic_folder, [img_name, '.png'])); 70 | end 71 | end 72 | end 73 | end 74 | 75 | %% modcrop 76 | function img = modcrop(img, modulo) 77 | if size(img,3) == 1 78 | sz = size(img); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2)); 81 | else 82 | tmpsz = size(img); 83 | sz = tmpsz(1:2); 84 | sz = sz - mod(sz, modulo); 85 | img = img(1:sz(1), 1:sz(2),:); 86 | end 87 | end 88 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_fid_folder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from basicsr.data import build_dataset 8 | from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3 9 | 10 | 11 | def calculate_fid_folder(): 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('folder', type=str, help='Path to the folder.') 16 | parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.') 17 | parser.add_argument('--batch_size', type=int, default=64) 18 | parser.add_argument('--num_sample', type=int, default=50000) 19 | parser.add_argument('--num_workers', type=int, default=4) 20 | parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb') 21 | args = parser.parse_args() 22 | 23 | # inception model 24 | inception = load_patched_inception_v3(device) 25 | 26 | # create dataset 27 | opt = {} 28 | opt['name'] = 'SingleImageDataset' 29 | opt['type'] = 'SingleImageDataset' 30 | opt['dataroot_lq'] = args.folder 31 | opt['io_backend'] = dict(type=args.backend) 32 | opt['mean'] = [0.5, 0.5, 0.5] 33 | opt['std'] = [0.5, 0.5, 0.5] 34 | dataset = build_dataset(opt) 35 | 36 | # create dataloader 37 | data_loader = DataLoader( 38 | dataset=dataset, 39 | batch_size=args.batch_size, 40 | shuffle=False, 41 | num_workers=args.num_workers, 42 | sampler=None, 43 | drop_last=False) 44 | args.num_sample = min(args.num_sample, len(dataset)) 45 | total_batch = math.ceil(args.num_sample / args.batch_size) 46 | 47 | def data_generator(data_loader, total_batch): 48 | for idx, data in enumerate(data_loader): 49 | if idx >= total_batch: 50 | break 51 | else: 52 | yield data['lq'] 53 | 54 | features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device) 55 | features = features.numpy() 56 | total_len = features.shape[0] 57 | features = features[:args.num_sample] 58 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.') 59 | 60 | sample_mean = np.mean(features, 0) 61 | sample_cov = np.cov(features, rowvar=False) 62 | 63 | # load the dataset stats 64 | stats = torch.load(args.fid_stats) 65 | real_mean = stats['mean'] 66 | real_cov = stats['cov'] 67 | 68 | # calculate FID metric 69 | fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) 70 | print('fid:', fid) 71 | 72 | 73 | if __name__ == '__main__': 74 | calculate_fid_folder() 75 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_fid_stats_from_datasets.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch.utils.data import DataLoader 6 | 7 | from basicsr.data import build_dataset 8 | from basicsr.metrics.fid import extract_inception_features, load_patched_inception_v3 9 | 10 | 11 | def calculate_stats_from_dataset(): 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--num_sample', type=int, default=50000) 16 | parser.add_argument('--batch_size', type=int, default=64) 17 | parser.add_argument('--size', type=int, default=512) 18 | parser.add_argument('--dataroot', type=str, default='datasets/ffhq') 19 | args = parser.parse_args() 20 | 21 | # inception model 22 | inception = load_patched_inception_v3(device) 23 | 24 | # create dataset 25 | opt = {} 26 | opt['name'] = 'FFHQ' 27 | opt['type'] = 'FFHQDataset' 28 | opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb' 29 | opt['io_backend'] = dict(type='lmdb') 30 | opt['use_hflip'] = False 31 | opt['mean'] = [0.5, 0.5, 0.5] 32 | opt['std'] = [0.5, 0.5, 0.5] 33 | dataset = build_dataset(opt) 34 | 35 | # create dataloader 36 | data_loader = DataLoader( 37 | dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False) 38 | total_batch = math.ceil(args.num_sample / args.batch_size) 39 | 40 | def data_generator(data_loader, total_batch): 41 | for idx, data in enumerate(data_loader): 42 | if idx >= total_batch: 43 | break 44 | else: 45 | yield data['gt'] 46 | 47 | features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device) 48 | features = features.numpy() 49 | total_len = features.shape[0] 50 | features = features[:args.num_sample] 51 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.') 52 | mean = np.mean(features, 0) 53 | cov = np.cov(features, rowvar=False) 54 | 55 | save_path = f'inception_{opt["name"]}_{args.size}.pth' 56 | torch.save( 57 | dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False) 58 | 59 | 60 | if __name__ == '__main__': 61 | calculate_stats_from_dataset() 62 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_lpips.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import glob 3 | import numpy as np 4 | import os.path as osp 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.utils import img2tensor 8 | 9 | try: 10 | import lpips 11 | except ImportError: 12 | print('Please install lpips: pip install lpips') 13 | 14 | 15 | def main(): 16 | # Configurations 17 | # ------------------------------------------------------------------------- 18 | folder_gt = 'datasets/celeba/celeba_512_validation' 19 | folder_restored = 'datasets/celeba/celeba_512_validation_lq' 20 | # crop_border = 4 21 | suffix = '' 22 | # ------------------------------------------------------------------------- 23 | loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1] 24 | lpips_all = [] 25 | img_list = sorted(glob.glob(osp.join(folder_gt, '*'))) 26 | 27 | mean = [0.5, 0.5, 0.5] 28 | std = [0.5, 0.5, 0.5] 29 | for i, img_path in enumerate(img_list): 30 | basename, ext = osp.splitext(osp.basename(img_path)) 31 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 32 | img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype( 33 | np.float32) / 255. 34 | 35 | img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True) 36 | # norm to [-1, 1] 37 | normalize(img_gt, mean, std, inplace=True) 38 | normalize(img_restored, mean, std, inplace=True) 39 | 40 | # calculate lpips 41 | lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda()) 42 | 43 | print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.') 44 | lpips_all.append(lpips_val) 45 | 46 | print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}') 47 | 48 | 49 | if __name__ == '__main__': 50 | main() 51 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_niqe.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import os 4 | import warnings 5 | 6 | from basicsr.metrics import calculate_niqe 7 | from basicsr.utils import scandir 8 | 9 | 10 | def main(args): 11 | 12 | niqe_all = [] 13 | img_list = sorted(scandir(args.input, recursive=True, full_path=True)) 14 | 15 | for i, img_path in enumerate(img_list): 16 | basename, _ = os.path.splitext(os.path.basename(img_path)) 17 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 18 | 19 | with warnings.catch_warnings(): 20 | warnings.simplefilter('ignore', category=RuntimeWarning) 21 | niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y') 22 | print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}') 23 | niqe_all.append(niqe_score) 24 | 25 | print(args.input) 26 | print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}') 27 | 28 | 29 | if __name__ == '__main__': 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--input', type=str, default='datasets/val_set14/Set14', help='Input path') 32 | parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side') 33 | args = parser.parse_args() 34 | main(args) 35 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | from os import path as osp 5 | 6 | from basicsr.metrics import calculate_psnr, calculate_ssim 7 | from basicsr.utils import scandir 8 | from basicsr.utils.matlab_functions import bgr2ycbcr 9 | 10 | 11 | def main(args): 12 | """Calculate PSNR and SSIM for images. 13 | """ 14 | psnr_all = [] 15 | ssim_all = [] 16 | img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True))) 17 | img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True))) 18 | 19 | if args.test_y_channel: 20 | print('Testing Y channel.') 21 | else: 22 | print('Testing RGB channels.') 23 | 24 | for i, img_path in enumerate(img_list_gt): 25 | basename, ext = osp.splitext(osp.basename(img_path)) 26 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 27 | if args.suffix == '': 28 | img_path_restored = img_list_restored[i] 29 | else: 30 | img_path_restored = osp.join(args.restored, basename + args.suffix + ext) 31 | img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 32 | 33 | if args.correct_mean_var: 34 | mean_l = [] 35 | std_l = [] 36 | for j in range(3): 37 | mean_l.append(np.mean(img_gt[:, :, j])) 38 | std_l.append(np.std(img_gt[:, :, j])) 39 | for j in range(3): 40 | # correct twice 41 | mean = np.mean(img_restored[:, :, j]) 42 | img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j] 43 | std = np.std(img_restored[:, :, j]) 44 | img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] 45 | 46 | mean = np.mean(img_restored[:, :, j]) 47 | img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j] 48 | std = np.std(img_restored[:, :, j]) 49 | img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j] 50 | 51 | if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3: 52 | img_gt = bgr2ycbcr(img_gt, y_only=True) 53 | img_restored = bgr2ycbcr(img_restored, y_only=True) 54 | 55 | # calculate PSNR and SSIM 56 | psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC') 57 | ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC') 58 | print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 59 | psnr_all.append(psnr) 60 | ssim_all.append(ssim) 61 | print(args.gt) 62 | print(args.restored) 63 | print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}') 64 | 65 | 66 | if __name__ == '__main__': 67 | parser = argparse.ArgumentParser() 68 | parser.add_argument('--gt', type=str, default='datasets/val_set14/Set14', help='Path to gt (Ground-Truth)') 69 | parser.add_argument('--restored', type=str, default='results/Set14', help='Path to restored images') 70 | parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side') 71 | parser.add_argument('--suffix', type=str, default='', help='Suffix for restored images') 72 | parser.add_argument( 73 | '--test_y_channel', 74 | action='store_true', 75 | help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.') 76 | parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.') 77 | args = parser.parse_args() 78 | main(args) 79 | -------------------------------------------------------------------------------- /scripts/metrics/calculate_stylegan2_fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import numpy as np 4 | import torch 5 | from torch import nn 6 | 7 | from basicsr.archs.stylegan2_arch import StyleGAN2Generator 8 | from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3 9 | 10 | 11 | def calculate_stylegan2_fid(): 12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.') 16 | parser.add_argument('fid_stats', type=str, help='Path to the dataset fid statistics.') 17 | parser.add_argument('--size', type=int, default=256) 18 | parser.add_argument('--channel_multiplier', type=int, default=2) 19 | parser.add_argument('--batch_size', type=int, default=64) 20 | parser.add_argument('--num_sample', type=int, default=50000) 21 | parser.add_argument('--truncation', type=float, default=1) 22 | parser.add_argument('--truncation_mean', type=int, default=4096) 23 | args = parser.parse_args() 24 | 25 | # create stylegan2 model 26 | generator = StyleGAN2Generator( 27 | out_size=args.size, 28 | num_style_feat=512, 29 | num_mlp=8, 30 | channel_multiplier=args.channel_multiplier, 31 | resample_kernel=(1, 3, 3, 1)) 32 | generator.load_state_dict(torch.load(args.ckpt)['params_ema']) 33 | generator = nn.DataParallel(generator).eval().to(device) 34 | 35 | if args.truncation < 1: 36 | with torch.no_grad(): 37 | truncation_latent = generator.mean_latent(args.truncation_mean) 38 | else: 39 | truncation_latent = None 40 | 41 | # inception model 42 | inception = load_patched_inception_v3(device) 43 | 44 | total_batch = math.ceil(args.num_sample / args.batch_size) 45 | 46 | def sample_generator(total_batch): 47 | for _ in range(total_batch): 48 | with torch.no_grad(): 49 | latent = torch.randn(args.batch_size, 512, device=device) 50 | samples, _ = generator([latent], truncation=args.truncation, truncation_latent=truncation_latent) 51 | yield samples 52 | 53 | features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device) 54 | features = features.numpy() 55 | total_len = features.shape[0] 56 | features = features[:args.num_sample] 57 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.') 58 | sample_mean = np.mean(features, 0) 59 | sample_cov = np.cov(features, rowvar=False) 60 | 61 | # load the dataset stats 62 | stats = torch.load(args.fid_stats) 63 | real_mean = stats['mean'] 64 | real_cov = stats['cov'] 65 | 66 | # calculate FID metric 67 | fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov) 68 | print('fid:', fid) 69 | 70 | 71 | if __name__ == '__main__': 72 | calculate_stylegan2_fid() 73 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_dfdnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from basicsr.archs.dfdnet_arch import DFDNet 4 | from basicsr.archs.vgg_arch import NAMES 5 | 6 | 7 | def convert_net(ori_net, crt_net): 8 | 9 | for crt_k, _ in crt_net.items(): 10 | # vgg feature extractor 11 | if 'vgg_extractor' in crt_k: 12 | ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model') 13 | if 'mean' in crt_k: 14 | ori_k = ori_k.replace('mean', 'RGB_mean') 15 | elif 'std' in crt_k: 16 | ori_k = ori_k.replace('std', 'RGB_std') 17 | else: 18 | idx = NAMES['vgg19'].index(crt_k.split('.')[2]) 19 | if 'weight' in crt_k: 20 | ori_k = f'VggExtract.model.features.{idx}.weight' 21 | else: 22 | ori_k = f'VggExtract.model.features.{idx}.bias' 23 | elif 'attn_blocks' in crt_k: 24 | if 'left_eye' in crt_k: 25 | ori_k = crt_k.replace('attn_blocks.left_eye', 'le') 26 | elif 'right_eye' in crt_k: 27 | ori_k = crt_k.replace('attn_blocks.right_eye', 're') 28 | elif 'mouth' in crt_k: 29 | ori_k = crt_k.replace('attn_blocks.mouth', 'mo') 30 | elif 'nose' in crt_k: 31 | ori_k = crt_k.replace('attn_blocks.nose', 'no') 32 | else: 33 | raise ValueError('Wrong!') 34 | elif 'multi_scale_dilation' in crt_k: 35 | if 'conv_blocks' in crt_k: 36 | _, _, c, d, e = crt_k.split('.') 37 | ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}' 38 | else: 39 | ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi') 40 | 41 | elif crt_k.startswith('upsample'): 42 | ori_k = crt_k.replace('upsample', 'up') 43 | if 'scale_block' in crt_k: 44 | ori_k = ori_k.replace('scale_block', 'ScaleModel1') 45 | elif 'shift_block' in crt_k: 46 | ori_k = ori_k.replace('shift_block', 'ShiftModel1') 47 | 48 | elif 'upsample4' in crt_k and 'body' in crt_k: 49 | ori_k = ori_k.replace('body', 'Model') 50 | 51 | else: 52 | print('unprocess key: ', crt_k) 53 | 54 | # replace 55 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 56 | raise ValueError('Wrong tensor size: \n' 57 | f'crt_net: {crt_net[crt_k].size()}\n' 58 | f'ori_net: {ori_net[ori_k].size()}') 59 | else: 60 | crt_net[crt_k] = ori_net[ori_k] 61 | 62 | return crt_net 63 | 64 | 65 | if __name__ == '__main__': 66 | ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth') 67 | dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth') 68 | crt_net = dfd_net.state_dict() 69 | crt_net_params = convert_net(ori_net, crt_net) 70 | 71 | torch.save( 72 | dict(params=crt_net_params), 73 | 'experiments/pretrained_models/DFDNet/DFDNet_official.pth', 74 | _use_new_zipfile_serialization=False) 75 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_ridnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | from basicsr.archs.ridnet_arch import RIDNet 5 | 6 | if __name__ == '__main__': 7 | ori_net_checkpoint = torch.load( 8 | 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage) 9 | rid_net = RIDNet(3, 64, 3) 10 | new_ridnet_dict = OrderedDict() 11 | 12 | rid_net_namelist = [] 13 | for name, param in rid_net.named_parameters(): 14 | rid_net_namelist.append(name) 15 | 16 | count = 0 17 | for name, param in ori_net_checkpoint.items(): 18 | new_ridnet_dict[rid_net_namelist[count]] = param 19 | count += 1 20 | 21 | rid_net.load_state_dict(new_ridnet_dict) 22 | torch.save(rid_net.state_dict(), 'experiments/pretrained_models/RIDNet/RIDNet.pth') 23 | -------------------------------------------------------------------------------- /scripts/model_conversion/convert_stylegan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator, StyleGAN2Generator 4 | 5 | 6 | def convert_net_g(ori_net, crt_net): 7 | """Convert network generator.""" 8 | 9 | for crt_k, crt_v in crt_net.items(): 10 | if 'style_mlp' in crt_k: 11 | ori_k = crt_k.replace('style_mlp', 'style') 12 | elif 'constant_input.weight' in crt_k: 13 | ori_k = crt_k.replace('constant_input.weight', 'input.input') 14 | # style conv1 15 | elif 'style_conv1.modulated_conv' in crt_k: 16 | ori_k = crt_k.replace('style_conv1.modulated_conv', 'conv1.conv') 17 | elif 'style_conv1' in crt_k: 18 | if crt_v.shape == torch.Size([1]): 19 | ori_k = crt_k.replace('style_conv1', 'conv1.noise') 20 | else: 21 | ori_k = crt_k.replace('style_conv1', 'conv1') 22 | # style conv 23 | elif 'style_convs' in crt_k: 24 | ori_k = crt_k.replace('style_convs', 'convs').replace('modulated_conv', 'conv') 25 | if crt_v.shape == torch.Size([1]): 26 | ori_k = ori_k.replace('.weight', '.noise.weight') 27 | # to_rgb1 28 | elif 'to_rgb1.modulated_conv' in crt_k: 29 | ori_k = crt_k.replace('to_rgb1.modulated_conv', 'to_rgb1.conv') 30 | # to_rgbs 31 | elif 'to_rgbs' in crt_k: 32 | ori_k = crt_k.replace('modulated_conv', 'conv') 33 | elif 'noises' in crt_k: 34 | ori_k = crt_k.replace('.noise', '.noise_') 35 | else: 36 | ori_k = crt_k 37 | 38 | # replace 39 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 40 | raise ValueError('Wrong tensor size: \n' 41 | f'crt_net: {crt_net[crt_k].size()}\n' 42 | f'ori_net: {ori_net[ori_k].size()}') 43 | else: 44 | crt_net[crt_k] = ori_net[ori_k] 45 | 46 | return crt_net 47 | 48 | 49 | def convert_net_d(ori_net, crt_net): 50 | """Convert network discriminator.""" 51 | 52 | for crt_k, _ in crt_net.items(): 53 | if 'conv_body' in crt_k: 54 | ori_k = crt_k.replace('conv_body', 'convs') 55 | else: 56 | ori_k = crt_k 57 | 58 | # replace 59 | if crt_net[crt_k].size() != ori_net[ori_k].size(): 60 | raise ValueError('Wrong tensor size: \n' 61 | f'crt_net: {crt_net[crt_k].size()}\n' 62 | f'ori_net: {ori_net[ori_k].size()}') 63 | else: 64 | crt_net[crt_k] = ori_net[ori_k] 65 | return crt_net 66 | 67 | 68 | if __name__ == '__main__': 69 | """Convert official stylegan2 weights from stylegan2-pytorch.""" 70 | 71 | # configuration 72 | ori_net = torch.load('experiments/pretrained_models/stylegan2-ffhq.pth') 73 | save_path_g = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official.pth' # noqa: E501 74 | save_path_d = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_discriminator_official.pth' # noqa: E501 75 | out_size = 1024 76 | channel_multiplier = 1 77 | 78 | # convert generator 79 | crt_net = StyleGAN2Generator(out_size, num_style_feat=512, num_mlp=8, channel_multiplier=channel_multiplier) 80 | crt_net = crt_net.state_dict() 81 | 82 | crt_net_params_ema = convert_net_g(ori_net['g_ema'], crt_net) 83 | torch.save(dict(params_ema=crt_net_params_ema, latent_avg=ori_net['latent_avg']), save_path_g) 84 | 85 | # convert discriminator 86 | crt_net = StyleGAN2Discriminator(out_size, channel_multiplier=channel_multiplier) 87 | crt_net = crt_net.state_dict() 88 | 89 | crt_net_params = convert_net_d(ori_net['d'], crt_net) 90 | torch.save(dict(params=crt_net_params), save_path_d) 91 | -------------------------------------------------------------------------------- /scripts/publish_models.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import subprocess 3 | import torch 4 | from os import path as osp 5 | from torch.serialization import _is_zipfile, _open_file_like 6 | 7 | 8 | def update_sha(paths): 9 | print('# Update sha ...') 10 | for idx, path in enumerate(paths): 11 | print(f'{idx+1:03d}: Processing {path}') 12 | net = torch.load(path, map_location=torch.device('cpu')) 13 | basename = osp.basename(path) 14 | if 'params' not in net and 'params_ema' not in net: 15 | user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. ' 16 | 'Do you still want to continue? Y/N\n') 17 | if user_response.lower() == 'y': 18 | pass 19 | elif user_response.lower() == 'n': 20 | raise ValueError('Please modify..') 21 | else: 22 | raise ValueError('Wrong input. Only accepts Y/N.') 23 | 24 | if '-' in basename: 25 | # check whether the sha is the latest 26 | old_sha = basename.split('-')[1].split('.')[0] 27 | new_sha = subprocess.check_output(['sha256sum', path]).decode()[:8] 28 | if old_sha != new_sha: 29 | final_file = path.split('-')[0] + f'-{new_sha}.pth' 30 | print(f'\tSave from {path} to {final_file}') 31 | subprocess.Popen(['mv', path, final_file]) 32 | else: 33 | sha = subprocess.check_output(['sha256sum', path]).decode()[:8] 34 | final_file = path.split('.pth')[0] + f'-{sha}.pth' 35 | print(f'\tSave from {path} to {final_file}') 36 | subprocess.Popen(['mv', path, final_file]) 37 | 38 | 39 | def convert_to_backward_compatible_models(paths): 40 | """Convert to backward compatible pth files. 41 | 42 | PyTorch 1.6 uses a updated version of torch.save. In order to be compatible 43 | with previous PyTorch version, save it with 44 | _use_new_zipfile_serialization=False. 45 | """ 46 | print('# Convert to backward compatible pth files ...') 47 | for idx, path in enumerate(paths): 48 | print(f'{idx+1:03d}: Processing {path}') 49 | flag_need_conversion = False 50 | with _open_file_like(path, 'rb') as opened_file: 51 | if _is_zipfile(opened_file): 52 | flag_need_conversion = True 53 | 54 | if flag_need_conversion: 55 | net = torch.load(path, map_location=torch.device('cpu')) 56 | print('\tConverting to compatible pth file...') 57 | torch.save(net, path, _use_new_zipfile_serialization=False) 58 | 59 | 60 | if __name__ == '__main__': 61 | paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob('experiments/pretrained_models/**/*.pth') 62 | convert_to_backward_compatible_models(paths) 63 | update_sha(paths) 64 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | import torch 9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 10 | 11 | version_file = 'basicsr/version.py' 12 | 13 | 14 | def readme(): 15 | with open('README.md', encoding='utf-8') as f: 16 | content = f.read() 17 | return content 18 | 19 | 20 | def get_git_hash(): 21 | 22 | def _minimal_ext_cmd(cmd): 23 | # construct minimal environment 24 | env = {} 25 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 26 | v = os.environ.get(k) 27 | if v is not None: 28 | env[k] = v 29 | # LANGUAGE is used on win32 30 | env['LANGUAGE'] = 'C' 31 | env['LANG'] = 'C' 32 | env['LC_ALL'] = 'C' 33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 34 | return out 35 | 36 | try: 37 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 38 | sha = out.strip().decode('ascii') 39 | except OSError: 40 | sha = 'unknown' 41 | 42 | return sha 43 | 44 | 45 | def get_hash(): 46 | if os.path.exists('.git'): 47 | sha = get_git_hash()[:7] 48 | # currently ignore this 49 | # elif os.path.exists(version_file): 50 | # try: 51 | # from basicsr.version import __version__ 52 | # sha = __version__.split('+')[-1] 53 | # except ImportError: 54 | # raise ImportError('Unable to get git version') 55 | else: 56 | sha = 'unknown' 57 | 58 | return sha 59 | 60 | 61 | def write_version_py(): 62 | content = """# GENERATED VERSION FILE 63 | # TIME: {} 64 | __version__ = '{}' 65 | __gitsha__ = '{}' 66 | version_info = ({}) 67 | """ 68 | sha = get_hash() 69 | with open('VERSION', 'r') as f: 70 | SHORT_VERSION = f.read().strip() 71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 72 | 73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 74 | with open(version_file, 'w') as f: 75 | f.write(version_file_str) 76 | 77 | 78 | def get_version(): 79 | with open(version_file, 'r') as f: 80 | exec(compile(f.read(), version_file, 'exec')) 81 | return locals()['__version__'] 82 | 83 | 84 | def make_cuda_ext(name, module, sources, sources_cuda=None): 85 | if sources_cuda is None: 86 | sources_cuda = [] 87 | define_macros = [] 88 | extra_compile_args = {'cxx': []} 89 | 90 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 91 | define_macros += [('WITH_CUDA', None)] 92 | extension = CUDAExtension 93 | extra_compile_args['nvcc'] = [ 94 | '-D__CUDA_NO_HALF_OPERATORS__', 95 | '-D__CUDA_NO_HALF_CONVERSIONS__', 96 | '-D__CUDA_NO_HALF2_OPERATORS__', 97 | ] 98 | sources += sources_cuda 99 | else: 100 | print(f'Compiling {name} without CUDA') 101 | extension = CppExtension 102 | 103 | return extension( 104 | name=f'{module}.{name}', 105 | sources=[os.path.join(*module.split('.'), p) for p in sources], 106 | define_macros=define_macros, 107 | extra_compile_args=extra_compile_args) 108 | 109 | 110 | def get_requirements(filename='requirements.txt'): 111 | here = os.path.dirname(os.path.realpath(__file__)) 112 | with open(os.path.join(here, filename), 'r') as f: 113 | requires = [line.replace('\n', '') for line in f.readlines()] 114 | return requires 115 | 116 | 117 | if __name__ == '__main__': 118 | cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext 119 | if cuda_ext == 'True': 120 | ext_modules = [ 121 | make_cuda_ext( 122 | name='deform_conv_ext', 123 | module='basicsr.ops.dcn', 124 | sources=['src/deform_conv_ext.cpp'], 125 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 126 | make_cuda_ext( 127 | name='fused_act_ext', 128 | module='basicsr.ops.fused_act', 129 | sources=['src/fused_bias_act.cpp'], 130 | sources_cuda=['src/fused_bias_act_kernel.cu']), 131 | make_cuda_ext( 132 | name='upfirdn2d_ext', 133 | module='basicsr.ops.upfirdn2d', 134 | sources=['src/upfirdn2d.cpp'], 135 | sources_cuda=['src/upfirdn2d_kernel.cu']), 136 | ] 137 | else: 138 | ext_modules = [] 139 | 140 | # write_version_py() 141 | setup( 142 | name='basicsr', 143 | # version=get_version(), 144 | description='Open Source Image and Video Super-Resolution Toolbox', 145 | long_description=readme(), 146 | long_description_content_type='text/markdown', 147 | author='Xintao Wang', 148 | author_email='xintao.wang@outlook.com', 149 | keywords='computer vision, restoration, super resolution', 150 | url='https://github.com/xinntao/BasicSR', 151 | include_package_data=True, 152 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 153 | classifiers=[ 154 | 'Development Status :: 4 - Beta', 155 | 'License :: OSI Approved :: Apache Software License', 156 | 'Operating System :: OS Independent', 157 | 'Programming Language :: Python :: 3', 158 | 'Programming Language :: Python :: 3.7', 159 | 'Programming Language :: Python :: 3.8', 160 | ], 161 | license='Apache License 2.0', 162 | setup_requires=['cython', 'numpy'], 163 | install_requires=get_requirements(), 164 | ext_modules=ext_modules, 165 | cmdclass={'build_ext': BuildExtension}, 166 | zip_safe=False) 167 | -------------------------------------------------------------------------------- /test_recon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import os 5 | from tqdm import tqdm 6 | import torch 7 | from yaml import load 8 | import pdb 9 | import numpy as np 10 | import matplotlib as m 11 | import matplotlib.pyplot as plt 12 | from PIL import Image 13 | import pyiqa 14 | 15 | from basicsr.utils import img2tensor, tensor2img, imwrite 16 | from basicsr.archs.adacode_arch import AdaCodeSRNet 17 | from basicsr.archs.femasr_arch import FeMaSRNet 18 | from basicsr.archs.adacode_contrast_arch import AdaCodeSRNet_Contrast 19 | from basicsr.utils.download_util import load_file_from_url 20 | 21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 22 | 23 | #eval metrics 24 | metric_funcs = {} 25 | metric_funcs['psnr'] = pyiqa.create_metric('psnr', device=device, crop_border=4, test_y_channel=True) 26 | metric_funcs['ssim'] = pyiqa.create_metric('ssim', device=device, crop_border=4, test_y_channel=True) 27 | metric_funcs['lpips'] = pyiqa.create_metric('lpips', device=device) 28 | 29 | def main(args): 30 | """Inference demo for FeMaSR 31 | """ 32 | metric_results = {'psnr': 0, 'ssim': 0, 'lpips': 0} 33 | 34 | weight_path = args.weight 35 | 36 | # set up the model 37 | model_params = torch.load(weight_path)['params'] 38 | codebook_dim = np.array([v.size() for k,v in model_params.items() if 'quantize_group' in k]) 39 | codebook_dim_list = [] 40 | for k in codebook_dim: 41 | temp = k.tolist() 42 | temp.insert(0,32) 43 | codebook_dim_list.append(temp) 44 | # recon_model = FeMaSRNet(codebook_params=[[32, 512, 256]], LQ_stage=False, scale_factor=2).to(device) 45 | recon_model = AdaCodeSRNet_Contrast(codebook_params=codebook_dim_list, LQ_stage=False, AdaCode_stage=True, batch_size=2, weight_softmax=False).to(device) 46 | recon_model.load_state_dict(model_params, strict=False) 47 | recon_model.eval() 48 | 49 | os.makedirs(args.output, exist_ok=True) 50 | if os.path.isfile(args.input): 51 | paths = [args.input] 52 | else: 53 | paths = sorted(glob.glob(os.path.join(args.input, '*.png'))) 54 | 55 | pbar = tqdm(total=len(paths), unit='image') 56 | for idx, path in enumerate(paths): 57 | # try: 58 | img_name = os.path.basename(path) 59 | pbar.set_description(f'Test {img_name}') 60 | 61 | # recon 62 | img_HR = cv2.imread(path, cv2.IMREAD_UNCHANGED) 63 | img_HR_tensor = img2tensor(img_HR).to(device) / 255. 64 | img_HR_tensor = img_HR_tensor.unsqueeze(0) 65 | 66 | max_size = args.max_size ** 2 67 | h, w = img_HR_tensor.shape[2:] 68 | if h * w < max_size: 69 | output_HR = recon_model.test(img_HR_tensor, vis_weight=args.vis_weight) 70 | else: 71 | output_HR = recon_model.test_tile(img_HR_tensor, vis_weight=args.vis_weight) 72 | 73 | if args.vis_weight: 74 | weight_map = output_HR[1] 75 | vis_weight(weight_map, os.path.join(args.output, 'weight_map', img_name)) 76 | output = output_HR[0] 77 | else: 78 | output = output_HR 79 | output_img = tensor2img(output) 80 | 81 | imwrite(output_img, os.path.join(args.output, f'{img_name}')) 82 | 83 | for name in metric_funcs.keys(): 84 | metric_results[name] += metric_funcs[name](img_HR_tensor, output).item() 85 | pbar.update(1) 86 | # except: 87 | # print(path, ' fails.') 88 | pbar.close() 89 | 90 | for name in metric_results.keys(): 91 | metric_results[name] /= len(paths) 92 | print('Result for {}'.format(args.weight)) 93 | print(metric_results) 94 | 95 | 96 | def vis_weight(weight, save_path): 97 | # weight: B x n x 1 x H x W 98 | weight = weight.cpu().numpy() 99 | # normalize weights 100 | # norm_weight = weight 101 | norm_weight = (weight - weight.mean()) / weight.std() / 2 102 | norm_weight = np.abs(norm_weight) 103 | norm_weight *= 255 104 | norm_weight = np.clip(norm_weight, 0, 255) 105 | norm_weight = norm_weight.astype(np.uint8) 106 | # visualize 107 | display_grid = np.zeros((weight.shape[3], (weight.shape[4]+1)*weight.shape[1]-1)) 108 | for img_id in range(len(norm_weight)): 109 | for c in range(norm_weight.shape[1]): 110 | display_grid[:, c*weight.shape[4]+c:(c+1)*weight.shape[4]+c] = norm_weight[img_id, c, 0, :, :] 111 | # weight_path = save_path.split('.')[0] + '_{}.png'.format(str(c)) 112 | # Image.fromarray(norm_weight[img_id, c, 0, :, :]).save(weight_path) 113 | plt.figure(figsize=(30,150)) 114 | plt.axis('off') 115 | plt.imshow(display_grid) 116 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0) 117 | plt.close() 118 | 119 | 120 | 121 | 122 | if __name__ == '__main__': 123 | parser = argparse.ArgumentParser() 124 | parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder') 125 | parser.add_argument('-w', '--weight', type=str, default=None, help='path for model weights') 126 | parser.add_argument('-o', '--output', type=str, default='results', help='Output folder') 127 | parser.add_argument('--suffix', type=str, default='', help='Suffix of the restored image') 128 | parser.add_argument('--max_size', type=int, default=600000, help='Max image size for whole image inference, otherwise use tiled_test') 129 | parser.add_argument('--vis_weight', action='store_true', help='visualize weight map') 130 | args = parser.parse_args() 131 | 132 | if args.vis_weight: 133 | os.makedirs(os.path.join(args.output, 'weight_map'), exist_ok=True) 134 | 135 | main(args) 136 | -------------------------------------------------------------------------------- /testset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/.DS_Store -------------------------------------------------------------------------------- /testset/inpaint/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/.DS_Store -------------------------------------------------------------------------------- /testset/inpaint/0801_s009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0801_s009.png -------------------------------------------------------------------------------- /testset/inpaint/0804_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0804_s003.png -------------------------------------------------------------------------------- /testset/inpaint/0804_s006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0804_s006.png -------------------------------------------------------------------------------- /testset/inpaint/0808_s010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0808_s010.png -------------------------------------------------------------------------------- /testset/inpaint/0809_s006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0809_s006.png -------------------------------------------------------------------------------- /testset/inpaint/0813_s009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0813_s009.png -------------------------------------------------------------------------------- /testset/inpaint/0817_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0817_s003.png -------------------------------------------------------------------------------- /testset/inpaint/0817_s012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0817_s012.png -------------------------------------------------------------------------------- /testset/inpaint/0830_s007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0830_s007.png -------------------------------------------------------------------------------- /testset/inpaint/0831_s006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0831_s006.png -------------------------------------------------------------------------------- /testset/inpaint/0833_s003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0833_s003.png -------------------------------------------------------------------------------- /testset/inpaint/0836_s001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0836_s001.png -------------------------------------------------------------------------------- /testset/inpaint/0836_s010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0836_s010.png -------------------------------------------------------------------------------- /testset/inpaint/0839_s012.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0839_s012.png -------------------------------------------------------------------------------- /testset/inpaint/0842_s002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0842_s002.png -------------------------------------------------------------------------------- /testset/inpaint/0848_s009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0848_s009.png -------------------------------------------------------------------------------- /testset/inpaint/0849_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0849_s004.png -------------------------------------------------------------------------------- /testset/inpaint/0852_s010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0852_s010.png -------------------------------------------------------------------------------- /testset/inpaint/0856_s006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0856_s006.png -------------------------------------------------------------------------------- /testset/inpaint/0856_s010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0856_s010.png -------------------------------------------------------------------------------- /testset/inpaint/0863_s004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0863_s004.png -------------------------------------------------------------------------------- /testset/inpaint/0865_s011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0865_s011.png -------------------------------------------------------------------------------- /testset/inpaint/0866_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0866_s005.png -------------------------------------------------------------------------------- /testset/inpaint/0870_s005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0870_s005.png -------------------------------------------------------------------------------- /testset/inpaint/0874_s007.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0874_s007.png -------------------------------------------------------------------------------- /testset/inpaint/64010.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/64010.png -------------------------------------------------------------------------------- /testset/sr/0003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0003.jpg -------------------------------------------------------------------------------- /testset/sr/0014.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0014.jpg -------------------------------------------------------------------------------- /testset/sr/0015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0015.jpg -------------------------------------------------------------------------------- /testset/sr/0030.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0030.jpg -------------------------------------------------------------------------------- /testset/sr/0032.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0032.jpg -------------------------------------------------------------------------------- /testset/sr/0054.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0054.jpg -------------------------------------------------------------------------------- /testset/sr/0068.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0068.jpg -------------------------------------------------------------------------------- /testset/sr/49370.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/49370.png -------------------------------------------------------------------------------- /testset/sr/ADE_val_00000015.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ADE_val_00000015.jpg -------------------------------------------------------------------------------- /testset/sr/ADE_val_00000114.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ADE_val_00000114.jpg -------------------------------------------------------------------------------- /testset/sr/Canon_004_LR4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/Canon_004_LR4.png -------------------------------------------------------------------------------- /testset/sr/Canon_045_LR4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/Canon_045_LR4.png -------------------------------------------------------------------------------- /testset/sr/OST_009.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_009.png -------------------------------------------------------------------------------- /testset/sr/OST_020.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_020.png -------------------------------------------------------------------------------- /testset/sr/OST_120.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_120.png -------------------------------------------------------------------------------- /testset/sr/building.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/building.png -------------------------------------------------------------------------------- /testset/sr/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/butterfly.png -------------------------------------------------------------------------------- /testset/sr/butterfly2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/butterfly2.png -------------------------------------------------------------------------------- /testset/sr/chip.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/chip.png -------------------------------------------------------------------------------- /testset/sr/comic1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic1.png -------------------------------------------------------------------------------- /testset/sr/comic2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic2.png -------------------------------------------------------------------------------- /testset/sr/comic3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic3.png -------------------------------------------------------------------------------- /testset/sr/computer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/computer.png -------------------------------------------------------------------------------- /testset/sr/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/dog.png -------------------------------------------------------------------------------- /testset/sr/dped_crop00061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/dped_crop00061.png -------------------------------------------------------------------------------- /testset/sr/foreman.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/foreman.png -------------------------------------------------------------------------------- /testset/sr/frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/frog.png -------------------------------------------------------------------------------- /testset/sr/oldphoto3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/oldphoto3.png -------------------------------------------------------------------------------- /testset/sr/oldphoto6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/oldphoto6.png -------------------------------------------------------------------------------- /testset/sr/painting.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/painting.png -------------------------------------------------------------------------------- /testset/sr/pattern.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/pattern.png -------------------------------------------------------------------------------- /testset/sr/ppt3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ppt3.png -------------------------------------------------------------------------------- /testset/sr/tiger.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/tiger.png -------------------------------------------------------------------------------- /vis_codebook.py: -------------------------------------------------------------------------------- 1 | from itertools import count 2 | from tokenize import PlainToken 3 | import torch 4 | import torchvision.transforms as tf 5 | from torchvision.utils import save_image 6 | import torchvision.utils as tvu 7 | 8 | import numpy as np 9 | import os 10 | import random 11 | from tqdm import tqdm 12 | import cv2 13 | from matplotlib import pyplot as plt 14 | import seaborn as sns 15 | 16 | from basicsr.utils.misc import set_random_seed 17 | from basicsr.utils import img2tensor, tensor2img, imwrite 18 | from basicsr.archs.femasr_arch import FeMaSRNet 19 | 20 | 21 | def reconstruct_ost(model, data_dir, save_dir, maxnum=100): 22 | 23 | texture_classes = list(os.listdir(data_dir)) 24 | texture_classes.remove('manga109') 25 | code_idx_dict = {} 26 | for tc in texture_classes: 27 | img_name_list = os.listdir(os.path.join(data_dir, tc)) 28 | random.shuffle(img_name_list) 29 | tmp_code_idx_list = [] 30 | for img_name in tqdm(img_name_list[:maxnum]): 31 | img_path = os.path.join(data_dir, tc, img_name) 32 | 33 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 34 | img_tensor = img2tensor(img).to(device) / 255. 35 | img_tensor = img_tensor.unsqueeze(0) 36 | 37 | rec, _, _, indices = model(img_tensor) 38 | indices = indices[0] 39 | 40 | save_path = os.path.join(save_dir, tc, img_name) 41 | if not os.path.exists(os.path.join(save_dir, tc)): 42 | os.makedirs(os.path.join(save_dir, tc), exist_ok=True) 43 | imwrite(tensor2img(rec), save_path) 44 | 45 | save_org_dir = save_dir.replace('rec', 'org') 46 | save_org_path = os.path.join(save_org_dir, tc, img_name) 47 | if not os.path.exists(os.path.join(save_org_dir, tc)): 48 | os.makedirs(os.path.join(save_org_dir, tc), exist_ok=True) 49 | imwrite(tensor2img(img_tensor), save_org_path) 50 | 51 | tmp_code_idx_list.append(indices) 52 | code_idx_dict[tc] = tmp_code_idx_list 53 | 54 | torch.save(code_idx_dict, './tmp_code_vis/code_idx_dict.pth') 55 | 56 | 57 | def vis_hrp(model, code_list_path, save_dir, samples_each_class=16): 58 | code_idx_dict = torch.load(code_list_path) 59 | classes = list(code_idx_dict.keys()) 60 | 61 | latent_size = 8 62 | color_palette = sns.color_palette() 63 | for idx, (key, value) in enumerate(code_idx_dict.items()): 64 | all_idx = torch.cat([x.flatten() for x in value]) 65 | 66 | plt.figure(figsize=(16, 8)) 67 | sns.histplot(all_idx.cpu().numpy(), color=color_palette[idx]) 68 | plt.xlabel(key, fontsize=30) 69 | plt.ylabel('Count', fontsize=30) 70 | plt.savefig(f'./tmp_code_vis/code_stat/code_index_bincount_{key}.pdf') 71 | 72 | counts = all_idx.bincount() 73 | dist = counts / sum(counts) 74 | dist = dist.cpu().numpy() 75 | 76 | vis_tex_samples = [] 77 | for sid in range(32): 78 | vis_tex_map = np.random.choice(np.arange(dist.shape[0]), latent_size ** 2, p=dist) 79 | vis_tex_map = torch.from_numpy(vis_tex_map).to(all_idx) 80 | vis_tex_map = vis_tex_map.reshape(1, 1, latent_size, latent_size) 81 | vis_tex_img = model.decode_indices(vis_tex_map) 82 | vis_tex_samples.append(vis_tex_img) 83 | vis_tex_samples = torch.cat(vis_tex_samples, dim=0) 84 | save_image(vis_tex_samples, f'./tmp_code_vis/tmp_tex_vis/{key}.jpg', normalize=True, nrow=16) 85 | 86 | 87 | def vis_single_code(model, codenum, save_path, up_factor=4): 88 | code_idx = torch.arange(codenum).reshape(codenum, 1, 1, 1) 89 | code_idx = code_idx.repeat(1, 1, up_factor, up_factor) 90 | output_img = model.decode_indices(code_idx) 91 | output_img = tvu.make_grid(output_img, nrow=32) 92 | save_image(output_img, save_path) 93 | 94 | def vis_rand_single_code(model, codenum, save_path, up_factor=4, vis_num=5): 95 | code_idx = torch.randint(codenum, (vis_num,)).reshape(vis_num, 1, 1, 1) 96 | code_idx = code_idx.repeat(1, 1, up_factor, up_factor) 97 | output_img = model.decode_indices(code_idx) 98 | output_img = tvu.make_grid(output_img, nrow=32) 99 | save_image(output_img, save_path) 100 | 101 | 102 | if __name__ == '__main__': 103 | # set random seeds to ensure reproducibility 104 | set_random_seed(123) 105 | 106 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 107 | 108 | # set up the model 109 | weight_path = './experiments/pretrained_models/S1_c2_512x256_net_g_best.pth' 110 | codebook_size = os.path.basename(weight_path).split('_')[2].split('x') 111 | vqgan = FeMaSRNet(codebook_params=[[32, int(codebook_size[0]), int(codebook_size[1])]], LQ_stage=False).to(device) 112 | vqgan.load_state_dict(torch.load(weight_path)['params'], strict=False) 113 | vqgan.eval() 114 | 115 | os.makedirs('results/codebook_vis', exist_ok=True) 116 | vis_single_code(vqgan, int(codebook_size[0]), 'results/codebook_vis/{}.png'.format(os.path.basename(weight_path).split('.')[0])) 117 | # vis_rand_single_code(vqgan, 256, 'codebook_vis/sample_ffhq.png', vis_num=10) 118 | 119 | # reconstruct_ost(vqgan, '../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/', './tmp_code_vis/ost_rec', maxnum=1000) 120 | # vis_hrp(vqgan, './tmp_code_vis/code_idx_dict.pth', './tmp_code_vis/') 121 | --------------------------------------------------------------------------------