├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── VisToyExample.ipynb
├── basicsr
├── .DS_Store
├── __init__.py
├── archs
│ ├── RRDB_arch.py
│ ├── __init__.py
│ ├── adacode_arch.py
│ ├── adacode_contrast_arch.py
│ ├── arch_util.py
│ ├── discriminator_arch.py
│ ├── fema_utils.py
│ ├── femasr_arch.py
│ ├── network_swinir.py
│ └── vgg_arch.py
├── data
│ ├── __init__.py
│ ├── bsrgan_train_dataset.py
│ ├── bsrgan_util.py
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── paired_image_dataset.py
│ ├── prefetch_dataloader.py
│ ├── single_image_dataset.py
│ └── transforms.py
├── losses
│ ├── __init__.py
│ ├── loss_util.py
│ └── losses.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── femasr_model.py
│ ├── lr_scheduler.py
│ └── ori_femasr_model.py
├── test.py
├── train.py
├── train_mergedcodebook.py
└── utils
│ ├── __init__.py
│ ├── diffjpeg.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_process_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ ├── options.py
│ └── registry.py
├── evaluate.py
├── figures
├── .DS_Store
├── model.jpeg
└── teaser.jpeg
├── generate_dataset.py
├── generate_inpaint_dataset.py
├── inference.py
├── options
├── stage1
│ └── train_AdaCode_HQ_stage1_category.yml
├── stage2
│ └── train_AdaCode_HQ_stage2.yaml
└── stage3
│ └── train_AdaCode_stage3.yaml
├── requirements.txt
├── scripts
├── .DS_Store
├── data_preparation
│ ├── create_lmdb.py
│ ├── download_datasets.py
│ ├── extract_images_from_tfrecords.py
│ ├── extract_subimages.py
│ ├── generate_meta_info.py
│ ├── prepare_SR_dataset.py
│ ├── prepare_hifacegan_dataset.py
│ └── regroup_reds_dataset.py
├── download_gdrive.py
├── download_pretrained_models.py
├── matlab_scripts
│ ├── back_projection
│ │ ├── backprojection.m
│ │ ├── main_bp.m
│ │ └── main_reverse_filter.m
│ ├── generate_LR_Vimeo90K.m
│ └── generate_bicubic_img.m
├── metrics
│ ├── calculate_fid_folder.py
│ ├── calculate_fid_stats_from_datasets.py
│ ├── calculate_lpips.py
│ ├── calculate_niqe.py
│ ├── calculate_psnr_ssim.py
│ └── calculate_stylegan2_fid.py
├── model_conversion
│ ├── convert_dfdnet.py
│ ├── convert_models.py
│ ├── convert_ridnet.py
│ └── convert_stylegan.py
└── publish_models.py
├── setup.py
├── test_recon.py
├── testset
├── .DS_Store
├── inpaint
│ ├── .DS_Store
│ ├── 0801_s009.png
│ ├── 0804_s003.png
│ ├── 0804_s006.png
│ ├── 0808_s010.png
│ ├── 0809_s006.png
│ ├── 0813_s009.png
│ ├── 0817_s003.png
│ ├── 0817_s012.png
│ ├── 0830_s007.png
│ ├── 0831_s006.png
│ ├── 0833_s003.png
│ ├── 0836_s001.png
│ ├── 0836_s010.png
│ ├── 0839_s012.png
│ ├── 0842_s002.png
│ ├── 0848_s009.png
│ ├── 0849_s004.png
│ ├── 0852_s010.png
│ ├── 0856_s006.png
│ ├── 0856_s010.png
│ ├── 0863_s004.png
│ ├── 0865_s011.png
│ ├── 0866_s005.png
│ ├── 0870_s005.png
│ ├── 0874_s007.png
│ └── 64010.png
└── sr
│ ├── 0003.jpg
│ ├── 0014.jpg
│ ├── 0015.jpg
│ ├── 0030.jpg
│ ├── 0032.jpg
│ ├── 0054.jpg
│ ├── 0068.jpg
│ ├── 49370.png
│ ├── ADE_val_00000015.jpg
│ ├── ADE_val_00000114.jpg
│ ├── Canon_004_LR4.png
│ ├── Canon_045_LR4.png
│ ├── OST_009.png
│ ├── OST_020.png
│ ├── OST_120.png
│ ├── building.png
│ ├── butterfly.png
│ ├── butterfly2.png
│ ├── chip.png
│ ├── comic1.png
│ ├── comic2.png
│ ├── comic3.png
│ ├── computer.png
│ ├── dog.png
│ ├── dped_crop00061.png
│ ├── foreman.png
│ ├── frog.png
│ ├── oldphoto3.png
│ ├── oldphoto6.png
│ ├── painting.png
│ ├── pattern.png
│ ├── ppt3.png
│ └── tiger.png
└── vis_codebook.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignored folders
2 | data/*
3 | datasets/*
4 | experiments*
5 | experiments/*
6 | test_results*/
7 | results*/
8 | tb_logger*/*
9 | wandb/*
10 | tmp*/*
11 | tmp*
12 | *.sh
13 | .vscode*
14 | .github
15 | */.DS_Store
16 |
17 | # ignored files
18 | version.py
19 |
20 | # ignored files with suffix
21 | *.html
22 | *.gif
23 | *.pth
24 | *.zip
25 | *.npy
26 |
27 | # template
28 |
29 | # Byte-compiled / optimized / DLL files
30 | __pycache__/
31 | *.py[cod]
32 | *$py.class
33 |
34 | # C extensions
35 | *.so
36 |
37 | # Distribution / packaging
38 | .Python
39 | build/
40 | develop-eggs/
41 | dist/
42 | downloads/
43 | eggs/
44 | .eggs/
45 | lib/
46 | lib64/
47 | parts/
48 | sdist/
49 | var/
50 | wheels/
51 | *.egg-info/
52 | .installed.cfg
53 | *.egg
54 | MANIFEST
55 |
56 | # PyInstaller
57 | # Usually these files are written by a python script from a template
58 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
59 | *.manifest
60 | *.spec
61 |
62 | # Installer logs
63 | pip-log.txt
64 | pip-delete-this-directory.txt
65 |
66 | # Unit test / coverage reports
67 | htmlcov/
68 | .tox/
69 | .coverage
70 | .coverage.*
71 | .cache
72 | nosetests.xml
73 | coverage.xml
74 | *.cover
75 | .hypothesis/
76 | .pytest_cache/
77 |
78 | # Translations
79 | *.mo
80 | *.pot
81 |
82 | # Django stuff:
83 | *.log
84 | local_settings.py
85 | db.sqlite3
86 |
87 | # Flask stuff:
88 | instance/
89 | .webassets-cache
90 |
91 | # Scrapy stuff:
92 | .scrapy
93 |
94 | # Sphinx documentation
95 | docs/_build/
96 |
97 | # PyBuilder
98 | target/
99 |
100 | # Jupyter Notebook
101 | .ipynb_checkpoints
102 |
103 | # pyenv
104 | .python-version
105 |
106 | # celery beat schedule file
107 | celerybeat-schedule
108 |
109 | # SageMath parsed files
110 | *.sage.py
111 |
112 | # Environments
113 | .env
114 | .venv
115 | env/
116 | venv/
117 | ENV/
118 | env.bak/
119 | venv.bak/
120 |
121 | # Spyder project settings
122 | .spyderproject
123 | .spyproject
124 |
125 | # Rope project settings
126 | .ropeproject
127 |
128 | # mkdocs documentation
129 | /site
130 |
131 | # mypy
132 | .mypy_cache/
133 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # AdaCode
2 |
3 | This repository provides the official implementation for the paper
4 |
5 | > **Learning Image-Adaptive Codebooks for Class-Agnostic Image Restoration**
6 | > [Kechun Liu](https://kechunl.github.io/), Yitong Jiang, Inchang Choi, [Jinwei Gu](https://www.gujinwei.org/)
7 | > In ICCV 2023.
8 |
9 | [](https://arxiv.org/abs/2306.06513)
10 | [](https://kechunl.github.io/AdaCode/)
11 | [](https://www.youtube.com/watch?v=GOp-4kbgyoM)
12 | > **Abstract:** Recent work of discrete generative priors, in the form of codebooks, has shown exciting performance for image reconstruction and restoration, since the discrete prior space spanned by the codebooks increases the robustness against diverse image degradations. Nevertheless, these methods require separate training of codebooks for different image categories, which limits their use to specific image categories only (e.g. face, architecture, etc.), and fail to handle arbitrary natural images. **In this paper, we propose AdaCode for learning image-adaptive codebooks for classagnostic image restoration. Instead of learning a single codebook for all categories of images, we learn a set of basis codebooks. For a given input image, AdaCode learns a weight map with which we compute a weighted combination of these basis codebooks for adaptive image restoration.** Intuitively, AdaCode is a more flexible and expressive discrete generative prior than previous work. Experimental results show that AdaCode achieves state-of-the-art performance on image reconstruction and restoration tasks, including image super-resolution and inpainting.
13 |
14 |
15 |
16 | ---
17 | ### Dependencies and Installation
18 |
19 | - Ubuntu >= 18.04
20 | - CUDA >= 11.0
21 | - Other required packages in `requirements.txt`
22 | ```
23 | # git clone this repository
24 | git clone https://github.com/kechunl/AdaCode.git
25 | cd AdaCode
26 |
27 | # create new anaconda env
28 | conda create -n AdaCode python=3.8
29 | conda activate AdaCode
30 |
31 | # install python dependencies
32 | pip install -r requirements.txt
33 | python setup.py develop
34 | ```
35 |
36 | ---
37 |
38 | ### Inference
39 |
40 | ```
41 | # super resolution
42 | python inference.py -i ./testset/sr -t sr -s 2 -o results_x2/
43 | python inference.py -i ./testset/sr -t sr -s 4 -o results_x4/
44 |
45 | # inpainting
46 | python inference.py -i ./testset/inpaint -t inpaint -o results_inpaint/
47 | ```
48 |
49 | ## Train the model
50 |
51 | ### Preparation
52 |
53 | #### Dataset
54 |
55 | Please prepare the training and testing data follow descriptions in the [paper](https://arxiv.org/abs/2306.06513). In brief, you need to crop 512 x 512 high resolution patches, and generate the low resolution patches with [`degradation_bsrgan`](https://github.com/cszn/BSRGAN/blob/3a958f40a9a24e8b81c3cb1960f05b0e91f1b421/utils/utils_blindsr.py?_pjax=%23js-repo-pjax-container%2C%20div%5Bitemtype%3D%22http%3A%2F%2Fschema.org%2FSoftwareSourceCode%22%5D%20main%2C%20%5Bdata-pjax-container%5D#L432) function provided by [BSRGAN](https://github.com/cszn/BSRGAN). You may find `generate_dataset.py` and `generate_inpaint_datset.py` useful in generating low resolution and masked patches.
56 |
57 | #### Model preparation
58 |
59 | Before training, you need to
60 | - Download the pretrained stage 2 model: [generator](https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_S2_model_g.pth), [discriminator](https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_S2_model_d.pth)
61 | - Put the pretrained models in `experiments/pretrained_models`
62 | - Specify their path in the corresponding option file.
63 |
64 | ### Train SR model
65 |
66 | ```
67 | python basicsr/train.py -opt options/stage3/train_Adacode_stage3.yaml
68 | ```
69 |
70 | ## Acknowledgement
71 |
72 | This project is based on [FeMaSR](https://github.com/chaofengc/FeMaSR).
--------------------------------------------------------------------------------
/basicsr/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/basicsr/.DS_Store
--------------------------------------------------------------------------------
/basicsr/__init__.py:
--------------------------------------------------------------------------------
1 | # https://github.com/xinntao/BasicSR
2 | # flake8: noqa
3 | from .archs import *
4 | from .data import *
5 | from .losses import *
6 | from .models import *
7 | from .test import *
8 | from .train import *
9 | from .utils import *
10 |
--------------------------------------------------------------------------------
/basicsr/archs/RRDB_arch.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 |
7 | def make_layer(block, n_layers):
8 | layers = []
9 | for _ in range(n_layers):
10 | layers.append(block())
11 | return nn.Sequential(*layers)
12 |
13 |
14 | class ResidualDenseBlock_5C(nn.Module):
15 | def __init__(self, nf=64, gc=32, bias=True):
16 | super(ResidualDenseBlock_5C, self).__init__()
17 | # gc: growth channel, i.e. intermediate channels
18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
24 |
25 | # initialization
26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
27 |
28 | def forward(self, x):
29 | x1 = self.lrelu(self.conv1(x))
30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
34 | return x5 * 0.2 + x
35 |
36 |
37 | class RRDB(nn.Module):
38 | '''Residual in Residual Dense Block'''
39 |
40 | def __init__(self, nf, gc=32):
41 | super(RRDB, self).__init__()
42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc)
43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc)
44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc)
45 |
46 | def forward(self, x):
47 | out = self.RDB1(x)
48 | out = self.RDB2(out)
49 | out = self.RDB3(out)
50 | return out * 0.2 + x
51 |
52 |
53 | class RRDBNet(nn.Module):
54 | def __init__(self, in_nc, out_nc, nf, nb, gc=32):
55 | super(RRDBNet, self).__init__()
56 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
57 |
58 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
59 | self.RRDB_trunk = make_layer(RRDB_block_f, nb)
60 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
61 | #### upsampling
62 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
63 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
64 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
65 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
66 |
67 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
68 |
69 | def forward(self, x):
70 | fea = self.conv_first(x)
71 | trunk = self.trunk_conv(self.RRDB_trunk(fea))
72 | fea = fea + trunk
73 |
74 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
75 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
76 | out = self.conv_last(self.lrelu(self.HRconv(fea)))
77 |
78 | return out
--------------------------------------------------------------------------------
/basicsr/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import ARCH_REGISTRY
7 |
8 | __all__ = ['build_network']
9 |
10 | # automatically scan and import arch modules for registry
11 | # scan all the files under the 'archs' folder and collect files ending with
12 | # '_arch.py'
13 | arch_folder = osp.dirname(osp.abspath(__file__))
14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15 | # import all the arch modules
16 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames]
17 |
18 |
19 | def build_network(opt):
20 | opt = deepcopy(opt)
21 | network_type = opt.pop('type')
22 | net = ARCH_REGISTRY.get(network_type)(**opt)
23 | logger = get_root_logger()
24 | logger.info(f'Network [{net.__class__.__name__}] is created.')
25 | return net
26 |
--------------------------------------------------------------------------------
/basicsr/archs/discriminator_arch.py:
--------------------------------------------------------------------------------
1 | from basicsr.utils.registry import ARCH_REGISTRY
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | from torch.nn.utils import spectral_norm
5 |
6 |
7 | @ARCH_REGISTRY.register()
8 | class UNetDiscriminatorSN(nn.Module):
9 | """Defines a U-Net discriminator with spectral normalization (SN)
10 |
11 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data.
12 |
13 | Arg:
14 | num_in_ch (int): Channel number of inputs. Default: 3.
15 | num_feat (int): Channel number of base intermediate features. Default: 64.
16 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True.
17 | """
18 |
19 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
20 | super(UNetDiscriminatorSN, self).__init__()
21 | self.skip_connection = skip_connection
22 | norm = spectral_norm
23 | # the first convolution
24 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
25 | # downsample
26 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
27 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
28 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
29 | # upsample
30 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
31 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
32 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
33 | # extra convolutions
34 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
35 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
36 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
37 |
38 | def forward(self, x):
39 | # downsample
40 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
41 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
42 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
43 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
44 |
45 | # upsample
46 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
47 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
48 |
49 | if self.skip_connection:
50 | x4 = x4 + x2
51 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
52 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
53 |
54 | if self.skip_connection:
55 | x5 = x5 + x1
56 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
57 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
58 |
59 | if self.skip_connection:
60 | x6 = x6 + x0
61 |
62 | # extra convolutions
63 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
64 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
65 | out = self.conv9(out)
66 |
67 | return out
--------------------------------------------------------------------------------
/basicsr/archs/fema_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 | from torch import nn as nn
4 |
5 | class NormLayer(nn.Module):
6 | """Normalization Layers.
7 | ------------
8 | # Arguments
9 | - channels: input channels, for batch norm and instance norm.
10 | - input_size: input shape without batch size, for layer norm.
11 | """
12 | def __init__(self, channels, norm_type='bn'):
13 | super(NormLayer, self).__init__()
14 | norm_type = norm_type.lower()
15 | self.norm_type = norm_type
16 | self.channels = channels
17 | if norm_type == 'bn':
18 | self.norm = nn.BatchNorm2d(channels, affine=True)
19 | elif norm_type == 'in':
20 | self.norm = nn.InstanceNorm2d(channels, affine=False)
21 | elif norm_type == 'gn':
22 | self.norm = nn.GroupNorm(num_groups=32, num_channels=channels, eps=1e-6, affine=True)
23 | elif norm_type == 'none':
24 | self.norm = lambda x: x*1.0
25 | else:
26 | assert 1==0, 'Norm type {} not support.'.format(norm_type)
27 |
28 | def forward(self, x):
29 | return self.norm(x)
30 |
31 |
32 | class ActLayer(nn.Module):
33 | """activation layer.
34 | ------------
35 | # Arguments
36 | - relu type: type of relu layer, candidates are
37 | - ReLU
38 | - LeakyReLU: default relu slope 0.2
39 | - PRelu
40 | - SELU
41 | - none: direct pass
42 | """
43 | def __init__(self, channels, relu_type='leakyrelu'):
44 | super(ActLayer, self).__init__()
45 | relu_type = relu_type.lower()
46 | if relu_type == 'relu':
47 | self.func = nn.ReLU(True)
48 | elif relu_type == 'leakyrelu':
49 | self.func = nn.LeakyReLU(0.2, inplace=True)
50 | elif relu_type == 'prelu':
51 | self.func = nn.PReLU(channels)
52 | elif relu_type == 'none':
53 | self.func = lambda x: x*1.0
54 | elif relu_type == 'silu':
55 | self.func = nn.SiLU(True)
56 | elif relu_type == 'gelu':
57 | self.func = nn.GELU()
58 | else:
59 | assert 1==0, 'activation type {} not support.'.format(relu_type)
60 |
61 | def forward(self, x):
62 | return self.func(x)
63 |
64 |
65 | class ResBlock(nn.Module):
66 | """
67 | Use preactivation version of residual block, the same as taming
68 | """
69 | def __init__(self, in_channel, out_channel, norm_type='gn', act_type='leakyrelu'):
70 | super(ResBlock, self).__init__()
71 |
72 | self.conv = nn.Sequential(
73 | NormLayer(in_channel, norm_type),
74 | ActLayer(in_channel, act_type),
75 | nn.Conv2d(in_channel, out_channel, 3, stride=1, padding=1),
76 | NormLayer(out_channel, norm_type),
77 | ActLayer(out_channel, act_type),
78 | nn.Conv2d(out_channel, out_channel, 3, stride=1, padding=1),
79 | )
80 |
81 | def forward(self, input):
82 | res = self.conv(input)
83 | out = res + input
84 | return out
85 |
86 |
87 | class CombineQuantBlock(nn.Module):
88 | def __init__(self, in_ch1, in_ch2, out_channel):
89 | super().__init__()
90 | self.conv = nn.Conv2d(in_ch1 + in_ch2, out_channel, 3, 1, 1)
91 |
92 | def forward(self, input1, input2=None):
93 | if input2 is not None:
94 | input2 = F.interpolate(input2, input1.shape[2:])
95 | input = torch.cat((input1, input2), dim=1)
96 | else:
97 | input = input1
98 | out = self.conv(input)
99 | return out
100 |
101 |
102 |
--------------------------------------------------------------------------------
/basicsr/archs/vgg_arch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from torch import nn as nn
5 | from torchvision.models import vgg as vgg
6 |
7 | from basicsr.utils.registry import ARCH_REGISTRY
8 |
9 | VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
10 | NAMES = {
11 | 'vgg11': [
12 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
13 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
14 | 'pool5'
15 | ],
16 | 'vgg13': [
17 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
18 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
19 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
20 | ],
21 | 'vgg16': [
22 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
23 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
24 | 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
25 | 'pool5'
26 | ],
27 | 'vgg19': [
28 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
29 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
30 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
31 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
32 | ]
33 | }
34 |
35 |
36 | def insert_bn(names):
37 | """Insert bn layer after each conv.
38 |
39 | Args:
40 | names (list): The list of layer names.
41 |
42 | Returns:
43 | list: The list of layer names with bn layers.
44 | """
45 | names_bn = []
46 | for name in names:
47 | names_bn.append(name)
48 | if 'conv' in name:
49 | position = name.replace('conv', '')
50 | names_bn.append('bn' + position)
51 | return names_bn
52 |
53 |
54 | @ARCH_REGISTRY.register()
55 | class VGGFeatureExtractor(nn.Module):
56 | """VGG network for feature extraction.
57 |
58 | In this implementation, we allow users to choose whether use normalization
59 | in the input feature and the type of vgg network. Note that the pretrained
60 | path must fit the vgg type.
61 |
62 | Args:
63 | layer_name_list (list[str]): Forward function returns the corresponding
64 | features according to the layer_name_list.
65 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
66 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
67 | use_input_norm (bool): If True, normalize the input image. Importantly,
68 | the input feature must in the range [0, 1]. Default: True.
69 | range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
70 | Default: False.
71 | requires_grad (bool): If true, the parameters of VGG network will be
72 | optimized. Default: False.
73 | remove_pooling (bool): If true, the max pooling operations in VGG net
74 | will be removed. Default: False.
75 | pooling_stride (int): The stride of max pooling operation. Default: 2.
76 | """
77 |
78 | def __init__(self,
79 | layer_name_list,
80 | vgg_type='vgg19',
81 | use_input_norm=True,
82 | range_norm=False,
83 | requires_grad=False,
84 | remove_pooling=False,
85 | pooling_stride=2):
86 | super(VGGFeatureExtractor, self).__init__()
87 |
88 | self.layer_name_list = layer_name_list
89 | self.use_input_norm = use_input_norm
90 | self.range_norm = range_norm
91 |
92 | self.names = NAMES[vgg_type.replace('_bn', '')]
93 | if 'bn' in vgg_type:
94 | self.names = insert_bn(self.names)
95 |
96 | # only borrow layers that will be used to avoid unused params
97 | max_idx = 0
98 | for v in layer_name_list:
99 | idx = self.names.index(v)
100 | if idx > max_idx:
101 | max_idx = idx
102 |
103 | if os.path.exists(VGG_PRETRAIN_PATH):
104 | vgg_net = getattr(vgg, vgg_type)(pretrained=False)
105 | state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
106 | vgg_net.load_state_dict(state_dict)
107 | else:
108 | vgg_net = getattr(vgg, vgg_type)(pretrained=True)
109 |
110 | features = vgg_net.features[:max_idx + 1]
111 |
112 | modified_net = OrderedDict()
113 | for k, v in zip(self.names, features):
114 | if 'pool' in k:
115 | # if remove_pooling is true, pooling operation will be removed
116 | if remove_pooling:
117 | continue
118 | else:
119 | # in some cases, we may want to change the default stride
120 | modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
121 | else:
122 | modified_net[k] = v
123 |
124 | self.vgg_net = nn.Sequential(modified_net)
125 |
126 | if not requires_grad:
127 | self.vgg_net.eval()
128 | for param in self.parameters():
129 | param.requires_grad = False
130 | else:
131 | self.vgg_net.train()
132 | for param in self.parameters():
133 | param.requires_grad = True
134 |
135 | if self.use_input_norm:
136 | # the mean is for image with range [0, 1]
137 | self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
138 | # the std is for image with range [0, 1]
139 | self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
140 |
141 | def forward(self, x):
142 | """Forward function.
143 |
144 | Args:
145 | x (Tensor): Input tensor with shape (n, c, h, w).
146 |
147 | Returns:
148 | Tensor: Forward results.
149 | """
150 | if self.range_norm:
151 | x = (x + 1) / 2
152 | if self.use_input_norm:
153 | x = (x - self.mean) / self.std
154 |
155 | output = {}
156 | for key, layer in self.vgg_net._modules.items():
157 | x = layer(x)
158 | if key in self.layer_name_list:
159 | output[key] = x.clone()
160 |
161 | return output
162 |
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from copy import deepcopy
7 | from functools import partial
8 | from os import path as osp
9 |
10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11 | from basicsr.utils import get_root_logger, scandir
12 | from basicsr.utils.dist_util import get_dist_info
13 | from basicsr.utils.registry import DATASET_REGISTRY
14 |
15 | __all__ = ['build_dataset', 'build_dataloader']
16 |
17 | # automatically scan and import dataset modules for registry
18 | # scan all the files under the data folder with '_dataset' in file names
19 | data_folder = osp.dirname(osp.abspath(__file__))
20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21 | # import all the dataset modules
22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames]
23 |
24 |
25 | def build_dataset(dataset_opt):
26 | """Build dataset from options.
27 |
28 | Args:
29 | dataset_opt (dict): Configuration for dataset. It must contain:
30 | name (str): Dataset name.
31 | type (str): Dataset type.
32 | """
33 | dataset_opt = deepcopy(dataset_opt)
34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
35 | logger = get_root_logger()
36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} ' 'is built.')
37 | return dataset
38 |
39 |
40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
41 | """Build dataloader.
42 |
43 | Args:
44 | dataset (torch.utils.data.Dataset): Dataset.
45 | dataset_opt (dict): Dataset options. It contains the following keys:
46 | phase (str): 'train' or 'val'.
47 | num_worker_per_gpu (int): Number of workers for each GPU.
48 | batch_size_per_gpu (int): Training batch size for each GPU.
49 | num_gpu (int): Number of GPUs. Used only in the train phase.
50 | Default: 1.
51 | dist (bool): Whether in distributed training. Used only in the train
52 | phase. Default: False.
53 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
54 | seed (int | None): Seed. Default: None
55 | """
56 | phase = dataset_opt['phase']
57 | rank, _ = get_dist_info()
58 | if phase == 'train':
59 | if dist: # distributed training
60 | batch_size = dataset_opt['batch_size_per_gpu']
61 | num_workers = dataset_opt['num_worker_per_gpu']
62 | else: # non-distributed training
63 | multiplier = 1 if num_gpu == 0 else num_gpu
64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
66 | dataloader_args = dict(
67 | dataset=dataset,
68 | batch_size=batch_size,
69 | shuffle=False,
70 | num_workers=num_workers,
71 | sampler=sampler,
72 | drop_last=True)
73 | if sampler is None:
74 | dataloader_args['shuffle'] = True
75 | dataloader_args['worker_init_fn'] = partial(
76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
77 | elif phase in ['val', 'test']: # validation
78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
79 | else:
80 | raise ValueError(f'Wrong dataset phase: {phase}. ' "Supported ones are 'train', 'val' and 'test'.")
81 |
82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
84 |
85 | prefetch_mode = dataset_opt.get('prefetch_mode')
86 | if prefetch_mode == 'cpu': # CPUPrefetcher
87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
88 | logger = get_root_logger()
89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
91 | else:
92 | # prefetch_mode=None: Normal dataloader
93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
94 | return torch.utils.data.DataLoader(**dataloader_args)
95 |
96 |
97 | def worker_init_fn(worker_id, num_workers, rank, seed):
98 | # Set the worker seed to num_workers * rank + worker_id + seed
99 | worker_seed = num_workers * rank + worker_id + seed
100 | np.random.seed(worker_seed)
101 | random.seed(worker_seed)
102 |
--------------------------------------------------------------------------------
/basicsr/data/bsrgan_train_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from torch.utils import data as data
3 |
4 | from basicsr.data.bsrgan_util import degradation_bsrgan
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, img2tensor
7 | from basicsr.utils.registry import DATASET_REGISTRY
8 |
9 | from .data_util import make_dataset
10 |
11 | import cv2
12 | import random
13 |
14 |
15 | def random_resize(img, scale_factor=1.):
16 | return cv2.resize(img, None, fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_CUBIC)
17 |
18 |
19 | def random_crop(img, out_size):
20 | h, w = img.shape[:2]
21 | rnd_h = random.randint(0, h - out_size)
22 | rnd_w = random.randint(0, w - out_size)
23 | return img[rnd_h: rnd_h + out_size, rnd_w: rnd_w + out_size]
24 |
25 |
26 | @DATASET_REGISTRY.register()
27 | class BSRGANTrainDataset(data.Dataset):
28 | """Synthesize LR-HR pairs online with BSRGAN for image restoration.
29 |
30 | Args:
31 | opt (dict): Config for train datasets. It contains the following keys:
32 | dataroot_gt (str): Data root path for gt.
33 | dataroot_lq (str): Data root path for lq.
34 | meta_info_file (str): Path for meta information file.
35 | gt_size (int): Cropped patched size for gt patches.
36 | use_flip (bool): Use horizontal flips.
37 | use_rot (bool): Use rotation (use vertical flip and transposing h
38 | and w for implementation).
39 |
40 | scale (bool): Scale, which will be added automatically.
41 | phase (str): 'train' or 'val'.
42 | """
43 |
44 | def __init__(self, opt):
45 | super(BSRGANTrainDataset, self).__init__()
46 | self.opt = opt
47 | # file client (io backend)
48 | self.file_client = None
49 | self.io_backend_opt = opt['io_backend']
50 |
51 | if opt.get('dataroot_gt') is not None:
52 | self.gt_folder = opt['dataroot_gt']
53 | self.gt_paths = make_dataset(self.gt_folder)
54 | elif opt.get('datafile_gt') is not None:
55 | with open(opt.get('datafile_gt'), "r") as f:
56 | paths = f.read().splitlines()
57 | self.gt_paths = paths
58 | else:
59 | raise ValueError("Unknown path for gt.")
60 |
61 | def __getitem__(self, index):
62 |
63 | scale = self.opt['scale']
64 |
65 | gt_path = self.gt_paths[index]
66 | img_gt = cv2.imread(gt_path).astype(np.float32) / 255.
67 |
68 | img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB
69 | gt_size = self.opt['gt_size']
70 |
71 | if self.opt['phase'] == 'train':
72 | if self.opt['use_resize_crop']:
73 | input_gt_size = img_gt.shape[0]
74 | input_gt_random_size = random.randint(gt_size, input_gt_size)
75 | resize_factor = input_gt_random_size / input_gt_size
76 | img_gt = random_resize(img_gt, resize_factor)
77 |
78 | img_gt = random_crop(img_gt, gt_size)
79 |
80 | img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, lq_patchsize=self.opt['gt_size'] // scale, use_crop=False)
81 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'],
82 | self.opt['use_rot'])
83 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=False, float32=True)
84 |
85 | return {
86 | 'lq': img_lq,
87 | 'gt': img_gt,
88 | 'lq_path': gt_path,
89 | 'gt_path': gt_path
90 | }
91 |
92 | def __len__(self):
93 | return len(self.gt_paths)
94 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas)
27 | self.total_size = self.num_samples * self.num_replicas
28 |
29 | def __iter__(self):
30 | # deterministically shuffle based on epoch
31 | g = torch.Generator()
32 | g.manual_seed(self.epoch)
33 | indices = torch.randperm(self.total_size, generator=g).tolist()
34 |
35 | dataset_size = len(self.dataset)
36 | indices = [v % dataset_size for v in indices]
37 |
38 | # subsample
39 | indices = indices[self.rank:self.total_size:self.num_replicas]
40 | assert len(indices) == self.num_samples
41 |
42 | return iter(indices)
43 |
44 | def __len__(self):
45 | return self.num_samples
46 |
47 | def set_epoch(self, epoch):
48 | self.epoch = epoch
49 |
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
116 |
117 | def next(self):
118 | torch.cuda.current_stream().wait_stream(self.stream)
119 | batch = self.batch
120 | self.preload()
121 | return batch
122 |
123 | def reset(self):
124 | self.loader = iter(self.ori_loader)
125 | self.preload()
126 |
--------------------------------------------------------------------------------
/basicsr/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 | from basicsr.utils.matlab_functions import rgb2ycbcr
8 | from basicsr.utils.registry import DATASET_REGISTRY
9 |
10 |
11 | @DATASET_REGISTRY.register()
12 | class SingleImageDataset(data.Dataset):
13 | """Read only lq images in the test phase.
14 |
15 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
16 |
17 | There are two modes:
18 | 1. 'meta_info_file': Use meta information file to generate paths.
19 | 2. 'folder': Scan folders to generate paths.
20 |
21 | Args:
22 | opt (dict): Config for train datasets. It contains the following keys:
23 | dataroot_lq (str): Data root path for lq.
24 | meta_info_file (str): Path for meta information file.
25 | io_backend (dict): IO backend type and other kwarg.
26 | """
27 |
28 | def __init__(self, opt):
29 | super(SingleImageDataset, self).__init__()
30 | self.opt = opt
31 | # file client (io backend)
32 | self.file_client = None
33 | self.io_backend_opt = opt['io_backend']
34 | self.mean = opt['mean'] if 'mean' in opt else None
35 | self.std = opt['std'] if 'std' in opt else None
36 | self.lq_folder = opt['dataroot_lq']
37 |
38 | if self.io_backend_opt['type'] == 'lmdb':
39 | self.io_backend_opt['db_paths'] = [self.lq_folder]
40 | self.io_backend_opt['client_keys'] = ['lq']
41 | self.paths = paths_from_lmdb(self.lq_folder)
42 | elif 'meta_info_file' in self.opt:
43 | with open(self.opt['meta_info_file'], 'r') as fin:
44 | self.paths = [osp.join(self.lq_folder, line.rstrip().split(' ')[0]) for line in fin]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
51 |
52 | # load lq image
53 | lq_path = self.paths[index]
54 | img_bytes = self.file_client.get(lq_path, 'lq')
55 | img_lq = imfrombytes(img_bytes, float32=True)
56 |
57 | # color space transform
58 | if 'color' in self.opt and self.opt['color'] == 'y':
59 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
60 |
61 | # BGR to RGB, HWC to CHW, numpy to tensor
62 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
63 | # normalize
64 | if self.mean is not None or self.std is not None:
65 | normalize(img_lq, self.mean, self.std, inplace=True)
66 | return {'lq': img_lq, 'lq_path': lq_path}
67 |
68 | def __len__(self):
69 | return len(self.paths)
70 |
--------------------------------------------------------------------------------
/basicsr/data/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import torch
4 |
5 |
6 | def mod_crop(img, scale):
7 | """Mod crop images, used during testing.
8 |
9 | Args:
10 | img (ndarray): Input image.
11 | scale (int): Scale factor.
12 |
13 | Returns:
14 | ndarray: Result image.
15 | """
16 | img = img.copy()
17 | if img.ndim in (2, 3):
18 | h, w = img.shape[0], img.shape[1]
19 | h_remainder, w_remainder = h % scale, w % scale
20 | img = img[:h - h_remainder, :w - w_remainder, ...]
21 | else:
22 | raise ValueError(f'Wrong img ndim: {img.ndim}.')
23 | return img
24 |
25 |
26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
27 | """Paired random crop. Support Numpy array and Tensor inputs.
28 |
29 | It crops lists of lq and gt images with corresponding locations.
30 |
31 | Args:
32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
33 | should have the same shape. If the input is an ndarray, it will
34 | be transformed to a list containing itself.
35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
36 | should have the same shape. If the input is an ndarray, it will
37 | be transformed to a list containing itself.
38 | gt_patch_size (int): GT patch size.
39 | scale (int): Scale factor.
40 | gt_path (str): Path to ground-truth. Default: None.
41 |
42 | Returns:
43 | list[ndarray] | ndarray: GT images and LQ images. If returned results
44 | only have one element, just return ndarray.
45 | """
46 |
47 | if not isinstance(img_gts, list):
48 | img_gts = [img_gts]
49 | if not isinstance(img_lqs, list):
50 | img_lqs = [img_lqs]
51 |
52 | # determine input type: Numpy array or Tensor
53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
54 |
55 | if input_type == 'Tensor':
56 | h_lq, w_lq = img_lqs[0].size()[-2:]
57 | h_gt, w_gt = img_gts[0].size()[-2:]
58 | else:
59 | h_lq, w_lq = img_lqs[0].shape[0:2]
60 | h_gt, w_gt = img_gts[0].shape[0:2]
61 | lq_patch_size = gt_patch_size // scale
62 |
63 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
65 | f'multiplication of LQ ({h_lq}, {w_lq}).')
66 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
68 | f'({lq_patch_size}, {lq_patch_size}). '
69 | f'Please remove {gt_path}.')
70 |
71 | # randomly choose top and left coordinates for lq patch
72 | top = random.randint(0, h_lq - lq_patch_size)
73 | left = random.randint(0, w_lq - lq_patch_size)
74 |
75 | # crop lq patch
76 | if input_type == 'Tensor':
77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
78 | else:
79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
80 |
81 | # crop corresponding gt patch
82 | top_gt, left_gt = int(top * scale), int(left * scale)
83 | if input_type == 'Tensor':
84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
85 | else:
86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
87 | if len(img_gts) == 1:
88 | img_gts = img_gts[0]
89 | if len(img_lqs) == 1:
90 | img_lqs = img_lqs[0]
91 | return img_gts, img_lqs
92 |
93 |
94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
96 |
97 | We use vertical flip and transpose for rotation implementation.
98 | All the images in the list use the same augmentation.
99 |
100 | Args:
101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
102 | is an ndarray, it will be transformed to a list.
103 | hflip (bool): Horizontal flip. Default: True.
104 | rotation (bool): Ratotation. Default: True.
105 | flows (list[ndarray]: Flows to be augmented. If the input is an
106 | ndarray, it will be transformed to a list.
107 | Dimension is (h, w, 2). Default: None.
108 | return_status (bool): Return the status of flip and rotation.
109 | Default: False.
110 |
111 | Returns:
112 | list[ndarray] | ndarray: Augmented images and flows. If returned
113 | results only have one element, just return ndarray.
114 |
115 | """
116 | hflip = hflip and random.random() < 0.5
117 | vflip = rotation and random.random() < 0.5
118 | rot90 = rotation and random.random() < 0.5
119 |
120 | def _augment(img):
121 | if hflip: # horizontal
122 | cv2.flip(img, 1, img)
123 | if vflip: # vertical
124 | cv2.flip(img, 0, img)
125 | if rot90:
126 | img = img.transpose(1, 0, 2)
127 | return img
128 |
129 | def _augment_flow(flow):
130 | if hflip: # horizontal
131 | cv2.flip(flow, 1, flow)
132 | flow[:, :, 0] *= -1
133 | if vflip: # vertical
134 | cv2.flip(flow, 0, flow)
135 | flow[:, :, 1] *= -1
136 | if rot90:
137 | flow = flow.transpose(1, 0, 2)
138 | flow = flow[:, :, [1, 0]]
139 | return flow
140 |
141 | if not isinstance(imgs, list):
142 | imgs = [imgs]
143 | imgs = [_augment(img) for img in imgs]
144 | if len(imgs) == 1:
145 | imgs = imgs[0]
146 |
147 | if flows is not None:
148 | if not isinstance(flows, list):
149 | flows = [flows]
150 | flows = [_augment_flow(flow) for flow in flows]
151 | if len(flows) == 1:
152 | flows = flows[0]
153 | return imgs, flows
154 | else:
155 | if return_status:
156 | return imgs, (hflip, vflip, rot90)
157 | else:
158 | return imgs
159 |
160 |
161 | def img_rotate(img, angle, center=None, scale=1.0):
162 | """Rotate image.
163 |
164 | Args:
165 | img (ndarray): Image to be rotated.
166 | angle (float): Rotation angle in degrees. Positive values mean
167 | counter-clockwise rotation.
168 | center (tuple[int]): Rotation center. If the center is None,
169 | initialize it as the center of the image. Default: None.
170 | scale (float): Isotropic scale factor. Default: 1.0.
171 | """
172 | (h, w) = img.shape[:2]
173 |
174 | if center is None:
175 | center = (w // 2, h // 2)
176 |
177 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
178 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
179 | return rotated_img
180 |
--------------------------------------------------------------------------------
/basicsr/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 |
3 | from basicsr.utils import get_root_logger
4 | from basicsr.utils.registry import LOSS_REGISTRY
5 | from .losses import (CharbonnierLoss, GANLoss, ContrastiveLoss, L1Loss, MSELoss, SoftCrossEntropy, PerceptualLoss, WeightedTVLoss, g_path_regularize,
6 | gradient_penalty_loss, r1_penalty)
7 |
8 | __all__ = [
9 | 'ContrastiveLoss', 'L1Loss', 'MSELoss', 'SoftCrossEntropy', 'CharbonnierLoss', 'WeightedTVLoss', 'PerceptualLoss', 'GANLoss', 'gradient_penalty_loss',
10 | 'r1_penalty', 'g_path_regularize'
11 | ]
12 |
13 |
14 | def build_loss(opt):
15 | """Build loss from options.
16 |
17 | Args:
18 | opt (dict): Configuration. It must contain:
19 | type (str): Model type.
20 | """
21 | opt = deepcopy(opt)
22 | loss_type = opt.pop('type')
23 | loss = LOSS_REGISTRY.get(loss_type)(**opt)
24 | logger = get_root_logger()
25 | logger.info(f'Loss [{loss.__class__.__name__}] is created.')
26 | return loss
27 |
--------------------------------------------------------------------------------
/basicsr/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import MODEL_REGISTRY
7 |
8 | __all__ = ['build_model']
9 |
10 | # automatically scan and import model modules for registry
11 | # scan all the files under the 'models' folder and collect files ending with
12 | # '_model.py'
13 | model_folder = osp.dirname(osp.abspath(__file__))
14 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
15 | # import all the model modules
16 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames]
17 |
18 |
19 | def build_model(opt):
20 | """Build model from options.
21 |
22 | Args:
23 | opt (dict): Configuration. It must contain:
24 | model_type (str): Model type.
25 | """
26 | opt = deepcopy(opt)
27 | model = MODEL_REGISTRY.get(opt['model_type'])(opt)
28 | logger = get_root_logger()
29 | logger.info(f'Model [{model.__class__.__name__}] is created.')
30 | return model
31 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 |
5 |
6 | class MultiStepRestartLR(_LRScheduler):
7 | """ MultiStep with restarts learning rate scheme.
8 |
9 | Args:
10 | optimizer (torch.nn.optimizer): Torch optimizer.
11 | milestones (list): Iterations that will decrease learning rate.
12 | gamma (float): Decrease ratio. Default: 0.1.
13 | restarts (list): Restart iterations. Default: [0].
14 | restart_weights (list): Restart weights at each restart iteration.
15 | Default: [1].
16 | last_epoch (int): Used in _LRScheduler. Default: -1.
17 | """
18 |
19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
20 | self.milestones = Counter(milestones)
21 | self.gamma = gamma
22 | self.restarts = restarts
23 | self.restart_weights = restart_weights
24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.'
25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
26 |
27 | def get_lr(self):
28 | if self.last_epoch in self.restarts:
29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
31 | if self.last_epoch not in self.milestones:
32 | return [group['lr'] for group in self.optimizer.param_groups]
33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
34 |
35 |
36 | def get_position_from_periods(iteration, cumulative_period):
37 | """Get the position from a period list.
38 |
39 | It will return the index of the right-closest number in the period list.
40 | For example, the cumulative_period = [100, 200, 300, 400],
41 | if iteration == 50, return 0;
42 | if iteration == 210, return 2;
43 | if iteration == 300, return 2.
44 |
45 | Args:
46 | iteration (int): Current iteration.
47 | cumulative_period (list[int]): Cumulative period list.
48 |
49 | Returns:
50 | int: The position of the right-closest number in the period list.
51 | """
52 | for i, period in enumerate(cumulative_period):
53 | if iteration <= period:
54 | return i
55 |
56 |
57 | class CosineAnnealingRestartLR(_LRScheduler):
58 | """ Cosine annealing with restarts learning rate scheme.
59 |
60 | An example of config:
61 | periods = [10, 10, 10, 10]
62 | restart_weights = [1, 0.5, 0.5, 0.5]
63 | eta_min=1e-7
64 |
65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
66 | scheduler will restart with the weights in restart_weights.
67 |
68 | Args:
69 | optimizer (torch.nn.optimizer): Torch optimizer.
70 | periods (list): Period for each cosine anneling cycle.
71 | restart_weights (list): Restart weights at each restart iteration.
72 | Default: [1].
73 | eta_min (float): The minimum lr. Default: 0.
74 | last_epoch (int): Used in _LRScheduler. Default: -1.
75 | """
76 |
77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
78 | self.periods = periods
79 | self.restart_weights = restart_weights
80 | self.eta_min = eta_min
81 | assert (len(self.periods) == len(
82 | self.restart_weights)), 'periods and restart_weights should have the same length.'
83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
85 |
86 | def get_lr(self):
87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period)
88 | current_weight = self.restart_weights[idx]
89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
90 | current_period = self.periods[idx]
91 |
92 | return [
93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
95 | for base_lr in self.base_lrs
96 | ]
97 |
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import build_dataloader, build_dataset
6 | from basicsr.models import build_model
7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
8 | from basicsr.utils.options import dict2str, parse_options
9 |
10 |
11 | def test_pipeline(root_path):
12 | # parse options, set distributed setting, set ramdom seed
13 | opt, _ = parse_options(root_path, is_train=False)
14 |
15 | torch.backends.cudnn.benchmark = True
16 | # torch.backends.cudnn.deterministic = True
17 |
18 | # mkdir and initialize loggers
19 | make_exp_dirs(opt)
20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
22 | logger.info(get_env_info())
23 | logger.info(dict2str(opt))
24 |
25 | # create test dataset and dataloader
26 | test_loaders = []
27 | for _, dataset_opt in sorted(opt['datasets'].items()):
28 | test_set = build_dataset(dataset_opt)
29 | test_loader = build_dataloader(
30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
32 | test_loaders.append(test_loader)
33 |
34 | # create model
35 | model = build_model(opt)
36 |
37 | for test_loader in test_loaders:
38 | test_set_name = test_loader.dataset.opt['name']
39 | logger.info(f'Testing {test_set_name}...')
40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
41 |
42 |
43 | if __name__ == '__main__':
44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
45 | test_pipeline(root_path)
46 |
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .diffjpeg import DiffJPEG
2 | from .file_client import FileClient
3 | from .img_process_util import USMSharp, usm_sharp
4 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img
5 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger
6 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt
7 |
8 | __all__ = [
9 | # file_client.py
10 | 'FileClient',
11 | # img_util.py
12 | 'img2tensor',
13 | 'tensor2img',
14 | 'imfrombytes',
15 | 'imwrite',
16 | 'crop_border',
17 | # logger.py
18 | 'MessageLogger',
19 | 'AvgTimer',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | # diffjpeg
33 | 'DiffJPEG',
34 | # img_process_util
35 | 'USMSharp',
36 | 'usm_sharp'
37 | ]
38 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1')
45 | # specify master port
46 | if port is not None:
47 | os.environ['MASTER_PORT'] = str(port)
48 | elif 'MASTER_PORT' in os.environ:
49 | pass # use MASTER_PORT in the environment variable
50 | else:
51 | # 29500 is torch.distributed default port
52 | os.environ['MASTER_PORT'] = '29500'
53 | os.environ['MASTER_ADDR'] = addr
54 | os.environ['WORLD_SIZE'] = str(ntasks)
55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
56 | os.environ['RANK'] = str(proc_id)
57 | dist.init_process_group(backend=backend)
58 |
59 |
60 | def get_dist_info():
61 | if dist.is_available():
62 | initialized = dist.is_initialized()
63 | else:
64 | initialized = False
65 | if initialized:
66 | rank = dist.get_rank()
67 | world_size = dist.get_world_size()
68 | else:
69 | rank = 0
70 | world_size = 1
71 | return rank, world_size
72 |
73 |
74 | def master_only(func):
75 |
76 | @functools.wraps(func)
77 | def wrapper(*args, **kwargs):
78 | rank, _ = get_dist_info()
79 | if rank == 0:
80 | return func(*args, **kwargs)
81 |
82 | return wrapper
83 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | import requests
4 | from torch.hub import download_url_to_file, get_dir
5 | from tqdm import tqdm
6 | from urllib.parse import urlparse
7 |
8 | from .misc import sizeof_fmt
9 |
10 |
11 | def download_file_from_google_drive(file_id, save_path):
12 | """Download files from google drive.
13 |
14 | Ref:
15 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
16 |
17 | Args:
18 | file_id (str): File id.
19 | save_path (str): Save path.
20 | """
21 |
22 | session = requests.Session()
23 | URL = 'https://docs.google.com/uc?export=download'
24 | params = {'id': file_id}
25 |
26 | response = session.get(URL, params=params, stream=True)
27 | token = get_confirm_token(response)
28 | if token:
29 | params['confirm'] = token
30 | response = session.get(URL, params=params, stream=True)
31 |
32 | # get file size
33 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
34 | if 'Content-Range' in response_file_size.headers:
35 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
36 | else:
37 | file_size = None
38 |
39 | save_response_content(response, save_path, file_size)
40 |
41 |
42 | def get_confirm_token(response):
43 | for key, value in response.cookies.items():
44 | if key.startswith('download_warning'):
45 | return value
46 | return None
47 |
48 |
49 | def save_response_content(response, destination, file_size=None, chunk_size=32768):
50 | if file_size is not None:
51 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
52 |
53 | readable_file_size = sizeof_fmt(file_size)
54 | else:
55 | pbar = None
56 |
57 | with open(destination, 'wb') as f:
58 | downloaded_size = 0
59 | for chunk in response.iter_content(chunk_size):
60 | downloaded_size += chunk_size
61 | if pbar is not None:
62 | pbar.update(1)
63 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
64 | if chunk: # filter out keep-alive new chunks
65 | f.write(chunk)
66 | if pbar is not None:
67 | pbar.close()
68 |
69 |
70 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
71 | """Load file form http url, will download models if necessary.
72 |
73 | Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
74 |
75 | Args:
76 | url (str): URL to be downloaded.
77 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
78 | Default: None.
79 | progress (bool): Whether to show the download progress. Default: True.
80 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
81 |
82 | Returns:
83 | str: The path to the downloaded file.
84 | """
85 | if model_dir is None: # use the pytorch hub_dir
86 | hub_dir = get_dir()
87 | model_dir = os.path.join(hub_dir, 'checkpoints')
88 |
89 | os.makedirs(model_dir, exist_ok=True)
90 |
91 | parts = urlparse(url)
92 | filename = os.path.basename(parts.path)
93 | if file_name is not None:
94 | filename = file_name
95 | cached_file = os.path.abspath(os.path.join(model_dir, filename))
96 | if not os.path.exists(cached_file):
97 | print(f'Downloading: "{url}" to {cached_file}\n')
98 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
99 | return cached_file
100 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError('Please install memcached to enable MemcachedBackend.')
40 |
41 | self.server_list_cfg = server_list_cfg
42 | self.client_cfg = client_cfg
43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
44 | # mc.pyvector servers as a point which points to a memory cache
45 | self._mc_buffer = mc.pyvector()
46 |
47 | def get(self, filepath):
48 | filepath = str(filepath)
49 | import mc
50 | self._client.Get(filepath, self._mc_buffer)
51 | value_buf = mc.ConvertBuffer(self._mc_buffer)
52 | return value_buf
53 |
54 | def get_text(self, filepath):
55 | raise NotImplementedError
56 |
57 |
58 | class HardDiskBackend(BaseStorageBackend):
59 | """Raw hard disks storage backend."""
60 |
61 | def get(self, filepath):
62 | filepath = str(filepath)
63 | with open(filepath, 'rb') as f:
64 | value_buf = f.read()
65 | return value_buf
66 |
67 | def get_text(self, filepath):
68 | filepath = str(filepath)
69 | with open(filepath, 'r') as f:
70 | value_buf = f.read()
71 | return value_buf
72 |
73 |
74 | class LmdbBackend(BaseStorageBackend):
75 | """Lmdb storage backend.
76 |
77 | Args:
78 | db_paths (str | list[str]): Lmdb database paths.
79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
80 | readonly (bool, optional): Lmdb environment parameter. If True,
81 | disallow any write operations. Default: True.
82 | lock (bool, optional): Lmdb environment parameter. If False, when
83 | concurrent access occurs, do not lock the database. Default: False.
84 | readahead (bool, optional): Lmdb environment parameter. If False,
85 | disable the OS filesystem readahead mechanism, which may improve
86 | random read performance when a database is larger than RAM.
87 | Default: False.
88 |
89 | Attributes:
90 | db_paths (list): Lmdb database path.
91 | _client (list): A list of several lmdb envs.
92 | """
93 |
94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
95 | try:
96 | import lmdb
97 | except ImportError:
98 | raise ImportError('Please install lmdb to enable LmdbBackend.')
99 |
100 | if isinstance(client_keys, str):
101 | client_keys = [client_keys]
102 |
103 | if isinstance(db_paths, list):
104 | self.db_paths = [str(v) for v in db_paths]
105 | elif isinstance(db_paths, str):
106 | self.db_paths = [str(db_paths)]
107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
108 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
109 |
110 | self._client = {}
111 | for client, path in zip(client_keys, self.db_paths):
112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
113 |
114 | def get(self, filepath, client_key):
115 | """Get values according to the filepath from one lmdb named client_key.
116 |
117 | Args:
118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
119 | client_key (str): Used for distinguishing different lmdb envs.
120 | """
121 | filepath = str(filepath)
122 | assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
123 | client = self._client[client_key]
124 | with client.begin(write=False) as txn:
125 | value_buf = txn.get(filepath.encode('ascii'))
126 | return value_buf
127 |
128 | def get_text(self, filepath):
129 | raise NotImplementedError
130 |
131 |
132 | class FileClient(object):
133 | """A general file client to access files in different backend.
134 |
135 | The client loads a file or text in a specified backend from its path
136 | and return it as a binary file. it can also register other backend
137 | accessor with a given name and backend class.
138 |
139 | Attributes:
140 | backend (str): The storage backend type. Options are "disk",
141 | "memcached" and "lmdb".
142 | client (:obj:`BaseStorageBackend`): The backend object.
143 | """
144 |
145 | _backends = {
146 | 'disk': HardDiskBackend,
147 | 'memcached': MemcachedBackend,
148 | 'lmdb': LmdbBackend,
149 | }
150 |
151 | def __init__(self, backend='disk', **kwargs):
152 | if backend not in self._backends:
153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
154 | f' are {list(self._backends.keys())}')
155 | self.backend = backend
156 | self.client = self._backends[backend](**kwargs)
157 |
158 | def get(self, filepath, client_key='default'):
159 | # client_key is used only for lmdb, where different fileclients have
160 | # different lmdb environments.
161 | if self.backend == 'lmdb':
162 | return self.client.get(filepath, client_key)
163 | else:
164 | return self.client.get(filepath)
165 |
166 | def get_text(self, filepath):
167 | return self.client.get_text(filepath)
168 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
25 | assert cat_flow.shape[concat_axis] % 2 == 0
26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
27 | flow = dequantize_flow(dx, dy, *args, **kwargs)
28 | else:
29 | with open(flow_path, 'rb') as f:
30 | try:
31 | header = f.read(4).decode('utf-8')
32 | except Exception:
33 | raise IOError(f'Invalid flow file: {flow_path}')
34 | else:
35 | if header != 'PIEH':
36 | raise IOError(f'Invalid flow file: {flow_path}, ' 'header does not contain PIEH')
37 |
38 | w = np.fromfile(f, np.int32, 1).squeeze()
39 | h = np.fromfile(f, np.int32, 1).squeeze()
40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
41 |
42 | return flow.astype(np.float32)
43 |
44 |
45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
46 | """Write optical flow to file.
47 |
48 | If the flow is not quantized, it will be saved as a .flo file losslessly,
49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
50 | will be concatenated horizontally into a single image if quantize is True.)
51 |
52 | Args:
53 | flow (ndarray): (h, w, 2) array of optical flow.
54 | filename (str): Output filepath.
55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
56 | images. If set to True, remaining args will be passed to
57 | :func:`quantize_flow`.
58 | concat_axis (int): The axis that dx and dy are concatenated,
59 | can be either 0 or 1. Ignored if quantize is False.
60 | """
61 | if not quantize:
62 | with open(filename, 'wb') as f:
63 | f.write('PIEH'.encode('utf-8'))
64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
65 | flow = flow.astype(np.float32)
66 | flow.tofile(f)
67 | f.flush()
68 | else:
69 | assert concat_axis in [0, 1]
70 | dx, dy = quantize_flow(flow, *args, **kwargs)
71 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
72 | os.makedirs(os.path.dirname(filename), exist_ok=True)
73 | cv2.imwrite(filename, dxdy)
74 |
75 |
76 | def quantize_flow(flow, max_val=0.02, norm=True):
77 | """Quantize flow to [0, 255].
78 |
79 | After this step, the size of flow will be much smaller, and can be
80 | dumped as jpeg images.
81 |
82 | Args:
83 | flow (ndarray): (h, w, 2) array of optical flow.
84 | max_val (float): Maximum value of flow, values beyond
85 | [-max_val, max_val] will be truncated.
86 | norm (bool): Whether to divide flow values by image width/height.
87 |
88 | Returns:
89 | tuple[ndarray]: Quantized dx and dy.
90 | """
91 | h, w, _ = flow.shape
92 | dx = flow[..., 0]
93 | dy = flow[..., 1]
94 | if norm:
95 | dx = dx / w # avoid inplace operations
96 | dy = dy / h
97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
99 | return tuple(flow_comps)
100 |
101 |
102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
103 | """Recover from quantized flow.
104 |
105 | Args:
106 | dx (ndarray): Quantized dx.
107 | dy (ndarray): Quantized dy.
108 | max_val (float): Maximum value used when quantizing.
109 | denorm (bool): Whether to multiply flow values with width/height.
110 |
111 | Returns:
112 | ndarray: Dequantized flow.
113 | """
114 | assert dx.shape == dy.shape
115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
116 |
117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
118 |
119 | if denorm:
120 | dx *= dx.shape[1]
121 | dy *= dx.shape[0]
122 | flow = np.dstack((dx, dy))
123 | return flow
124 |
125 |
126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
127 | """Quantize an array of (-inf, inf) to [0, levels-1].
128 |
129 | Args:
130 | arr (ndarray): Input array.
131 | min_val (scalar): Minimum value to be clipped.
132 | max_val (scalar): Maximum value to be clipped.
133 | levels (int): Quantization levels.
134 | dtype (np.type): The type of the quantized array.
135 |
136 | Returns:
137 | tuple: Quantized array.
138 | """
139 | if not (isinstance(levels, int) and levels > 1):
140 | raise ValueError(f'levels must be a positive integer, but got {levels}')
141 | if min_val >= max_val:
142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
143 |
144 | arr = np.clip(arr, min_val, max_val) - min_val
145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
146 |
147 | return quantized_arr
148 |
149 |
150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
151 | """Dequantize an array.
152 |
153 | Args:
154 | arr (ndarray): Input array.
155 | min_val (scalar): Minimum value to be clipped.
156 | max_val (scalar): Maximum value to be clipped.
157 | levels (int): Quantization levels.
158 | dtype (np.type): The type of the dequantized array.
159 |
160 | Returns:
161 | tuple: Dequantized array.
162 | """
163 | if not (isinstance(levels, int) and levels > 1):
164 | raise ValueError(f'levels must be a positive integer, but got {levels}')
165 | if min_val >= max_val:
166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
167 |
168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
169 |
170 | return dequantized_arr
171 |
--------------------------------------------------------------------------------
/basicsr/utils/img_process_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 | from torch.nn import functional as F
5 |
6 |
7 | def filter2D(img, kernel):
8 | """PyTorch version of cv2.filter2D
9 |
10 | Args:
11 | img (Tensor): (b, c, h, w)
12 | kernel (Tensor): (b, k, k)
13 | """
14 | k = kernel.size(-1)
15 | b, c, h, w = img.size()
16 | if k % 2 == 1:
17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect')
18 | else:
19 | raise ValueError('Wrong kernel size')
20 |
21 | ph, pw = img.size()[-2:]
22 |
23 | if kernel.size(0) == 1:
24 | # apply the same kernel to all batch images
25 | img = img.view(b * c, 1, ph, pw)
26 | kernel = kernel.view(1, 1, k, k)
27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w)
28 | else:
29 | img = img.view(1, b * c, ph, pw)
30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k)
31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w)
32 |
33 |
34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10):
35 | """USM sharpening.
36 |
37 | Input image: I; Blurry image: B.
38 | 1. sharp = I + weight * (I - B)
39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0
40 | 3. Blur mask:
41 | 4. Out = Mask * sharp + (1 - Mask) * I
42 |
43 |
44 | Args:
45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1].
46 | weight (float): Sharp weight. Default: 1.
47 | radius (float): Kernel size of Gaussian blur. Default: 50.
48 | threshold (int):
49 | """
50 | if radius % 2 == 0:
51 | radius += 1
52 | blur = cv2.GaussianBlur(img, (radius, radius), 0)
53 | residual = img - blur
54 | mask = np.abs(residual) * 255 > threshold
55 | mask = mask.astype('float32')
56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0)
57 |
58 | sharp = img + weight * residual
59 | sharp = np.clip(sharp, 0, 1)
60 | return soft_mask * sharp + (1 - soft_mask) * img
61 |
62 |
63 | class USMSharp(torch.nn.Module):
64 |
65 | def __init__(self, radius=50, sigma=0):
66 | super(USMSharp, self).__init__()
67 | if radius % 2 == 0:
68 | radius += 1
69 | self.radius = radius
70 | kernel = cv2.getGaussianKernel(radius, sigma)
71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0)
72 | self.register_buffer('kernel', kernel)
73 |
74 | def forward(self, img, weight=0.5, threshold=10):
75 | blur = filter2D(img, self.kernel)
76 | residual = img - blur
77 |
78 | mask = torch.abs(residual) * 255 > threshold
79 | mask = mask.float()
80 | soft_mask = filter2D(mask, self.kernel)
81 | sharp = img + weight * residual
82 | sharp = torch.clip(sharp, 0, 1)
83 | return soft_mask * sharp + (1 - soft_mask) * img
84 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | if img.dtype == 'float64':
25 | img = img.astype('float32')
26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
27 | img = torch.from_numpy(img.transpose(2, 0, 1))
28 | if float32:
29 | img = img.float()
30 | return img
31 |
32 | if isinstance(imgs, list):
33 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
34 | else:
35 | return _totensor(imgs, bgr2rgb, float32)
36 |
37 |
38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
39 | """Convert torch Tensors into image numpy arrays.
40 |
41 | After clamping to [min, max], values will be normalized to [0, 1].
42 |
43 | Args:
44 | tensor (Tensor or list[Tensor]): Accept shapes:
45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
46 | 2) 3D Tensor of shape (3/1 x H x W);
47 | 3) 2D Tensor of shape (H x W).
48 | Tensor channel should be in RGB order.
49 | rgb2bgr (bool): Whether to change rgb to bgr.
50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
51 | to uint8 type with range [0, 255]; otherwise, float type with
52 | range [0, 1]. Default: ``np.uint8``.
53 | min_max (tuple[int]): min and max values for clamp.
54 |
55 | Returns:
56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
57 | shape (H x W). The channel order is BGR.
58 | """
59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
61 |
62 | if torch.is_tensor(tensor):
63 | tensor = [tensor]
64 | result = []
65 | for _tensor in tensor:
66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
68 |
69 | n_dim = _tensor.dim()
70 | if n_dim == 4:
71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
72 | img_np = img_np.transpose(1, 2, 0)
73 | if rgb2bgr:
74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
75 | elif n_dim == 3:
76 | img_np = _tensor.numpy()
77 | img_np = img_np.transpose(1, 2, 0)
78 | if img_np.shape[2] == 1: # gray image
79 | img_np = np.squeeze(img_np, axis=2)
80 | else:
81 | if rgb2bgr:
82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
83 | elif n_dim == 2:
84 | img_np = _tensor.numpy()
85 | else:
86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}')
87 | if out_type == np.uint8:
88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
89 | img_np = (img_np * 255.0).round()
90 | img_np = img_np.astype(out_type)
91 | result.append(img_np)
92 | if len(result) == 1:
93 | result = result[0]
94 | return result
95 |
96 |
97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
98 | """This implementation is slightly faster than tensor2img.
99 | It now only supports torch tensor with shape (1, c, h, w).
100 |
101 | Args:
102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w).
103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
104 | min_max (tuple[int]): min and max values for clamp.
105 | """
106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
108 | output = output.type(torch.uint8).cpu().numpy()
109 | if rgb2bgr:
110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
111 | return output
112 |
113 |
114 | def imfrombytes(content, flag='color', float32=False):
115 | """Read an image from bytes.
116 |
117 | Args:
118 | content (bytes): Image bytes got from files or other streams.
119 | flag (str): Flags specifying the color type of a loaded image,
120 | candidates are `color`, `grayscale` and `unchanged`.
121 | float32 (bool): Whether to change to float32., If True, will also norm
122 | to [0, 1]. Default: False.
123 |
124 | Returns:
125 | ndarray: Loaded image array.
126 | """
127 | img_np = np.frombuffer(content, np.uint8)
128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
129 | img = cv2.imdecode(img_np, imread_flags[flag])
130 | if float32:
131 | img = img.astype(np.float32) / 255.
132 | return img
133 |
134 |
135 | def imwrite(img, file_path, params=None, auto_mkdir=True):
136 | """Write image to file.
137 |
138 | Args:
139 | img (ndarray): Image array to be written.
140 | file_path (str): Image file path.
141 | params (None or list): Same as opencv's :func:`imwrite` interface.
142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
143 | whether to create it automatically.
144 |
145 | Returns:
146 | bool: Successful or not.
147 | """
148 | if auto_mkdir:
149 | dir_name = os.path.abspath(os.path.dirname(file_path))
150 | os.makedirs(dir_name, exist_ok=True)
151 | ok = cv2.imwrite(file_path, img, params)
152 | if not ok:
153 | raise IOError('Failed in writing images.')
154 |
155 |
156 | def crop_border(imgs, crop_border):
157 | """Crop borders of images.
158 |
159 | Args:
160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
161 | crop_border (int): Crop border for each end of height and weight.
162 |
163 | Returns:
164 | list[ndarray]: Cropped images.
165 | """
166 | if crop_border == 0:
167 | return imgs
168 | else:
169 | if isinstance(imgs, list):
170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
171 | else:
172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
173 |
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 |
10 |
11 | def set_random_seed(seed):
12 | """Set random seeds."""
13 | random.seed(seed)
14 | np.random.seed(seed)
15 | torch.manual_seed(seed)
16 | torch.cuda.manual_seed(seed)
17 | torch.cuda.manual_seed_all(seed)
18 |
19 |
20 | def get_time_str():
21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
22 |
23 |
24 | def mkdir_and_rename(path):
25 | """mkdirs. If path exists, rename it with timestamp and create a new one.
26 |
27 | Args:
28 | path (str): Folder path.
29 | """
30 | if osp.exists(path):
31 | new_name = path + '_archived_' + get_time_str()
32 | print(f'Path already exists. Rename it to {new_name}', flush=True)
33 | os.rename(path, new_name)
34 | os.makedirs(path, exist_ok=True)
35 |
36 |
37 | @master_only
38 | def make_exp_dirs(opt):
39 | """Make dirs for experiments."""
40 | path_opt = opt['path'].copy()
41 | if opt['is_train']:
42 | mkdir_and_rename(path_opt.pop('experiments_root'))
43 | else:
44 | mkdir_and_rename(path_opt.pop('results_root'))
45 | for key, path in path_opt.items():
46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key) or ('pretrain_codebook' in key):
47 | continue
48 | else:
49 | os.makedirs(path, exist_ok=True)
50 |
51 |
52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
53 | """Scan a directory to find the interested files.
54 |
55 | Args:
56 | dir_path (str): Path of the directory.
57 | suffix (str | tuple(str), optional): File suffix that we are
58 | interested in. Default: None.
59 | recursive (bool, optional): If set to True, recursively scan the
60 | directory. Default: False.
61 | full_path (bool, optional): If set to True, include the dir_path.
62 | Default: False.
63 |
64 | Returns:
65 | A generator for all the interested files with relative paths.
66 | """
67 |
68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
69 | raise TypeError('"suffix" must be a string or tuple of strings')
70 |
71 | root = dir_path
72 |
73 | def _scandir(dir_path, suffix, recursive):
74 | for entry in os.scandir(dir_path):
75 | if not entry.name.startswith('.') and entry.is_file():
76 | if full_path:
77 | return_path = entry.path
78 | else:
79 | return_path = osp.relpath(entry.path, root)
80 |
81 | if suffix is None:
82 | yield return_path
83 | elif return_path.endswith(suffix):
84 | yield return_path
85 | else:
86 | if recursive:
87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
88 | else:
89 | continue
90 |
91 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
92 |
93 |
94 | def check_resume(opt, resume_iter):
95 | """Check resume states and pretrain_network paths.
96 |
97 | Args:
98 | opt (dict): Options.
99 | resume_iter (int): Resume iteration.
100 | """
101 | if opt['path']['resume_state']:
102 | # get all the networks
103 | networks = [key for key in opt.keys() if key.startswith('network_')]
104 | flag_pretrain = False
105 | for network in networks:
106 | if opt['path'].get(f'pretrain_{network}') is not None:
107 | flag_pretrain = True
108 | if flag_pretrain:
109 | print('pretrain_network path will be ignored during resuming.')
110 | # set pretrained model paths
111 | for network in networks:
112 | name = f'pretrain_{network}'
113 | basename = network.replace('network_', '')
114 | if opt['path'].get('ignore_resume_networks') is None or (network
115 | not in opt['path']['ignore_resume_networks']):
116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
117 | print(f"Set {name} to {opt['path'][name]}")
118 |
119 | # change param_key to params in resume
120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
121 | for param_key in param_keys:
122 | if opt['path'][param_key] == 'params_ema':
123 | opt['path'][param_key] = 'params'
124 | print(f'Set {param_key} to params')
125 |
126 |
127 | def sizeof_fmt(size, suffix='B'):
128 | """Get human readable file size.
129 |
130 | Args:
131 | size (int): File size.
132 | suffix (str): Suffix. Default: 'B'.
133 |
134 | Return:
135 | str: Formatted file siz.
136 | """
137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
138 | if abs(size) < 1024.0:
139 | return f'{size:3.1f} {unit}{suffix}'
140 | size /= 1024.0
141 | return f'{size:3.1f} Y{suffix}'
142 |
--------------------------------------------------------------------------------
/basicsr/utils/registry.py:
--------------------------------------------------------------------------------
1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
2 |
3 |
4 | class Registry():
5 | """
6 | The registry that provides name -> object mapping, to support third-party
7 | users' custom modules.
8 |
9 | To create a registry (e.g. a backbone registry):
10 |
11 | .. code-block:: python
12 |
13 | BACKBONE_REGISTRY = Registry('BACKBONE')
14 |
15 | To register an object:
16 |
17 | .. code-block:: python
18 |
19 | @BACKBONE_REGISTRY.register()
20 | class MyBackbone():
21 | ...
22 |
23 | Or:
24 |
25 | .. code-block:: python
26 |
27 | BACKBONE_REGISTRY.register(MyBackbone)
28 | """
29 |
30 | def __init__(self, name):
31 | """
32 | Args:
33 | name (str): the name of this registry
34 | """
35 | self._name = name
36 | self._obj_map = {}
37 |
38 | def _do_register(self, name, obj):
39 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
40 | f"in '{self._name}' registry!")
41 | self._obj_map[name] = obj
42 |
43 | def register(self, obj=None):
44 | """
45 | Register the given object under the the name `obj.__name__`.
46 | Can be used as either a decorator or not.
47 | See docstring of this class for usage.
48 | """
49 | if obj is None:
50 | # used as a decorator
51 | def deco(func_or_class):
52 | name = func_or_class.__name__
53 | self._do_register(name, func_or_class)
54 | return func_or_class
55 |
56 | return deco
57 |
58 | # used as a function call
59 | name = obj.__name__
60 | self._do_register(name, obj)
61 |
62 | def get(self, name):
63 | ret = self._obj_map.get(name)
64 | if ret is None:
65 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
66 | return ret
67 |
68 | def __contains__(self, name):
69 | return name in self._obj_map
70 |
71 | def __iter__(self):
72 | return iter(self._obj_map.items())
73 |
74 | def keys(self):
75 | return self._obj_map.keys()
76 |
77 |
78 | DATASET_REGISTRY = Registry('dataset')
79 | ARCH_REGISTRY = Registry('arch')
80 | MODEL_REGISTRY = Registry('model')
81 | LOSS_REGISTRY = Registry('loss')
82 | METRIC_REGISTRY = Registry('metric')
83 |
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | import pyiqa
2 | import argparse
3 | import cv2
4 | import glob
5 | import os
6 | from tqdm import tqdm
7 | import torch
8 | import yaml
9 |
10 |
11 | def main():
12 | """Evaluation. Metrics: PSNR, SSIM, LPIPS
13 | """
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('-p', '--pred', type=str, help='Predicted image or folder')
16 | parser.add_argument('-g', '--gt', type=str, help='groundtruth image or folder')
17 | parser.add_argument('-o', '--output', type=str, help='Output folder')
18 | args = parser.parse_args()
19 |
20 | if args.output is None:
21 | args.output = args.pred
22 | os.makedirs(args.output, exist_ok=True)
23 |
24 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25 | metric_funcs = {}
26 | metric_funcs['psnr'] = pyiqa.create_metric('psnr', device=device, crop_border=4, test_y_channel=True)
27 | metric_funcs['ssim'] = pyiqa.create_metric('ssim', device=device, crop_border=4, test_y_channel=True)
28 | metric_funcs['lpips'] = pyiqa.create_metric('lpips', device=device)
29 |
30 | metric_results = {'psnr': 0, 'ssim': 0, 'lpips': 0}
31 |
32 | # image
33 | pred_img = glob.glob(os.path.join(args.pred, '*.png'))
34 | gt_img = glob.glob(os.path.join(args.gt, '*.png'))
35 | basename = list(set([os.path.basename(img) for img in pred_img]) & set([os.path.basename(img) for img in gt_img]))
36 | data = [[os.path.join(args.pred, bn), os.path.join(args.gt, bn)] for bn in basename]
37 |
38 | # evaluate
39 | for idx in range(len(data)):
40 | for name in metric_funcs.keys():
41 | metric_results[name] += metric_funcs[name](*data[idx]).item()
42 |
43 | for name in metric_results.keys():
44 | metric_results[name] /= len(data)
45 |
46 | print(metric_results)
47 | with open(os.path.join(args.output, 'result.yaml'), 'w') as outfile:
48 | yaml.dump(metric_results, outfile, default_flow_style=False)
49 |
50 | if __name__ == '__main__':
51 | main()
--------------------------------------------------------------------------------
/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/.DS_Store
--------------------------------------------------------------------------------
/figures/model.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/model.jpeg
--------------------------------------------------------------------------------
/figures/teaser.jpeg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/figures/teaser.jpeg
--------------------------------------------------------------------------------
/generate_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | import random
5 | from tqdm import tqdm
6 | from multiprocessing import Pool
7 | import argparse
8 |
9 | from basicsr.data.bsrgan_util import degradation_bsrgan
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
14 | '.tif', '.TIF', '.tiff', '.TIFF',
15 | ]
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 |
22 | def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True):
23 | images = []
24 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
25 |
26 | for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)):
27 | for fname in fnames:
28 | if is_image_file(fname):
29 | path = os.path.join(root, fname)
30 | images.append(path)
31 | return images[:min(max_dataset_size, len(images))]
32 |
33 | def degrade_img(hr_path, save_path):
34 | img_gt = cv2.imread(hr_path).astype(np.float32) / 255.
35 | img_gt = img_gt[:, :, [2, 1, 0]] # BGR to RGB
36 | img_lq, img_gt = degradation_bsrgan(img_gt, sf=scale, use_crop=False)
37 | img_lq = (img_lq[:, :, [2, 1, 0]] * 255).astype(np.uint8)
38 | print(f'Save {save_path}')
39 | cv2.imwrite(save_path, img_lq)
40 |
41 | argparser = argparse.ArgumentParser(description='Generate degradation dataset following BSRGAN')
42 | argparser.add_argument('--scale', '-s', type=int, choices=[2, 4], help='downscale factor')
43 | argparser.add_argument('--dir', type=str, help='image directory')
44 | args = argparser.parse_args()
45 |
46 | seed = 123
47 | random.seed(seed)
48 | np.random.seed(seed)
49 |
50 | scale = args.scale
51 | hr_img_list = make_dataset(args.dir)
52 | pool = Pool(processes=40)
53 |
54 | for hr_path in hr_img_list:
55 | try:
56 | save_dir = os.path.dirname(hr_path) + f'_LR_X{scale}'
57 | save_path = os.path.join(save_dir, os.path.basename(hr_path))
58 | if not os.path.exists(save_dir):
59 | os.makedirs(save_dir, exist_ok=True)
60 | pool.apply_async(degrade_img(hr_path, save_path))
61 | except:
62 | print(hr_path, ': LR not generated.')
63 |
64 | pool.close()
65 | pool.join()
66 |
67 |
--------------------------------------------------------------------------------
/generate_inpaint_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from PIL import Image
4 | import numpy as np
5 | import random
6 | from tqdm import tqdm
7 | from multiprocessing import Pool
8 | import argparse
9 |
10 | from basicsr.data.paired_image_dataset import brush_stroke_mask
11 |
12 | IMG_EXTENSIONS = [
13 | '.jpg', '.JPG', '.jpeg', '.JPEG',
14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15 | '.tif', '.TIF', '.tiff', '.TIFF',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf"), followlinks=True):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir, followlinks=followlinks)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | return images[:min(max_dataset_size, len(images))]
33 |
34 | def degrade_img(hr_path, save_path):
35 | img_gt = cv2.imread(hr_path)
36 | img_lq = np.array(brush_stroke_mask(Image.fromarray(img_gt))).astype(np.uint8)
37 | print(f'Save {save_path}')
38 | cv2.imwrite(save_path, img_lq)
39 |
40 | argparser = argparse.ArgumentParser(description='Generate inpainting dataset')
41 | argparser.add_argument('--dir', type=str, help='image directory')
42 | args = argparser.parse_args()
43 |
44 | seed = 123
45 | random.seed(seed)
46 | np.random.seed(seed)
47 |
48 | hr_img_list = make_dataset(args.dir)
49 | pool = Pool(processes=40)
50 |
51 | for hr_path in hr_img_list:
52 | try:
53 | save_dir = os.path.dirname(hr_path) + f'_inpaint'
54 | save_path = os.path.join(save_dir, os.path.basename(hr_path))
55 | if not os.path.exists(save_dir):
56 | os.makedirs(save_dir, exist_ok=True)
57 | pool.apply_async(degrade_img(hr_path, save_path))
58 | except:
59 | print(hr_path, ': masked not generated.')
60 |
61 | pool.close()
62 | pool.join()
63 |
64 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import os
5 | from tqdm import tqdm
6 | import numpy as np
7 | import torch
8 | from yaml import load
9 | import matplotlib as m
10 | import matplotlib.pyplot as plt
11 |
12 | from basicsr.utils import img2tensor, tensor2img, imwrite
13 | from basicsr.archs.adacode_contrast_arch import AdaCodeSRNet_Contrast
14 | from basicsr.utils.download_util import load_file_from_url
15 |
16 | pretrain_model_url = {
17 | 'SR_x2': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X2_model_g.pth',
18 | 'SR_x4': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_SR_X4_model_g.pth',
19 | 'inpaint': 'https://github.com/kechunl/AdaCode/releases/download/v0-pretrain_models/AdaCode_Inpaint_model_g.pth',
20 | }
21 |
22 | def main():
23 | """Inference demo for FeMaSR
24 | """
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
27 | parser.add_argument('-w', '--weight', type=str, default=None, help='path for model weights')
28 | parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
29 | parser.add_argument('-t', '--task', type=str, choices=['sr', 'inpaint'], help='inference task')
30 | parser.add_argument('-s', '--out_scale', type=int, default=4, help='final upsampling scale of the image for SR task')
31 | parser.add_argument('--suffix', type=str, default='', help='Suffix of the restored image')
32 | parser.add_argument('--max_size', type=int, default=600, help='Max image size for whole image inference, otherwise use tiled_test')
33 | parser.add_argument('--vis_weight', action='store_true', help='visualize weight map')
34 | args = parser.parse_args()
35 |
36 | if args.vis_weight:
37 | os.makedirs(os.path.join(args.output, 'weight_map'), exist_ok=True)
38 |
39 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
40 |
41 | if args.task == 'sr':
42 | model_url_key = f'SR_x{args.out_scale}'
43 | scale = args.out_scale
44 | bs = 8
45 | else:
46 | model_url_key = 'inpaint'
47 | scale = 1
48 | bs = 2
49 |
50 | if args.weight is None:
51 | os.makedirs('./checkpoint', exist_ok=True)
52 | weight_path = load_file_from_url(pretrain_model_url[f'{model_url_key}'], model_dir='./checkpoint')
53 | else:
54 | weight_path = args.weight
55 |
56 | # set up the model
57 | model_params = torch.load(weight_path)['params']
58 | codebook_dim = np.array([v.size() for k,v in model_params.items() if 'quantize_group' in k])
59 | codebook_dim_list = []
60 | for k in codebook_dim:
61 | temp = k.tolist()
62 | temp.insert(0,32)
63 | codebook_dim_list.append(temp)
64 | model = AdaCodeSRNet_Contrast(codebook_params=codebook_dim_list, LQ_stage=True, AdaCode_stage=True, weight_softmax=False, batch_size=bs, scale_factor=scale).to(device)
65 | model.load_state_dict(torch.load(weight_path)['params'], strict=False)
66 | model.eval()
67 |
68 | os.makedirs(args.output, exist_ok=True)
69 | if os.path.isfile(args.input):
70 | paths = [args.input]
71 | else:
72 | paths = sorted(glob.glob(os.path.join(args.input, '*')))
73 |
74 | pbar = tqdm(total=len(paths), unit='image')
75 | for idx, path in enumerate(paths):
76 | try:
77 | img_name = os.path.basename(path)
78 | pbar.set_description(f'Test {img_name}')
79 |
80 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
81 | img_tensor = img2tensor(img).to(device) / 255.
82 | img_tensor = img_tensor.unsqueeze(0)
83 |
84 | max_size = args.max_size ** 2
85 | h, w = img_tensor.shape[2:]
86 | if h * w < max_size:
87 | output = model.test(img_tensor, vis_weight=args.vis_weight)
88 | else:
89 | output = model.test_tile(img_tensor, vis_weight=args.vis_weight)
90 |
91 | if args.vis_weight:
92 | weight_map = output[1]
93 | vis_weight(weight_map, os.path.join(args.output, 'weight_map', img_name))
94 | output = output[0]
95 | output_img = tensor2img(output)
96 |
97 | save_path = os.path.join(args.output, f'{img_name}')
98 | imwrite(output_img, save_path)
99 | pbar.update(1)
100 | except:
101 | print(path, ' fails.')
102 | pbar.close()
103 |
104 | def vis_weight(weight, save_path):
105 | # weight: B x n x 1 x H x W
106 | weight = weight.cpu().numpy()
107 | # normalize weights
108 | # norm_weight = weight
109 | norm_weight = (weight - weight.mean()) / weight.std() / 2
110 | norm_weight = np.abs(norm_weight)
111 | norm_weight *= 255
112 | norm_weight = np.clip(norm_weight, 0, 255)
113 | norm_weight = norm_weight.astype(np.uint8)
114 | # visualize
115 | display_grid = np.zeros((weight.shape[3], (weight.shape[4]+1)*weight.shape[1]-1))
116 | for img_id in range(len(norm_weight)):
117 | for c in range(norm_weight.shape[1]):
118 | display_grid[:, c*weight.shape[4]+c:(c+1)*weight.shape[4]+c] = norm_weight[img_id, c, 0, :, :]
119 | # weight_path = save_path.split('.')[0] + '_{}.png'.format(str(c))
120 | # Image.fromarray(norm_weight[img_id, c, 0, :, :]).save(weight_path)
121 | plt.figure(figsize=(30,150))
122 | plt.axis('off')
123 | plt.imshow(display_grid)
124 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
125 | plt.close()
126 |
127 | if __name__ == '__main__':
128 | main()
129 |
--------------------------------------------------------------------------------
/options/stage1/train_AdaCode_HQ_stage1_category.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: AdaCode_stage1_category
3 | model_type: FeMaSRModel
4 | scale: 4 # doesn't matter for this stage
5 | num_gpu: 4 # set num_gpu: 0 for cpu mode
6 | manual_seed: 0
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: General_Image_Train
12 | type: BSRGANTrainDataset
13 | datafile_gt: ./data/category_data/train/portrait.txt
14 | io_backend:
15 | type: disk
16 |
17 | gt_size: 256
18 | use_resize_crop: true
19 | use_flip: true
20 | use_rot: true
21 |
22 | # data loader
23 | use_shuffle: true
24 | batch_size_per_gpu: &bsz 8
25 | num_worker_per_gpu: *bsz
26 | dataset_enlarge_ratio: 1
27 |
28 | prefetch_mode: cpu
29 | num_prefetch_queue: *bsz
30 |
31 | val:
32 | name: General_Image_Valid
33 | type: PairedImageDataset
34 | datafile_gt: ./data/category_data/valid/portrait.txt
35 | datafile_lq: ./data/category_data/valid/portrait.txt
36 | io_backend:
37 | type: disk
38 |
39 | # network structures
40 | network_g:
41 | type: FeMaSRNet
42 | gt_resolution: 256
43 | norm_type: 'gn'
44 | act_type: 'silu'
45 | use_semantic_loss: true
46 | codebook_params: # has to order from low to high
47 | - [32, 256, 256]
48 |
49 | # for HQ stage training
50 | LQ_stage: false
51 | use_quantize: true
52 |
53 | network_d:
54 | type: UNetDiscriminatorSN
55 | num_in_ch: 3
56 |
57 | # path
58 | path:
59 | # pretrain_network_g: ./experiments/pretrained_models/QuanTexSR/pretrain_semantic_vqgan_net_g_latest.pth
60 | # pretrain_network_d: ~
61 | # pretrain_network_g: ./experiments/004_FeMaSR_HQ_stage/models/net_g_best_.pth
62 | # pretrain_network_d: ./experiments/004_FeMaSR_HQ_stage/models/net_d_best_.pth
63 | strict_load: false
64 | # resume_state: ~
65 |
66 | # training settings
67 | train:
68 | optim_g:
69 | type: Adam
70 | lr: !!float 1e-4
71 | weight_decay: 0
72 | betas: [0.9, 0.99]
73 | optim_d:
74 | type: Adam
75 | lr: !!float 4e-4
76 | weight_decay: 0
77 | betas: [0.9, 0.99]
78 |
79 | scheduler:
80 | type: MultiStepLR
81 | milestones: [1800, 3600, 5400, 7200, 9000, 10800, 12600]
82 | gamma: 1
83 |
84 | total_iter: 73500
85 | warmup_iter: -1 # no warm up
86 |
87 | # losses
88 | pixel_opt:
89 | type: L1Loss
90 | loss_weight: 1.0
91 | reduction: mean
92 |
93 | perceptual_opt:
94 | type: LPIPSLoss
95 | loss_weight: !!float 1.0
96 |
97 | gan_opt:
98 | type: GANLoss
99 | gan_type: hinge
100 | real_label_val: 1.0
101 | fake_label_val: 0.0
102 | loss_weight: 0.1
103 |
104 | codebook_opt:
105 | loss_weight: 1.0
106 |
107 | semantic_opt:
108 | loss_weight: 0.1
109 |
110 | net_d_iters: 1
111 | net_d_init_iters: !!float 0
112 |
113 | # validation settings·
114 | val:
115 | val_freq: !!float 3e2
116 | save_img: false
117 |
118 | key_metric: lpips
119 | metrics:
120 | psnr: # metric name, not used in this codebase
121 | type: psnr
122 | crop_border: 4
123 | test_y_channel: true
124 | color_space: ycbcr
125 | ssim:
126 | type: ssim
127 | crop_border: 4
128 | test_y_channel: true
129 | color_space: ycbcr
130 | lpips:
131 | type: lpips
132 | better: lower
133 |
134 | # logging settings
135 | logger:
136 | print_freq: 100
137 | save_checkpoint_freq: !!float 1e9
138 | save_latest_freq: !!float 1e3
139 | show_tf_imgs_freq: !!float 3e2
140 | use_tb_logger: true
141 |
--------------------------------------------------------------------------------
/options/stage2/train_AdaCode_HQ_stage2.yaml:
--------------------------------------------------------------------------------
1 | # GENERATE TIME: Fri Oct 21 18:30:19 2022
2 | # CMD:
3 | # basicsr/train.py --local_rank=0 -opt options/train_AdaCode_HQ_stage2.yaml --launcher pytorch
4 |
5 | # general settings
6 | name: AdaCode_stage2_recon
7 | model_type: FeMaSRModel
8 | scale: &upscale 4 # doens't matter for this stage
9 | num_gpu: 8 # set num_gpu: 0 for cpu mode
10 | manual_seed: 0
11 |
12 | # dataset and data loader settings
13 | datasets:
14 | train:
15 | name: General_Image_Train
16 | type: BSRGANTrainDataset
17 | datafile_gt: ./data/category_data/train/all_train.txt
18 | io_backend:
19 | type: disk
20 |
21 | gt_size: 256
22 | use_resize_crop: true
23 | use_flip: true
24 | use_rot: true
25 |
26 | # data loader
27 | use_shuffle: true
28 | batch_size_per_gpu: &bsz 8
29 | num_worker_per_gpu: *bsz
30 | dataset_enlarge_ratio: 1
31 |
32 | prefetch_mode: cpu
33 | num_prefetch_queue: *bsz
34 |
35 | val:
36 | name: General_Image_Valid
37 | type: PairedImageDataset
38 | datafile_gt: ./data/category_data/valid/all_valid.txt
39 | datafile_lq: ./data/category_data/valid/all_valid.txt
40 | # crop_eval_size: 384
41 | io_backend:
42 | type: disk
43 |
44 | # network structures
45 | network_g:
46 | type: AdaCodeSRNet
47 | gt_resolution: 256
48 | norm_type: 'gn'
49 | act_type: 'silu'
50 | scale_factor: *upscale
51 |
52 | ### TODO: modify the configuration carefully
53 | AdaCode_stage: true
54 | LQ_stage: false
55 | frozen_module_keywords: ['quantize']
56 |
57 | network_d:
58 | type: UNetDiscriminatorSN
59 | num_in_ch: 3
60 |
61 | # path
62 | path:
63 | pretrain_codebook:
64 | - ./experiments/AdaCode_stage1_c0_codebook512x256/models/net_g_best_.pth
65 | - ./experiments/AdaCode_stage1_c1_codebook256x256/models/net_g_best_.pth
66 | - ./experiments/AdaCode_stage1_c2_codebook512x256/models/net_g_best_.pth
67 | - ./experiments/AdaCode_stage1_c3_codebook256x256/models/net_g_best_.pth
68 | - ./experiments/AdaCode_stage1_ffhq_codebook256x256/models/net_g_best_.pth
69 |
70 | # pretrain_network_hq: ./experiments/008_FeMaSR_HQ_stage/models/net_g_best_.pth
71 | # pretrain_network_g: ~
72 | # pretrain_network_d: ./experiments/008_FeMaSR_HQ_stage/models/net_d_best_.pth
73 | strict_load: false
74 | # resume_state: ~
75 |
76 | # training settings
77 | train:
78 | optim_g:
79 | type: Adam
80 | lr: !!float 1e-4
81 | weight_decay: 0
82 | betas: [0.9, 0.99]
83 | optim_d:
84 | type: Adam
85 | lr: !!float 4e-4
86 | weight_decay: 0
87 | betas: [0.9, 0.99]
88 |
89 | scheduler:
90 | type: MultiStepLR
91 | milestones: [18000, 36000, 54000, 72000, 90000, 108000, 126000]
92 | gamma: 1
93 |
94 | total_iter: 730000
95 | warmup_iter: -1 # no warm up
96 |
97 | # losses
98 | pixel_opt:
99 | type: L1Loss
100 | loss_weight: 1.0
101 | reduction: mean
102 |
103 | perceptual_opt:
104 | type: LPIPSLoss
105 | loss_weight: !!float 1.0
106 |
107 | gan_opt:
108 | type: GANLoss
109 | gan_type: hinge
110 | real_label_val: 1.0
111 | fake_label_val: 0.0
112 | loss_weight: 0.1
113 |
114 | codebook_opt:
115 | loss_weight: 1.0
116 |
117 | semantic_opt:
118 | loss_weight: 0.1
119 |
120 | net_d_iters: 1
121 | net_d_init_iters: !!float 0
122 |
123 | # validation settings·
124 | val:
125 | val_freq: !!float 3e3
126 | save_img: true
127 |
128 | key_metric: lpips
129 | metrics:
130 | psnr: # metric name, can be arbitrary
131 | type: psnr
132 | crop_border: 4
133 | test_y_channel: true
134 | ssim:
135 | type: ssim
136 | crop_border: 4
137 | test_y_channel: true
138 | lpips:
139 | type: lpips
140 | better: lower
141 |
142 | # logging settings
143 | logger:
144 | print_freq: 100
145 | save_checkpoint_freq: !!float 1e9
146 | save_latest_freq: !!float 3e3
147 | show_tf_imgs_freq: !!float 1e3
148 | use_tb_logger: true
149 |
150 | # wandb:
151 | # project: ESRGAN
152 | # resume_id: ~
153 |
154 | # dist training settings
155 | # dist_params:
156 | # backend: nccl
157 | # port: 16500 #29500
158 |
--------------------------------------------------------------------------------
/options/stage3/train_AdaCode_stage3.yaml:
--------------------------------------------------------------------------------
1 | # GENERATE TIME: Sun Nov 27 00:55:18 2022
2 | # CMD:
3 | # basicsr/train.py --local_rank=0 -opt options/train_AdaCode_stage3.yaml --launcher pytorch
4 |
5 | # general settings
6 | name: AdaCode_stage3_SR_X2
7 | model_type: FeMaSRModel
8 | scale: &upscale 2
9 | num_gpu: 8 # set num_gpu: 0 for cpu mode
10 | manual_seed: 0
11 |
12 | # dataset and data loader settings
13 | datasets:
14 | train:
15 | name: General_Image_Train
16 | # type: BSRGANTrainDataset
17 | # datafile_gt: /newDisk/users/liukechun/research/FeMaSR/data/category_data/train/all_train.txt
18 | type: PairedImageDataset
19 | datafile_gt: ./data/category_data/train/all_train.txt
20 | datafile_lq: ./data/category_data/train/all_train_X2.txt
21 | io_backend:
22 | type: disk
23 |
24 | gt_size: 256
25 | use_resize_crop: true
26 | use_flip: true
27 | use_rot: true
28 |
29 | # data loader
30 | use_shuffle: true
31 | batch_size_per_gpu: &bsz 8
32 | num_worker_per_gpu: *bsz
33 | dataset_enlarge_ratio: 1
34 |
35 | prefetch_mode: cpu
36 | num_prefetch_queue: *bsz
37 |
38 | val:
39 | name: General_Image_Valid
40 | type: PairedImageDataset
41 | datafile_gt: ./data/category_data/valid/all_valid.txt
42 | datafile_lq: ./data/category_data/valid/all_valid_X2.txt
43 | # crop_eval_size: 384
44 | io_backend:
45 | type: disk
46 |
47 | # network structures
48 | network_g:
49 | type: AdaCodeSRNet_Contrast
50 | gt_resolution: 256
51 | norm_type: 'gn'
52 | act_type: 'silu'
53 | scale_factor: *upscale
54 |
55 | ### TODO: modify the configuration carefully
56 | AdaCode_stage: true
57 | LQ_stage: true
58 | frozen_module_keywords: ['quantize', 'decoder', 'after_quant_group', 'out_conv']
59 |
60 | weight_softmax: false
61 |
62 | network_d:
63 | type: UNetDiscriminatorSN
64 | num_in_ch: 3
65 |
66 | # path
67 | path:
68 | # pretrain_codebook:
69 | # - ./experiments/AdaCode_stage1_c0_codebook512x256/models/net_g_best_.pth
70 | # - ./experiments/AdaCode_stage1_c1_codebook256x256/models/net_g_best_.pth
71 | # - ./experiments/AdaCode_stage1_c2_codebook512x256/models/net_g_best_.pth
72 | # - ./experiments/AdaCode_stage1_c3_codebook256x256/models/net_g_best_.pth
73 | # - ./experiments/AdaCode_stage1_ffhq_codebook256x256/models/net_g_best_.pth
74 |
75 | pretrain_network_hq: ./experiments/stage2/AdaCode_stage2_recon/models/net_g_best_.pth
76 | # pretrain_network_g: ~
77 | pretrain_network_d: ./experiments/stage2/AdaCode_stage2_recon/models/net_d_best_.pth
78 | strict_load: false
79 | # resume_state: ~
80 |
81 | # training settings
82 | train:
83 | optim_g:
84 | type: Adam
85 | lr: !!float 1e-4
86 | weight_decay: 0
87 | betas: [0.9, 0.99]
88 | optim_d:
89 | type: Adam
90 | lr: !!float 4e-4
91 | weight_decay: 0
92 | betas: [0.9, 0.99]
93 |
94 | scheduler:
95 | type: MultiStepLR
96 | milestones: [18000, 36000, 54000, 72000, 90000, 108000, 126000]
97 | gamma: 1
98 |
99 | total_iter: 730000
100 | warmup_iter: -1 # no warm up
101 |
102 | # losses
103 | pixel_opt:
104 | type: L1Loss
105 | loss_weight: 1.0
106 | reduction: mean
107 |
108 | perceptual_opt:
109 | type: LPIPSLoss
110 | loss_weight: !!float 1.0
111 |
112 | gan_opt:
113 | type: GANLoss
114 | gan_type: hinge
115 | real_label_val: 1.0
116 | fake_label_val: 0.0
117 | loss_weight: 0.1
118 |
119 | codebook_opt:
120 | loss_weight: 1.0
121 |
122 | semantic_opt:
123 | loss_weight: 0.1
124 |
125 | before_quant_opt:
126 | loss_weight: 0.0
127 |
128 | after_quant_opt:
129 | loss_weight: 1.0
130 |
131 | contrast_opt:
132 | loss_weight: 1.0
133 |
134 | net_d_iters: 1
135 | net_d_init_iters: !!float 0
136 |
137 | # validation settings·
138 | val:
139 | val_freq: !!float 3e3
140 | save_img: false
141 |
142 | key_metric: lpips
143 | metrics:
144 | psnr: # metric name, can be arbitrary
145 | type: psnr
146 | crop_border: 4
147 | test_y_channel: true
148 | ssim:
149 | type: ssim
150 | crop_border: 4
151 | test_y_channel: true
152 | lpips:
153 | type: lpips
154 | better: lower
155 |
156 | # logging settings
157 | logger:
158 | print_freq: 100
159 | save_checkpoint_freq: !!float 1e9
160 | save_latest_freq: !!float 3e3
161 | show_tf_imgs_freq: !!float 1e3
162 | use_tb_logger: true
163 |
164 | # wandb:
165 | # project: ESRGAN
166 | # resume_id: ~
167 |
168 | # dist training settings
169 | # dist_params:
170 | # backend: nccl
171 | # port: 16500 #29500
172 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | addict
2 | future
3 | lmdb
4 | setuptools>=57.0.0
5 | numpy>=1.20.3
6 | opencv-python-headless
7 | Pillow>=8.3.2
8 | pyyaml
9 | requests
10 | scikit-image
11 | scikit-learn
12 | scipy
13 | tb-nightly
14 | torch>=1.8.1
15 | torchvision>=0.9
16 | tqdm
17 | yapf
18 | pyiqa
19 | einops
20 |
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/scripts/.DS_Store
--------------------------------------------------------------------------------
/scripts/data_preparation/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 |
8 | def create_lmdb_for_div2k():
9 | """Create lmdb files for DIV2K dataset.
10 |
11 | Usage:
12 | Before run this script, please run `extract_subimages.py`.
13 | Typically, there are four folders to be processed for DIV2K dataset.
14 | DIV2K_train_HR_sub
15 | DIV2K_train_LR_bicubic/X2_sub
16 | DIV2K_train_LR_bicubic/X3_sub
17 | DIV2K_train_LR_bicubic/X4_sub
18 | Remember to modify opt configurations according to your settings.
19 | """
20 | # HR images
21 | folder_path = 'datasets/DIV2K/DIV2K_train_HR_sub'
22 | lmdb_path = 'datasets/DIV2K/DIV2K_train_HR_sub.lmdb'
23 | img_path_list, keys = prepare_keys_div2k(folder_path)
24 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
25 |
26 | # LRx2 images
27 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
28 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X2_sub.lmdb'
29 | img_path_list, keys = prepare_keys_div2k(folder_path)
30 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
31 |
32 | # LRx3 images
33 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
34 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X3_sub.lmdb'
35 | img_path_list, keys = prepare_keys_div2k(folder_path)
36 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
37 |
38 | # LRx4 images
39 | folder_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
40 | lmdb_path = 'datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb'
41 | img_path_list, keys = prepare_keys_div2k(folder_path)
42 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
43 |
44 |
45 | def prepare_keys_div2k(folder_path):
46 | """Prepare image path list and keys for DIV2K dataset.
47 |
48 | Args:
49 | folder_path (str): Folder path.
50 |
51 | Returns:
52 | list[str]: Image path list.
53 | list[str]: Key list.
54 | """
55 | print('Reading image path list ...')
56 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=False)))
57 | keys = [img_path.split('.png')[0] for img_path in sorted(img_path_list)]
58 |
59 | return img_path_list, keys
60 |
61 |
62 | def create_lmdb_for_reds():
63 | """Create lmdb files for REDS dataset.
64 |
65 | Usage:
66 | Before run this script, please run `merge_reds_train_val.py`.
67 | We take two folders for example:
68 | train_sharp
69 | train_sharp_bicubic
70 | Remember to modify opt configurations according to your settings.
71 | """
72 | # train_sharp
73 | folder_path = 'datasets/REDS/train_sharp'
74 | lmdb_path = 'datasets/REDS/train_sharp_with_val.lmdb'
75 | img_path_list, keys = prepare_keys_reds(folder_path)
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
77 |
78 | # train_sharp_bicubic
79 | folder_path = 'datasets/REDS/train_sharp_bicubic'
80 | lmdb_path = 'datasets/REDS/train_sharp_bicubic_with_val.lmdb'
81 | img_path_list, keys = prepare_keys_reds(folder_path)
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
83 |
84 |
85 | def prepare_keys_reds(folder_path):
86 | """Prepare image path list and keys for REDS dataset.
87 |
88 | Args:
89 | folder_path (str): Folder path.
90 |
91 | Returns:
92 | list[str]: Image path list.
93 | list[str]: Key list.
94 | """
95 | print('Reading image path list ...')
96 | img_path_list = sorted(list(scandir(folder_path, suffix='png', recursive=True)))
97 | keys = [v.split('.png')[0] for v in img_path_list] # example: 000/00000000
98 |
99 | return img_path_list, keys
100 |
101 |
102 | def create_lmdb_for_vimeo90k():
103 | """Create lmdb files for Vimeo90K dataset.
104 |
105 | Usage:
106 | Remember to modify opt configurations according to your settings.
107 | """
108 | # GT
109 | folder_path = 'datasets/vimeo90k/vimeo_septuplet/sequences'
110 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_GT_only4th.lmdb'
111 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
112 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'gt')
113 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
114 |
115 | # LQ
116 | folder_path = 'datasets/vimeo90k/vimeo_septuplet_matlabLRx4/sequences'
117 | lmdb_path = 'datasets/vimeo90k/vimeo90k_train_LR7frames.lmdb'
118 | train_list_path = 'datasets/vimeo90k/vimeo_septuplet/sep_trainlist.txt'
119 | img_path_list, keys = prepare_keys_vimeo90k(folder_path, train_list_path, 'lq')
120 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys, multiprocessing_read=True)
121 |
122 |
123 | def prepare_keys_vimeo90k(folder_path, train_list_path, mode):
124 | """Prepare image path list and keys for Vimeo90K dataset.
125 |
126 | Args:
127 | folder_path (str): Folder path.
128 | train_list_path (str): Path to the official train list.
129 | mode (str): One of 'gt' or 'lq'.
130 |
131 | Returns:
132 | list[str]: Image path list.
133 | list[str]: Key list.
134 | """
135 | print('Reading image path list ...')
136 | with open(train_list_path, 'r') as fin:
137 | train_list = [line.strip() for line in fin]
138 |
139 | img_path_list = []
140 | keys = []
141 | for line in train_list:
142 | folder, sub_folder = line.split('/')
143 | img_path_list.extend([osp.join(folder, sub_folder, f'im{j + 1}.png') for j in range(7)])
144 | keys.extend([f'{folder}/{sub_folder}/im{j + 1}' for j in range(7)])
145 |
146 | if mode == 'gt':
147 | print('Only keep the 4th frame for the gt mode.')
148 | img_path_list = [v for v in img_path_list if v.endswith('im4.png')]
149 | keys = [v for v in keys if v.endswith('/im4')]
150 |
151 | return img_path_list, keys
152 |
153 |
154 | if __name__ == '__main__':
155 | parser = argparse.ArgumentParser()
156 |
157 | parser.add_argument(
158 | '--dataset',
159 | type=str,
160 | help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' "
161 | 'You may need to modify the corresponding configurations in codes.'))
162 | args = parser.parse_args()
163 | dataset = args.dataset.lower()
164 | if dataset == 'div2k':
165 | create_lmdb_for_div2k()
166 | elif dataset == 'reds':
167 | create_lmdb_for_reds()
168 | elif dataset == 'vimeo90k':
169 | create_lmdb_for_vimeo90k()
170 | else:
171 | raise ValueError('Wrong dataset.')
172 |
--------------------------------------------------------------------------------
/scripts/data_preparation/download_datasets.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | from os import path as osp
5 |
6 | from basicsr.utils.download_util import download_file_from_google_drive
7 |
8 |
9 | def download_dataset(dataset, file_ids):
10 | save_path_root = './datasets/'
11 | os.makedirs(save_path_root, exist_ok=True)
12 |
13 | for file_name, file_id in file_ids.items():
14 | save_path = osp.abspath(osp.join(save_path_root, file_name))
15 | if osp.exists(save_path):
16 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
17 | if user_response.lower() == 'y':
18 | print(f'Covering {file_name} to {save_path}')
19 | download_file_from_google_drive(file_id, save_path)
20 | elif user_response.lower() == 'n':
21 | print(f'Skipping {file_name}')
22 | else:
23 | raise ValueError('Wrong input. Only accepts Y/N.')
24 | else:
25 | print(f'Downloading {file_name} to {save_path}')
26 | download_file_from_google_drive(file_id, save_path)
27 |
28 | # unzip
29 | if save_path.endswith('.zip'):
30 | extracted_path = save_path.replace('.zip', '')
31 | print(f'Extract {save_path} to {extracted_path}')
32 | import zipfile
33 | with zipfile.ZipFile(save_path, 'r') as zip_ref:
34 | zip_ref.extractall(extracted_path)
35 |
36 | file_name = file_name.replace('.zip', '')
37 | subfolder = osp.join(extracted_path, file_name)
38 | if osp.isdir(subfolder):
39 | print(f'Move {subfolder} to {extracted_path}')
40 | import shutil
41 | for path in glob.glob(osp.join(subfolder, '*')):
42 | shutil.move(path, extracted_path)
43 | shutil.rmtree(subfolder)
44 |
45 |
46 | if __name__ == '__main__':
47 | parser = argparse.ArgumentParser()
48 |
49 | parser.add_argument(
50 | 'dataset',
51 | type=str,
52 | help=("Options: 'Set5', 'Set14'. "
53 | "Set to 'all' if you want to download all the dataset."))
54 | args = parser.parse_args()
55 |
56 | file_ids = {
57 | 'Set5': {
58 | 'Set5.zip': # file name
59 | '1RtyIeUFTyW8u7oa4z7a0lSzT3T1FwZE9', # file id
60 | },
61 | 'Set14': {
62 | 'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E',
63 | }
64 | }
65 |
66 | if args.dataset == 'all':
67 | for dataset in file_ids.keys():
68 | download_dataset(dataset, file_ids[dataset])
69 | else:
70 | download_dataset(args.dataset, file_ids[args.dataset])
71 |
--------------------------------------------------------------------------------
/scripts/data_preparation/extract_subimages.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import sys
5 | from multiprocessing import Pool
6 | from os import path as osp
7 | from tqdm import tqdm
8 |
9 | from basicsr.utils import scandir
10 |
11 |
12 | def main():
13 | """A multi-thread tool to crop large images to sub-images for faster IO.
14 |
15 | It is used for DIV2K dataset.
16 |
17 | opt (dict): Configuration dict. It contains:
18 | n_thread (int): Thread number.
19 | compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9.
20 | A higher value means a smaller size and longer compression time.
21 | Use 0 for faster CPU decompression. Default: 3, same in cv2.
22 |
23 | input_folder (str): Path to the input folder.
24 | save_folder (str): Path to save folder.
25 | crop_size (int): Crop size.
26 | step (int): Step for overlapped sliding window.
27 | thresh_size (int): Threshold size. Patches whose size is lower
28 | than thresh_size will be dropped.
29 |
30 | Usage:
31 | For each folder, run this script.
32 | Typically, there are four folders to be processed for DIV2K dataset.
33 | DIV2K_train_HR
34 | DIV2K_train_LR_bicubic/X2
35 | DIV2K_train_LR_bicubic/X3
36 | DIV2K_train_LR_bicubic/X4
37 | After process, each sub_folder should have the same number of
38 | subimages.
39 | Remember to modify opt configurations according to your settings.
40 | """
41 |
42 | opt = {}
43 | opt['n_thread'] = 20
44 | opt['compression_level'] = 3
45 |
46 | opt['input_folder'] = '../../../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/'
47 | opt['save_folder'] = '../../../datasets/HQ_sub/OST_train_HR_sub/'
48 | opt['crop_size'] = 320
49 | opt['step'] = 160
50 | opt['thresh_size'] = 0
51 |
52 | # HR images
53 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_HR'
54 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_HR_sub'
55 | opt['crop_size'] = 480
56 | opt['step'] = 240
57 | opt['thresh_size'] = 0
58 | extract_subimages(opt)
59 |
60 | # LRx2 images
61 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2'
62 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X2_sub'
63 | opt['crop_size'] = 240
64 | opt['step'] = 120
65 | opt['thresh_size'] = 0
66 | extract_subimages(opt)
67 |
68 | # LRx3 images
69 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3'
70 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X3_sub'
71 | opt['crop_size'] = 160
72 | opt['step'] = 80
73 | opt['thresh_size'] = 0
74 | extract_subimages(opt)
75 |
76 | # LRx4 images
77 | opt['input_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4'
78 | opt['save_folder'] = 'datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub'
79 | opt['crop_size'] = 120
80 | opt['step'] = 60
81 | opt['thresh_size'] = 0
82 | extract_subimages(opt)
83 |
84 |
85 | def extract_subimages(opt):
86 | """Crop images to subimages.
87 |
88 | Args:
89 | opt (dict): Configuration dict. It contains:
90 | input_folder (str): Path to the input folder.
91 | save_folder (str): Path to save folder.
92 | n_thread (int): Thread number.
93 | """
94 | input_folder = opt['input_folder']
95 | save_folder = opt['save_folder']
96 | if not osp.exists(save_folder):
97 | os.makedirs(save_folder)
98 | print(f'mkdir {save_folder} ...')
99 | else:
100 | print(f'Folder {save_folder} already exists. Exit.')
101 | sys.exit(1)
102 |
103 | img_list = list(scandir(input_folder, recursive=True, full_path=True))
104 |
105 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
106 | pool = Pool(opt['n_thread'])
107 | for path in img_list:
108 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
109 | pool.close()
110 | pool.join()
111 | pbar.close()
112 | print('All processes done.')
113 |
114 |
115 | def worker(path, opt):
116 | """Worker for each process.
117 |
118 | Args:
119 | path (str): Image path.
120 | opt (dict): Configuration dict. It contains:
121 | crop_size (int): Crop size.
122 | step (int): Step for overlapped sliding window.
123 | thresh_size (int): Threshold size. Patches whose size is lower
124 | than thresh_size will be dropped.
125 | save_folder (str): Path to save folder.
126 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
127 |
128 | Returns:
129 | process_info (str): Process information displayed in progress bar.
130 | """
131 | crop_size = opt['crop_size']
132 | step = opt['step']
133 | thresh_size = opt['thresh_size']
134 | img_name, extension = osp.splitext(osp.basename(path))
135 |
136 | # remove the x2, x3, x4 and x8 in the filename for DIV2K
137 | img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
138 |
139 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
140 |
141 | h, w = img.shape[0:2]
142 | h_space = np.arange(0, h - crop_size + 1, step)
143 | if h - (h_space[-1] + crop_size) > thresh_size:
144 | h_space = np.append(h_space, h - crop_size)
145 | w_space = np.arange(0, w - crop_size + 1, step)
146 | if w - (w_space[-1] + crop_size) > thresh_size:
147 | w_space = np.append(w_space, w - crop_size)
148 |
149 | index = 0
150 | for x in h_space:
151 | for y in w_space:
152 | index += 1
153 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
154 | cropped_img = np.ascontiguousarray(cropped_img)
155 | cv2.imwrite(
156 | osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
157 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
158 | process_info = f'Processing {img_name} ...'
159 | return process_info
160 |
161 |
162 | if __name__ == '__main__':
163 | main()
164 |
--------------------------------------------------------------------------------
/scripts/data_preparation/generate_meta_info.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from PIL import Image
3 |
4 | from basicsr.utils import scandir
5 |
6 |
7 | def generate_meta_info_div2k():
8 | """Generate meta info for DIV2K dataset.
9 | """
10 |
11 | gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/'
12 | meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
13 |
14 | img_list = sorted(list(scandir(gt_folder)))
15 |
16 | with open(meta_info_txt, 'w') as f:
17 | for idx, img_path in enumerate(img_list):
18 | img = Image.open(osp.join(gt_folder, img_path)) # lazy load
19 | width, height = img.size
20 | mode = img.mode
21 | if mode == 'RGB':
22 | n_channel = 3
23 | elif mode == 'L':
24 | n_channel = 1
25 | else:
26 | raise ValueError(f'Unsupported mode {mode}.')
27 |
28 | info = f'{img_path} ({height},{width},{n_channel})'
29 | print(idx + 1, info)
30 | f.write(f'{info}\n')
31 |
32 |
33 | if __name__ == '__main__':
34 | generate_meta_info_div2k()
35 |
--------------------------------------------------------------------------------
/scripts/data_preparation/prepare_hifacegan_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | from tqdm import tqdm
4 |
5 |
6 | class Mosaic16x:
7 | """
8 | Mosaic16x: A customized image augmentor for 16-pixel mosaic
9 | By default it replaces each pixel value with the mean value
10 | of its 16x16 neighborhood
11 | """
12 |
13 | def augment_image(self, x):
14 | h, w = x.shape[:2]
15 | x = x.astype('float') # avoid overflow for uint8
16 | irange, jrange = (h + 15) // 16, (w + 15) // 16
17 | for i in range(irange):
18 | for j in range(jrange):
19 | mean = x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16].mean(axis=(0, 1))
20 | x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16] = mean
21 |
22 | return x.astype('uint8')
23 |
24 |
25 | class DegradationSimulator:
26 | """
27 | Generating training/testing data pairs on the fly.
28 | The degradation script is aligned with HiFaceGAN paper settings.
29 |
30 | Args:
31 | opt(str | op): Config for degradation script, with degradation type and parameters
32 | Custom degradation is possible by passing an inherited class from ia.augmentors
33 | """
34 |
35 | def __init__(self, ):
36 | import imgaug.augmenters as ia
37 | self.default_deg_templates = {
38 | 'sr4x':
39 | ia.Sequential([
40 | # It's almost like a 4x bicubic downsampling
41 | ia.Resize((0.25000, 0.25001), cv2.INTER_AREA),
42 | ia.Resize({
43 | 'height': 512,
44 | 'width': 512
45 | }, cv2.INTER_CUBIC),
46 | ]),
47 | 'sr4x8x':
48 | ia.Sequential([
49 | ia.Resize((0.125, 0.25), cv2.INTER_AREA),
50 | ia.Resize({
51 | 'height': 512,
52 | 'width': 512
53 | }, cv2.INTER_CUBIC),
54 | ]),
55 | 'denoise':
56 | ia.OneOf([
57 | ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True),
58 | ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True),
59 | ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True),
60 | ]),
61 | 'deblur':
62 | ia.OneOf([
63 | ia.MotionBlur(k=(10, 20)),
64 | ia.GaussianBlur((3.0, 8.0)),
65 | ]),
66 | 'jpeg':
67 | ia.JpegCompression(compression=(50, 85)),
68 | '16x':
69 | Mosaic16x(),
70 | }
71 |
72 | rand_deg_list = [
73 | self.default_deg_templates['deblur'],
74 | self.default_deg_templates['denoise'],
75 | self.default_deg_templates['jpeg'],
76 | self.default_deg_templates['sr4x8x'],
77 | ]
78 | self.default_deg_templates['face_renov'] = ia.Sequential(rand_deg_list, random_order=True)
79 |
80 | def create_training_dataset(self, deg, gt_folder, lq_folder=None):
81 | from imgaug.augmenters.meta import Augmenter # baseclass
82 | """
83 | Create a degradation simulator and apply it to GT images on the fly
84 | Save the degraded result in the lq_folder (if None, name it as GT_deg)
85 | """
86 | if not lq_folder:
87 | suffix = deg if isinstance(deg, str) else 'custom'
88 | lq_folder = '_'.join([gt_folder.replace('gt', 'lq'), suffix])
89 | print(lq_folder)
90 | os.makedirs(lq_folder, exist_ok=True)
91 |
92 | if isinstance(deg, str):
93 | assert deg in self.default_deg_templates, (
94 | f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}')
95 | deg = self.default_deg_templates[deg]
96 | else:
97 | assert isinstance(deg, Augmenter), f'Deg must be either str|Augmenter, got {deg}'
98 |
99 | names = os.listdir(gt_folder)
100 | for name in tqdm(names):
101 | gt = cv2.imread(os.path.join(gt_folder, name))
102 | lq = deg.augment_image(gt)
103 | # pack = np.concatenate([lq, gt], axis=0)
104 | cv2.imwrite(os.path.join(lq_folder, name), lq)
105 |
106 | print('Dataset prepared.')
107 |
108 |
109 | if __name__ == '__main__':
110 | simuator = DegradationSimulator()
111 | gt_folder = 'datasets/FFHQ_512_gt'
112 | deg = 'sr4x'
113 | simuator.create_training_dataset(deg, gt_folder)
114 |
--------------------------------------------------------------------------------
/scripts/data_preparation/regroup_reds_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 |
5 | def regroup_reds_dataset(train_path, val_path):
6 | """Regroup original REDS datasets.
7 |
8 | We merge train and validation data into one folder, and separate the
9 | validation clips in reds_dataset.py.
10 | There are 240 training clips (starting from 0 to 239),
11 | so we name the validation clip index starting from 240 to 269 (total 30
12 | validation clips).
13 |
14 | Args:
15 | train_path (str): Path to the train folder.
16 | val_path (str): Path to the validation folder.
17 | """
18 | # move the validation data to the train folder
19 | val_folders = glob.glob(os.path.join(val_path, '*'))
20 | for folder in val_folders:
21 | new_folder_idx = int(folder.split('/')[-1]) + 240
22 | os.system(f'cp -r {folder} {os.path.join(train_path, str(new_folder_idx))}')
23 |
24 |
25 | if __name__ == '__main__':
26 | # train_sharp
27 | train_path = 'datasets/REDS/train_sharp'
28 | val_path = 'datasets/REDS/val_sharp'
29 | regroup_reds_dataset(train_path, val_path)
30 |
31 | # train_sharp_bicubic
32 | train_path = 'datasets/REDS/train_sharp_bicubic/X4'
33 | val_path = 'datasets/REDS/val_sharp_bicubic/X4'
34 | regroup_reds_dataset(train_path, val_path)
35 |
--------------------------------------------------------------------------------
/scripts/download_gdrive.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | from basicsr.utils.download_util import download_file_from_google_drive
4 |
5 | if __name__ == '__main__':
6 | parser = argparse.ArgumentParser()
7 |
8 | parser.add_argument('--id', type=str, help='File id')
9 | parser.add_argument('--output', type=str, help='Save path')
10 | args = parser.parse_args()
11 |
12 | download_file_from_google_drive(args.id, args.save_path)
13 |
--------------------------------------------------------------------------------
/scripts/download_pretrained_models.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from os import path as osp
4 |
5 | from basicsr.utils.download_util import download_file_from_google_drive
6 |
7 |
8 | def download_pretrained_models(method, file_ids):
9 | save_path_root = f'./experiments/pretrained_models/{method}'
10 | os.makedirs(save_path_root, exist_ok=True)
11 |
12 | for file_name, file_id in file_ids.items():
13 | save_path = osp.abspath(osp.join(save_path_root, file_name))
14 | if osp.exists(save_path):
15 | user_response = input(f'{file_name} already exist. Do you want to cover it? Y/N\n')
16 | if user_response.lower() == 'y':
17 | print(f'Covering {file_name} to {save_path}')
18 | download_file_from_google_drive(file_id, save_path)
19 | elif user_response.lower() == 'n':
20 | print(f'Skipping {file_name}')
21 | else:
22 | raise ValueError('Wrong input. Only accepts Y/N.')
23 | else:
24 | print(f'Downloading {file_name} to {save_path}')
25 | download_file_from_google_drive(file_id, save_path)
26 |
27 |
28 | if __name__ == '__main__':
29 | parser = argparse.ArgumentParser()
30 |
31 | parser.add_argument(
32 | 'method',
33 | type=str,
34 | help=("Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
35 | "Set to 'all' to download all the models."))
36 | args = parser.parse_args()
37 |
38 | file_ids = {
39 | 'ESRGAN': {
40 | 'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name
41 | '1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id
42 | 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
43 | },
44 | 'EDVR': {
45 | 'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb',
46 | 'EDVR_L_x4_SR_Vimeo90K_official-162b54e4.pth': '1aVR3lkX6ItCphNLcT7F5bbbC484h4Qqy',
47 | 'EDVR_M_woTSA_x4_SR_REDS_official-1edf645c.pth': '1C_WdN-NyNj-P7SOB5xIVuHl4EBOwd-Ny',
48 | 'EDVR_M_x4_SR_REDS_official-32075921.pth': '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6',
49 | 'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl',
50 | 'EDVR_L_deblur_REDS_official-ca46bd8c.pth': '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE',
51 | 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
52 | },
53 | 'StyleGAN': {
54 | 'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth': '1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g',
55 | 'stylegan2_ffhq_config_f_1024_discriminator_official-a386354a.pth': '1nPqCxm8TkDU3IvXdHCzPUxlBwR5Pd78G',
56 | 'stylegan2_cat_config_f_256_official-0a9173ad.pth': '1gfJkX6XO5pJ2J8LyMdvUgGldz7xwWpBJ',
57 | 'stylegan2_cat_config_f_256_discriminator_official-2c97fd08.pth': '1hy5FEQQl28XvfqpiWvSBd8YnIzsyDRb7',
58 | 'stylegan2_church_config_f_256_official-44ba63bf.pth': '1FCQMZXeOKZyl-xYKbl1Y_x2--rFl-1N_',
59 | 'stylegan2_church_config_f_256_discriminator_official-20cd675b.pth': # noqa: E501
60 | '1BS9ODHkUkhfTGFVfR6alCMGtr9nGm9ox',
61 | 'stylegan2_car_config_f_512_official-e8fcab4f.pth': '14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva',
62 | 'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth': '1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2',
63 | 'stylegan2_horse_config_f_256_official-26d57fee.pth': '12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx',
64 | 'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4'
65 | },
66 | 'EDSR': {
67 | 'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV',
68 | 'EDSR_Mx3_f64b16_DIV2K_official-6908f88a.pth': '1EriqQqlIiRyPbrYGBbwr_FZzvb3iwqz5',
69 | 'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn',
70 | 'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU',
71 | 'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su',
72 | 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
73 | },
74 | 'DUF': {
75 | 'DUF_x2_16L_official-39537cb9.pth': '1e91cEZOlUUk35keK9EnuK0F54QegnUKo',
76 | 'DUF_x3_16L_official-34ce53ec.pth': '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76',
77 | 'DUF_x4_16L_official-bf8f0cfa.pth': '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J',
78 | 'DUF_x4_28L_official-cbada450.pth': '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4',
79 | 'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
80 | },
81 | 'TOF': {
82 | 'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'
83 | },
84 | 'DFDNet': {
85 | 'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
86 | 'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
87 | },
88 | 'dlib': {
89 | 'mmod_human_face_detector-4cb19393.dat': '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL',
90 | 'shape_predictor_5_face_landmarks-c4b1e980.dat': '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F',
91 | 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
92 | },
93 | 'flownet': {
94 | 'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'
95 | },
96 | 'BasicVSR': {
97 | 'BasicVSR_REDS4-543c8261.pth': '1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p',
98 | 'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth': '1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW',
99 | 'BasicVSR_Vimeo90K_BIx4-2a29695a.pth': '1ykIu1jv5wo95Kca2TjlieJFxeV4VVfHP',
100 | 'EDVR_REDS_pretrained_for_IconVSR-f62a2f1e.pth': '1ShfwddugTmT3_kB8VL6KpCMrIpEO5sBi',
101 | 'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth': '16vR262NDVyVv5Q49xp2Sb-Llu05f63tt',
102 | 'IconVSR_REDS-aaa5367f.pth': '1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz',
103 | 'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth': '13lp55s-YTd-fApx8tTy24bbHsNIGXdAH',
104 | 'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g'
105 | }
106 | }
107 |
108 | if args.method == 'all':
109 | for method in file_ids.keys():
110 | download_pretrained_models(method, file_ids[method])
111 | else:
112 | download_pretrained_models(args.method, file_ids[args.method])
113 |
--------------------------------------------------------------------------------
/scripts/matlab_scripts/back_projection/backprojection.m:
--------------------------------------------------------------------------------
1 | function [im_h] = backprojection(im_h, im_l, maxIter)
2 |
3 | [row_l, col_l,~] = size(im_l);
4 | [row_h, col_h,~] = size(im_h);
5 |
6 | p = fspecial('gaussian', 5, 1);
7 | p = p.^2;
8 | p = p./sum(p(:));
9 |
10 | im_l = double(im_l);
11 | im_h = double(im_h);
12 |
13 | for ii = 1:maxIter
14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic');
15 | im_diff = im_l - im_l_s;
16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic');
17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same');
18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same');
19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same');
20 | end
21 |
--------------------------------------------------------------------------------
/scripts/matlab_scripts/back_projection/main_bp.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20bp';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | %tic
19 | im_out = backprojection(im_out, im_LR, max_iter);
20 | %toc
21 | imwrite(im_out, fullfile(save_folder, im_name));
22 | end
23 |
--------------------------------------------------------------------------------
/scripts/matlab_scripts/back_projection/main_reverse_filter.m:
--------------------------------------------------------------------------------
1 | clear; close all; clc;
2 |
3 | LR_folder = './LR'; % LR
4 | preout_folder = './results'; % pre output
5 | save_folder = './results_20if';
6 | filepaths = dir(fullfile(preout_folder, '*.png'));
7 | max_iter = 20;
8 |
9 | if ~ exist(save_folder, 'dir')
10 | mkdir(save_folder);
11 | end
12 |
13 | for idx_im = 1:length(filepaths)
14 | fprintf([num2str(idx_im) '\n']);
15 | im_name = filepaths(idx_im).name;
16 | im_LR = im2double(imread(fullfile(LR_folder, im_name)));
17 | im_out = im2double(imread(fullfile(preout_folder, im_name)));
18 | J = imresize(im_LR,4,'bicubic');
19 | %tic
20 | for m = 1:max_iter
21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic'));
22 | end
23 | %toc
24 | imwrite(im_out, fullfile(save_folder, im_name));
25 | end
26 |
--------------------------------------------------------------------------------
/scripts/matlab_scripts/generate_LR_Vimeo90K.m:
--------------------------------------------------------------------------------
1 | function generate_LR_Vimeo90K()
2 | %% matlab code to genetate bicubic-downsampled for Vimeo90K dataset
3 |
4 | up_scale = 4;
5 | mod_scale = 4;
6 | idx = 0;
7 | filepaths = dir('/home/xtwang/datasets/vimeo90k/vimeo_septuplet/sequences/*/*/*.png');
8 | for i = 1 : length(filepaths)
9 | [~,imname,ext] = fileparts(filepaths(i).name);
10 | folder_path = filepaths(i).folder;
11 | save_LR_folder = strrep(folder_path,'vimeo_septuplet','vimeo_septuplet_matlabLRx4');
12 | if ~exist(save_LR_folder, 'dir')
13 | mkdir(save_LR_folder);
14 | end
15 | if isempty(imname)
16 | disp('Ignore . folder.');
17 | elseif strcmp(imname, '.')
18 | disp('Ignore .. folder.');
19 | else
20 | idx = idx + 1;
21 | str_result = sprintf('%d\t%s.\n', idx, imname);
22 | fprintf(str_result);
23 | % read image
24 | img = imread(fullfile(folder_path, [imname, ext]));
25 | img = im2double(img);
26 | % modcrop
27 | img = modcrop(img, mod_scale);
28 | % LR
29 | im_LR = imresize(img, 1/up_scale, 'bicubic');
30 | if exist('save_LR_folder', 'var')
31 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '.png']));
32 | end
33 | end
34 | end
35 | end
36 |
37 | %% modcrop
38 | function img = modcrop(img, modulo)
39 | if size(img,3) == 1
40 | sz = size(img);
41 | sz = sz - mod(sz, modulo);
42 | img = img(1:sz(1), 1:sz(2));
43 | else
44 | tmpsz = size(img);
45 | sz = tmpsz(1:2);
46 | sz = sz - mod(sz, modulo);
47 | img = img(1:sz(1), 1:sz(2),:);
48 | end
49 | end
50 |
--------------------------------------------------------------------------------
/scripts/matlab_scripts/generate_bicubic_img.m:
--------------------------------------------------------------------------------
1 | function generate_bicubic_img()
2 | %% matlab code to genetate mod images, bicubic-downsampled images and
3 | %% bicubic_upsampled images
4 |
5 | %% set configurations
6 | % comment the unnecessary lines
7 | input_folder = '../../datasets/Set5/original';
8 | save_mod_folder = '../../datasets/Set5/GTmod12';
9 | save_lr_folder = '../../datasets/Set5/LRbicx2';
10 | % save_bic_folder = '';
11 |
12 | mod_scale = 12;
13 | up_scale = 2;
14 |
15 | if exist('save_mod_folder', 'var')
16 | if exist(save_mod_folder, 'dir')
17 | disp(['It will cover ', save_mod_folder]);
18 | else
19 | mkdir(save_mod_folder);
20 | end
21 | end
22 | if exist('save_lr_folder', 'var')
23 | if exist(save_lr_folder, 'dir')
24 | disp(['It will cover ', save_lr_folder]);
25 | else
26 | mkdir(save_lr_folder);
27 | end
28 | end
29 | if exist('save_bic_folder', 'var')
30 | if exist(save_bic_folder, 'dir')
31 | disp(['It will cover ', save_bic_folder]);
32 | else
33 | mkdir(save_bic_folder);
34 | end
35 | end
36 |
37 | idx = 0;
38 | filepaths = dir(fullfile(input_folder,'*.*'));
39 | for i = 1 : length(filepaths)
40 | [paths, img_name, ext] = fileparts(filepaths(i).name);
41 | if isempty(img_name)
42 | disp('Ignore . folder.');
43 | elseif strcmp(img_name, '.')
44 | disp('Ignore .. folder.');
45 | else
46 | idx = idx + 1;
47 | str_result = sprintf('%d\t%s.\n', idx, img_name);
48 | fprintf(str_result);
49 |
50 | % read image
51 | img = imread(fullfile(input_folder, [img_name, ext]));
52 | img = im2double(img);
53 |
54 | % modcrop
55 | img = modcrop(img, mod_scale);
56 | if exist('save_mod_folder', 'var')
57 | imwrite(img, fullfile(save_mod_folder, [img_name, '.png']));
58 | end
59 |
60 | % LR
61 | im_lr = imresize(img, 1/up_scale, 'bicubic');
62 | if exist('save_lr_folder', 'var')
63 | imwrite(im_lr, fullfile(save_lr_folder, [img_name, '.png']));
64 | end
65 |
66 | % Bicubic
67 | if exist('save_bic_folder', 'var')
68 | im_bicubic = imresize(im_lr, up_scale, 'bicubic');
69 | imwrite(im_bicubic, fullfile(save_bic_folder, [img_name, '.png']));
70 | end
71 | end
72 | end
73 | end
74 |
75 | %% modcrop
76 | function img = modcrop(img, modulo)
77 | if size(img,3) == 1
78 | sz = size(img);
79 | sz = sz - mod(sz, modulo);
80 | img = img(1:sz(1), 1:sz(2));
81 | else
82 | tmpsz = size(img);
83 | sz = tmpsz(1:2);
84 | sz = sz - mod(sz, modulo);
85 | img = img(1:sz(1), 1:sz(2),:);
86 | end
87 | end
88 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_fid_folder.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from basicsr.data import build_dataset
8 | from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
9 |
10 |
11 | def calculate_fid_folder():
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('folder', type=str, help='Path to the folder.')
16 | parser.add_argument('--fid_stats', type=str, help='Path to the dataset fid statistics.')
17 | parser.add_argument('--batch_size', type=int, default=64)
18 | parser.add_argument('--num_sample', type=int, default=50000)
19 | parser.add_argument('--num_workers', type=int, default=4)
20 | parser.add_argument('--backend', type=str, default='disk', help='io backend for dataset. Option: disk, lmdb')
21 | args = parser.parse_args()
22 |
23 | # inception model
24 | inception = load_patched_inception_v3(device)
25 |
26 | # create dataset
27 | opt = {}
28 | opt['name'] = 'SingleImageDataset'
29 | opt['type'] = 'SingleImageDataset'
30 | opt['dataroot_lq'] = args.folder
31 | opt['io_backend'] = dict(type=args.backend)
32 | opt['mean'] = [0.5, 0.5, 0.5]
33 | opt['std'] = [0.5, 0.5, 0.5]
34 | dataset = build_dataset(opt)
35 |
36 | # create dataloader
37 | data_loader = DataLoader(
38 | dataset=dataset,
39 | batch_size=args.batch_size,
40 | shuffle=False,
41 | num_workers=args.num_workers,
42 | sampler=None,
43 | drop_last=False)
44 | args.num_sample = min(args.num_sample, len(dataset))
45 | total_batch = math.ceil(args.num_sample / args.batch_size)
46 |
47 | def data_generator(data_loader, total_batch):
48 | for idx, data in enumerate(data_loader):
49 | if idx >= total_batch:
50 | break
51 | else:
52 | yield data['lq']
53 |
54 | features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
55 | features = features.numpy()
56 | total_len = features.shape[0]
57 | features = features[:args.num_sample]
58 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
59 |
60 | sample_mean = np.mean(features, 0)
61 | sample_cov = np.cov(features, rowvar=False)
62 |
63 | # load the dataset stats
64 | stats = torch.load(args.fid_stats)
65 | real_mean = stats['mean']
66 | real_cov = stats['cov']
67 |
68 | # calculate FID metric
69 | fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
70 | print('fid:', fid)
71 |
72 |
73 | if __name__ == '__main__':
74 | calculate_fid_folder()
75 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_fid_stats_from_datasets.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch.utils.data import DataLoader
6 |
7 | from basicsr.data import build_dataset
8 | from basicsr.metrics.fid import extract_inception_features, load_patched_inception_v3
9 |
10 |
11 | def calculate_stats_from_dataset():
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--num_sample', type=int, default=50000)
16 | parser.add_argument('--batch_size', type=int, default=64)
17 | parser.add_argument('--size', type=int, default=512)
18 | parser.add_argument('--dataroot', type=str, default='datasets/ffhq')
19 | args = parser.parse_args()
20 |
21 | # inception model
22 | inception = load_patched_inception_v3(device)
23 |
24 | # create dataset
25 | opt = {}
26 | opt['name'] = 'FFHQ'
27 | opt['type'] = 'FFHQDataset'
28 | opt['dataroot_gt'] = f'datasets/ffhq/ffhq_{args.size}.lmdb'
29 | opt['io_backend'] = dict(type='lmdb')
30 | opt['use_hflip'] = False
31 | opt['mean'] = [0.5, 0.5, 0.5]
32 | opt['std'] = [0.5, 0.5, 0.5]
33 | dataset = build_dataset(opt)
34 |
35 | # create dataloader
36 | data_loader = DataLoader(
37 | dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False)
38 | total_batch = math.ceil(args.num_sample / args.batch_size)
39 |
40 | def data_generator(data_loader, total_batch):
41 | for idx, data in enumerate(data_loader):
42 | if idx >= total_batch:
43 | break
44 | else:
45 | yield data['gt']
46 |
47 | features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
48 | features = features.numpy()
49 | total_len = features.shape[0]
50 | features = features[:args.num_sample]
51 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
52 | mean = np.mean(features, 0)
53 | cov = np.cov(features, rowvar=False)
54 |
55 | save_path = f'inception_{opt["name"]}_{args.size}.pth'
56 | torch.save(
57 | dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False)
58 |
59 |
60 | if __name__ == '__main__':
61 | calculate_stats_from_dataset()
62 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_lpips.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import glob
3 | import numpy as np
4 | import os.path as osp
5 | from torchvision.transforms.functional import normalize
6 |
7 | from basicsr.utils import img2tensor
8 |
9 | try:
10 | import lpips
11 | except ImportError:
12 | print('Please install lpips: pip install lpips')
13 |
14 |
15 | def main():
16 | # Configurations
17 | # -------------------------------------------------------------------------
18 | folder_gt = 'datasets/celeba/celeba_512_validation'
19 | folder_restored = 'datasets/celeba/celeba_512_validation_lq'
20 | # crop_border = 4
21 | suffix = ''
22 | # -------------------------------------------------------------------------
23 | loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() # RGB, normalized to [-1,1]
24 | lpips_all = []
25 | img_list = sorted(glob.glob(osp.join(folder_gt, '*')))
26 |
27 | mean = [0.5, 0.5, 0.5]
28 | std = [0.5, 0.5, 0.5]
29 | for i, img_path in enumerate(img_list):
30 | basename, ext = osp.splitext(osp.basename(img_path))
31 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
32 | img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(
33 | np.float32) / 255.
34 |
35 | img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
36 | # norm to [-1, 1]
37 | normalize(img_gt, mean, std, inplace=True)
38 | normalize(img_restored, mean, std, inplace=True)
39 |
40 | # calculate lpips
41 | lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
42 |
43 | print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
44 | lpips_all.append(lpips_val)
45 |
46 | print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
47 |
48 |
49 | if __name__ == '__main__':
50 | main()
51 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_niqe.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import os
4 | import warnings
5 |
6 | from basicsr.metrics import calculate_niqe
7 | from basicsr.utils import scandir
8 |
9 |
10 | def main(args):
11 |
12 | niqe_all = []
13 | img_list = sorted(scandir(args.input, recursive=True, full_path=True))
14 |
15 | for i, img_path in enumerate(img_list):
16 | basename, _ = os.path.splitext(os.path.basename(img_path))
17 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
18 |
19 | with warnings.catch_warnings():
20 | warnings.simplefilter('ignore', category=RuntimeWarning)
21 | niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y')
22 | print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
23 | niqe_all.append(niqe_score)
24 |
25 | print(args.input)
26 | print(f'Average: NIQE: {sum(niqe_all) / len(niqe_all):.6f}')
27 |
28 |
29 | if __name__ == '__main__':
30 | parser = argparse.ArgumentParser()
31 | parser.add_argument('--input', type=str, default='datasets/val_set14/Set14', help='Input path')
32 | parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
33 | args = parser.parse_args()
34 | main(args)
35 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import numpy as np
4 | from os import path as osp
5 |
6 | from basicsr.metrics import calculate_psnr, calculate_ssim
7 | from basicsr.utils import scandir
8 | from basicsr.utils.matlab_functions import bgr2ycbcr
9 |
10 |
11 | def main(args):
12 | """Calculate PSNR and SSIM for images.
13 | """
14 | psnr_all = []
15 | ssim_all = []
16 | img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
17 | img_list_restored = sorted(list(scandir(args.restored, recursive=True, full_path=True)))
18 |
19 | if args.test_y_channel:
20 | print('Testing Y channel.')
21 | else:
22 | print('Testing RGB channels.')
23 |
24 | for i, img_path in enumerate(img_list_gt):
25 | basename, ext = osp.splitext(osp.basename(img_path))
26 | img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
27 | if args.suffix == '':
28 | img_path_restored = img_list_restored[i]
29 | else:
30 | img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
31 | img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
32 |
33 | if args.correct_mean_var:
34 | mean_l = []
35 | std_l = []
36 | for j in range(3):
37 | mean_l.append(np.mean(img_gt[:, :, j]))
38 | std_l.append(np.std(img_gt[:, :, j]))
39 | for j in range(3):
40 | # correct twice
41 | mean = np.mean(img_restored[:, :, j])
42 | img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
43 | std = np.std(img_restored[:, :, j])
44 | img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
45 |
46 | mean = np.mean(img_restored[:, :, j])
47 | img_restored[:, :, j] = img_restored[:, :, j] - mean + mean_l[j]
48 | std = np.std(img_restored[:, :, j])
49 | img_restored[:, :, j] = img_restored[:, :, j] / std * std_l[j]
50 |
51 | if args.test_y_channel and img_gt.ndim == 3 and img_gt.shape[2] == 3:
52 | img_gt = bgr2ycbcr(img_gt, y_only=True)
53 | img_restored = bgr2ycbcr(img_restored, y_only=True)
54 |
55 | # calculate PSNR and SSIM
56 | psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
57 | ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
58 | print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
59 | psnr_all.append(psnr)
60 | ssim_all.append(ssim)
61 | print(args.gt)
62 | print(args.restored)
63 | print(f'Average: PSNR: {sum(psnr_all) / len(psnr_all):.6f} dB, SSIM: {sum(ssim_all) / len(ssim_all):.6f}')
64 |
65 |
66 | if __name__ == '__main__':
67 | parser = argparse.ArgumentParser()
68 | parser.add_argument('--gt', type=str, default='datasets/val_set14/Set14', help='Path to gt (Ground-Truth)')
69 | parser.add_argument('--restored', type=str, default='results/Set14', help='Path to restored images')
70 | parser.add_argument('--crop_border', type=int, default=0, help='Crop border for each side')
71 | parser.add_argument('--suffix', type=str, default='', help='Suffix for restored images')
72 | parser.add_argument(
73 | '--test_y_channel',
74 | action='store_true',
75 | help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.')
76 | parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.')
77 | args = parser.parse_args()
78 | main(args)
79 |
--------------------------------------------------------------------------------
/scripts/metrics/calculate_stylegan2_fid.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import math
3 | import numpy as np
4 | import torch
5 | from torch import nn
6 |
7 | from basicsr.archs.stylegan2_arch import StyleGAN2Generator
8 | from basicsr.metrics.fid import calculate_fid, extract_inception_features, load_patched_inception_v3
9 |
10 |
11 | def calculate_stylegan2_fid():
12 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('ckpt', type=str, help='Path to the stylegan2 checkpoint.')
16 | parser.add_argument('fid_stats', type=str, help='Path to the dataset fid statistics.')
17 | parser.add_argument('--size', type=int, default=256)
18 | parser.add_argument('--channel_multiplier', type=int, default=2)
19 | parser.add_argument('--batch_size', type=int, default=64)
20 | parser.add_argument('--num_sample', type=int, default=50000)
21 | parser.add_argument('--truncation', type=float, default=1)
22 | parser.add_argument('--truncation_mean', type=int, default=4096)
23 | args = parser.parse_args()
24 |
25 | # create stylegan2 model
26 | generator = StyleGAN2Generator(
27 | out_size=args.size,
28 | num_style_feat=512,
29 | num_mlp=8,
30 | channel_multiplier=args.channel_multiplier,
31 | resample_kernel=(1, 3, 3, 1))
32 | generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
33 | generator = nn.DataParallel(generator).eval().to(device)
34 |
35 | if args.truncation < 1:
36 | with torch.no_grad():
37 | truncation_latent = generator.mean_latent(args.truncation_mean)
38 | else:
39 | truncation_latent = None
40 |
41 | # inception model
42 | inception = load_patched_inception_v3(device)
43 |
44 | total_batch = math.ceil(args.num_sample / args.batch_size)
45 |
46 | def sample_generator(total_batch):
47 | for _ in range(total_batch):
48 | with torch.no_grad():
49 | latent = torch.randn(args.batch_size, 512, device=device)
50 | samples, _ = generator([latent], truncation=args.truncation, truncation_latent=truncation_latent)
51 | yield samples
52 |
53 | features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device)
54 | features = features.numpy()
55 | total_len = features.shape[0]
56 | features = features[:args.num_sample]
57 | print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
58 | sample_mean = np.mean(features, 0)
59 | sample_cov = np.cov(features, rowvar=False)
60 |
61 | # load the dataset stats
62 | stats = torch.load(args.fid_stats)
63 | real_mean = stats['mean']
64 | real_cov = stats['cov']
65 |
66 | # calculate FID metric
67 | fid = calculate_fid(sample_mean, sample_cov, real_mean, real_cov)
68 | print('fid:', fid)
69 |
70 |
71 | if __name__ == '__main__':
72 | calculate_stylegan2_fid()
73 |
--------------------------------------------------------------------------------
/scripts/model_conversion/convert_dfdnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from basicsr.archs.dfdnet_arch import DFDNet
4 | from basicsr.archs.vgg_arch import NAMES
5 |
6 |
7 | def convert_net(ori_net, crt_net):
8 |
9 | for crt_k, _ in crt_net.items():
10 | # vgg feature extractor
11 | if 'vgg_extractor' in crt_k:
12 | ori_k = crt_k.replace('vgg_extractor', 'VggExtract').replace('vgg_net', 'model')
13 | if 'mean' in crt_k:
14 | ori_k = ori_k.replace('mean', 'RGB_mean')
15 | elif 'std' in crt_k:
16 | ori_k = ori_k.replace('std', 'RGB_std')
17 | else:
18 | idx = NAMES['vgg19'].index(crt_k.split('.')[2])
19 | if 'weight' in crt_k:
20 | ori_k = f'VggExtract.model.features.{idx}.weight'
21 | else:
22 | ori_k = f'VggExtract.model.features.{idx}.bias'
23 | elif 'attn_blocks' in crt_k:
24 | if 'left_eye' in crt_k:
25 | ori_k = crt_k.replace('attn_blocks.left_eye', 'le')
26 | elif 'right_eye' in crt_k:
27 | ori_k = crt_k.replace('attn_blocks.right_eye', 're')
28 | elif 'mouth' in crt_k:
29 | ori_k = crt_k.replace('attn_blocks.mouth', 'mo')
30 | elif 'nose' in crt_k:
31 | ori_k = crt_k.replace('attn_blocks.nose', 'no')
32 | else:
33 | raise ValueError('Wrong!')
34 | elif 'multi_scale_dilation' in crt_k:
35 | if 'conv_blocks' in crt_k:
36 | _, _, c, d, e = crt_k.split('.')
37 | ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}'
38 | else:
39 | ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi')
40 |
41 | elif crt_k.startswith('upsample'):
42 | ori_k = crt_k.replace('upsample', 'up')
43 | if 'scale_block' in crt_k:
44 | ori_k = ori_k.replace('scale_block', 'ScaleModel1')
45 | elif 'shift_block' in crt_k:
46 | ori_k = ori_k.replace('shift_block', 'ShiftModel1')
47 |
48 | elif 'upsample4' in crt_k and 'body' in crt_k:
49 | ori_k = ori_k.replace('body', 'Model')
50 |
51 | else:
52 | print('unprocess key: ', crt_k)
53 |
54 | # replace
55 | if crt_net[crt_k].size() != ori_net[ori_k].size():
56 | raise ValueError('Wrong tensor size: \n'
57 | f'crt_net: {crt_net[crt_k].size()}\n'
58 | f'ori_net: {ori_net[ori_k].size()}')
59 | else:
60 | crt_net[crt_k] = ori_net[ori_k]
61 |
62 | return crt_net
63 |
64 |
65 | if __name__ == '__main__':
66 | ori_net = torch.load('experiments/pretrained_models/DFDNet/DFDNet_official_original.pth')
67 | dfd_net = DFDNet(64, dict_path='experiments/pretrained_models/DFDNet/DFDNet_dict_512.pth')
68 | crt_net = dfd_net.state_dict()
69 | crt_net_params = convert_net(ori_net, crt_net)
70 |
71 | torch.save(
72 | dict(params=crt_net_params),
73 | 'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
74 | _use_new_zipfile_serialization=False)
75 |
--------------------------------------------------------------------------------
/scripts/model_conversion/convert_ridnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | from basicsr.archs.ridnet_arch import RIDNet
5 |
6 | if __name__ == '__main__':
7 | ori_net_checkpoint = torch.load(
8 | 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage)
9 | rid_net = RIDNet(3, 64, 3)
10 | new_ridnet_dict = OrderedDict()
11 |
12 | rid_net_namelist = []
13 | for name, param in rid_net.named_parameters():
14 | rid_net_namelist.append(name)
15 |
16 | count = 0
17 | for name, param in ori_net_checkpoint.items():
18 | new_ridnet_dict[rid_net_namelist[count]] = param
19 | count += 1
20 |
21 | rid_net.load_state_dict(new_ridnet_dict)
22 | torch.save(rid_net.state_dict(), 'experiments/pretrained_models/RIDNet/RIDNet.pth')
23 |
--------------------------------------------------------------------------------
/scripts/model_conversion/convert_stylegan.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | from basicsr.archs.stylegan2_arch import StyleGAN2Discriminator, StyleGAN2Generator
4 |
5 |
6 | def convert_net_g(ori_net, crt_net):
7 | """Convert network generator."""
8 |
9 | for crt_k, crt_v in crt_net.items():
10 | if 'style_mlp' in crt_k:
11 | ori_k = crt_k.replace('style_mlp', 'style')
12 | elif 'constant_input.weight' in crt_k:
13 | ori_k = crt_k.replace('constant_input.weight', 'input.input')
14 | # style conv1
15 | elif 'style_conv1.modulated_conv' in crt_k:
16 | ori_k = crt_k.replace('style_conv1.modulated_conv', 'conv1.conv')
17 | elif 'style_conv1' in crt_k:
18 | if crt_v.shape == torch.Size([1]):
19 | ori_k = crt_k.replace('style_conv1', 'conv1.noise')
20 | else:
21 | ori_k = crt_k.replace('style_conv1', 'conv1')
22 | # style conv
23 | elif 'style_convs' in crt_k:
24 | ori_k = crt_k.replace('style_convs', 'convs').replace('modulated_conv', 'conv')
25 | if crt_v.shape == torch.Size([1]):
26 | ori_k = ori_k.replace('.weight', '.noise.weight')
27 | # to_rgb1
28 | elif 'to_rgb1.modulated_conv' in crt_k:
29 | ori_k = crt_k.replace('to_rgb1.modulated_conv', 'to_rgb1.conv')
30 | # to_rgbs
31 | elif 'to_rgbs' in crt_k:
32 | ori_k = crt_k.replace('modulated_conv', 'conv')
33 | elif 'noises' in crt_k:
34 | ori_k = crt_k.replace('.noise', '.noise_')
35 | else:
36 | ori_k = crt_k
37 |
38 | # replace
39 | if crt_net[crt_k].size() != ori_net[ori_k].size():
40 | raise ValueError('Wrong tensor size: \n'
41 | f'crt_net: {crt_net[crt_k].size()}\n'
42 | f'ori_net: {ori_net[ori_k].size()}')
43 | else:
44 | crt_net[crt_k] = ori_net[ori_k]
45 |
46 | return crt_net
47 |
48 |
49 | def convert_net_d(ori_net, crt_net):
50 | """Convert network discriminator."""
51 |
52 | for crt_k, _ in crt_net.items():
53 | if 'conv_body' in crt_k:
54 | ori_k = crt_k.replace('conv_body', 'convs')
55 | else:
56 | ori_k = crt_k
57 |
58 | # replace
59 | if crt_net[crt_k].size() != ori_net[ori_k].size():
60 | raise ValueError('Wrong tensor size: \n'
61 | f'crt_net: {crt_net[crt_k].size()}\n'
62 | f'ori_net: {ori_net[ori_k].size()}')
63 | else:
64 | crt_net[crt_k] = ori_net[ori_k]
65 | return crt_net
66 |
67 |
68 | if __name__ == '__main__':
69 | """Convert official stylegan2 weights from stylegan2-pytorch."""
70 |
71 | # configuration
72 | ori_net = torch.load('experiments/pretrained_models/stylegan2-ffhq.pth')
73 | save_path_g = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_official.pth' # noqa: E501
74 | save_path_d = 'experiments/pretrained_models/stylegan2_ffhq_config_f_1024_discriminator_official.pth' # noqa: E501
75 | out_size = 1024
76 | channel_multiplier = 1
77 |
78 | # convert generator
79 | crt_net = StyleGAN2Generator(out_size, num_style_feat=512, num_mlp=8, channel_multiplier=channel_multiplier)
80 | crt_net = crt_net.state_dict()
81 |
82 | crt_net_params_ema = convert_net_g(ori_net['g_ema'], crt_net)
83 | torch.save(dict(params_ema=crt_net_params_ema, latent_avg=ori_net['latent_avg']), save_path_g)
84 |
85 | # convert discriminator
86 | crt_net = StyleGAN2Discriminator(out_size, channel_multiplier=channel_multiplier)
87 | crt_net = crt_net.state_dict()
88 |
89 | crt_net_params = convert_net_d(ori_net['d'], crt_net)
90 | torch.save(dict(params=crt_net_params), save_path_d)
91 |
--------------------------------------------------------------------------------
/scripts/publish_models.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import subprocess
3 | import torch
4 | from os import path as osp
5 | from torch.serialization import _is_zipfile, _open_file_like
6 |
7 |
8 | def update_sha(paths):
9 | print('# Update sha ...')
10 | for idx, path in enumerate(paths):
11 | print(f'{idx+1:03d}: Processing {path}')
12 | net = torch.load(path, map_location=torch.device('cpu'))
13 | basename = osp.basename(path)
14 | if 'params' not in net and 'params_ema' not in net:
15 | user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. '
16 | 'Do you still want to continue? Y/N\n')
17 | if user_response.lower() == 'y':
18 | pass
19 | elif user_response.lower() == 'n':
20 | raise ValueError('Please modify..')
21 | else:
22 | raise ValueError('Wrong input. Only accepts Y/N.')
23 |
24 | if '-' in basename:
25 | # check whether the sha is the latest
26 | old_sha = basename.split('-')[1].split('.')[0]
27 | new_sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
28 | if old_sha != new_sha:
29 | final_file = path.split('-')[0] + f'-{new_sha}.pth'
30 | print(f'\tSave from {path} to {final_file}')
31 | subprocess.Popen(['mv', path, final_file])
32 | else:
33 | sha = subprocess.check_output(['sha256sum', path]).decode()[:8]
34 | final_file = path.split('.pth')[0] + f'-{sha}.pth'
35 | print(f'\tSave from {path} to {final_file}')
36 | subprocess.Popen(['mv', path, final_file])
37 |
38 |
39 | def convert_to_backward_compatible_models(paths):
40 | """Convert to backward compatible pth files.
41 |
42 | PyTorch 1.6 uses a updated version of torch.save. In order to be compatible
43 | with previous PyTorch version, save it with
44 | _use_new_zipfile_serialization=False.
45 | """
46 | print('# Convert to backward compatible pth files ...')
47 | for idx, path in enumerate(paths):
48 | print(f'{idx+1:03d}: Processing {path}')
49 | flag_need_conversion = False
50 | with _open_file_like(path, 'rb') as opened_file:
51 | if _is_zipfile(opened_file):
52 | flag_need_conversion = True
53 |
54 | if flag_need_conversion:
55 | net = torch.load(path, map_location=torch.device('cpu'))
56 | print('\tConverting to compatible pth file...')
57 | torch.save(net, path, _use_new_zipfile_serialization=False)
58 |
59 |
60 | if __name__ == '__main__':
61 | paths = glob.glob('experiments/pretrained_models/*.pth') + glob.glob('experiments/pretrained_models/**/*.pth')
62 | convert_to_backward_compatible_models(paths)
63 | update_sha(paths)
64 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import time
8 | import torch
9 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension
10 |
11 | version_file = 'basicsr/version.py'
12 |
13 |
14 | def readme():
15 | with open('README.md', encoding='utf-8') as f:
16 | content = f.read()
17 | return content
18 |
19 |
20 | def get_git_hash():
21 |
22 | def _minimal_ext_cmd(cmd):
23 | # construct minimal environment
24 | env = {}
25 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
26 | v = os.environ.get(k)
27 | if v is not None:
28 | env[k] = v
29 | # LANGUAGE is used on win32
30 | env['LANGUAGE'] = 'C'
31 | env['LANG'] = 'C'
32 | env['LC_ALL'] = 'C'
33 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
34 | return out
35 |
36 | try:
37 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
38 | sha = out.strip().decode('ascii')
39 | except OSError:
40 | sha = 'unknown'
41 |
42 | return sha
43 |
44 |
45 | def get_hash():
46 | if os.path.exists('.git'):
47 | sha = get_git_hash()[:7]
48 | # currently ignore this
49 | # elif os.path.exists(version_file):
50 | # try:
51 | # from basicsr.version import __version__
52 | # sha = __version__.split('+')[-1]
53 | # except ImportError:
54 | # raise ImportError('Unable to get git version')
55 | else:
56 | sha = 'unknown'
57 |
58 | return sha
59 |
60 |
61 | def write_version_py():
62 | content = """# GENERATED VERSION FILE
63 | # TIME: {}
64 | __version__ = '{}'
65 | __gitsha__ = '{}'
66 | version_info = ({})
67 | """
68 | sha = get_hash()
69 | with open('VERSION', 'r') as f:
70 | SHORT_VERSION = f.read().strip()
71 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
72 |
73 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO)
74 | with open(version_file, 'w') as f:
75 | f.write(version_file_str)
76 |
77 |
78 | def get_version():
79 | with open(version_file, 'r') as f:
80 | exec(compile(f.read(), version_file, 'exec'))
81 | return locals()['__version__']
82 |
83 |
84 | def make_cuda_ext(name, module, sources, sources_cuda=None):
85 | if sources_cuda is None:
86 | sources_cuda = []
87 | define_macros = []
88 | extra_compile_args = {'cxx': []}
89 |
90 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
91 | define_macros += [('WITH_CUDA', None)]
92 | extension = CUDAExtension
93 | extra_compile_args['nvcc'] = [
94 | '-D__CUDA_NO_HALF_OPERATORS__',
95 | '-D__CUDA_NO_HALF_CONVERSIONS__',
96 | '-D__CUDA_NO_HALF2_OPERATORS__',
97 | ]
98 | sources += sources_cuda
99 | else:
100 | print(f'Compiling {name} without CUDA')
101 | extension = CppExtension
102 |
103 | return extension(
104 | name=f'{module}.{name}',
105 | sources=[os.path.join(*module.split('.'), p) for p in sources],
106 | define_macros=define_macros,
107 | extra_compile_args=extra_compile_args)
108 |
109 |
110 | def get_requirements(filename='requirements.txt'):
111 | here = os.path.dirname(os.path.realpath(__file__))
112 | with open(os.path.join(here, filename), 'r') as f:
113 | requires = [line.replace('\n', '') for line in f.readlines()]
114 | return requires
115 |
116 |
117 | if __name__ == '__main__':
118 | cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext
119 | if cuda_ext == 'True':
120 | ext_modules = [
121 | make_cuda_ext(
122 | name='deform_conv_ext',
123 | module='basicsr.ops.dcn',
124 | sources=['src/deform_conv_ext.cpp'],
125 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']),
126 | make_cuda_ext(
127 | name='fused_act_ext',
128 | module='basicsr.ops.fused_act',
129 | sources=['src/fused_bias_act.cpp'],
130 | sources_cuda=['src/fused_bias_act_kernel.cu']),
131 | make_cuda_ext(
132 | name='upfirdn2d_ext',
133 | module='basicsr.ops.upfirdn2d',
134 | sources=['src/upfirdn2d.cpp'],
135 | sources_cuda=['src/upfirdn2d_kernel.cu']),
136 | ]
137 | else:
138 | ext_modules = []
139 |
140 | # write_version_py()
141 | setup(
142 | name='basicsr',
143 | # version=get_version(),
144 | description='Open Source Image and Video Super-Resolution Toolbox',
145 | long_description=readme(),
146 | long_description_content_type='text/markdown',
147 | author='Xintao Wang',
148 | author_email='xintao.wang@outlook.com',
149 | keywords='computer vision, restoration, super resolution',
150 | url='https://github.com/xinntao/BasicSR',
151 | include_package_data=True,
152 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')),
153 | classifiers=[
154 | 'Development Status :: 4 - Beta',
155 | 'License :: OSI Approved :: Apache Software License',
156 | 'Operating System :: OS Independent',
157 | 'Programming Language :: Python :: 3',
158 | 'Programming Language :: Python :: 3.7',
159 | 'Programming Language :: Python :: 3.8',
160 | ],
161 | license='Apache License 2.0',
162 | setup_requires=['cython', 'numpy'],
163 | install_requires=get_requirements(),
164 | ext_modules=ext_modules,
165 | cmdclass={'build_ext': BuildExtension},
166 | zip_safe=False)
167 |
--------------------------------------------------------------------------------
/test_recon.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import cv2
3 | import glob
4 | import os
5 | from tqdm import tqdm
6 | import torch
7 | from yaml import load
8 | import pdb
9 | import numpy as np
10 | import matplotlib as m
11 | import matplotlib.pyplot as plt
12 | from PIL import Image
13 | import pyiqa
14 |
15 | from basicsr.utils import img2tensor, tensor2img, imwrite
16 | from basicsr.archs.adacode_arch import AdaCodeSRNet
17 | from basicsr.archs.femasr_arch import FeMaSRNet
18 | from basicsr.archs.adacode_contrast_arch import AdaCodeSRNet_Contrast
19 | from basicsr.utils.download_util import load_file_from_url
20 |
21 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22 |
23 | #eval metrics
24 | metric_funcs = {}
25 | metric_funcs['psnr'] = pyiqa.create_metric('psnr', device=device, crop_border=4, test_y_channel=True)
26 | metric_funcs['ssim'] = pyiqa.create_metric('ssim', device=device, crop_border=4, test_y_channel=True)
27 | metric_funcs['lpips'] = pyiqa.create_metric('lpips', device=device)
28 |
29 | def main(args):
30 | """Inference demo for FeMaSR
31 | """
32 | metric_results = {'psnr': 0, 'ssim': 0, 'lpips': 0}
33 |
34 | weight_path = args.weight
35 |
36 | # set up the model
37 | model_params = torch.load(weight_path)['params']
38 | codebook_dim = np.array([v.size() for k,v in model_params.items() if 'quantize_group' in k])
39 | codebook_dim_list = []
40 | for k in codebook_dim:
41 | temp = k.tolist()
42 | temp.insert(0,32)
43 | codebook_dim_list.append(temp)
44 | # recon_model = FeMaSRNet(codebook_params=[[32, 512, 256]], LQ_stage=False, scale_factor=2).to(device)
45 | recon_model = AdaCodeSRNet_Contrast(codebook_params=codebook_dim_list, LQ_stage=False, AdaCode_stage=True, batch_size=2, weight_softmax=False).to(device)
46 | recon_model.load_state_dict(model_params, strict=False)
47 | recon_model.eval()
48 |
49 | os.makedirs(args.output, exist_ok=True)
50 | if os.path.isfile(args.input):
51 | paths = [args.input]
52 | else:
53 | paths = sorted(glob.glob(os.path.join(args.input, '*.png')))
54 |
55 | pbar = tqdm(total=len(paths), unit='image')
56 | for idx, path in enumerate(paths):
57 | # try:
58 | img_name = os.path.basename(path)
59 | pbar.set_description(f'Test {img_name}')
60 |
61 | # recon
62 | img_HR = cv2.imread(path, cv2.IMREAD_UNCHANGED)
63 | img_HR_tensor = img2tensor(img_HR).to(device) / 255.
64 | img_HR_tensor = img_HR_tensor.unsqueeze(0)
65 |
66 | max_size = args.max_size ** 2
67 | h, w = img_HR_tensor.shape[2:]
68 | if h * w < max_size:
69 | output_HR = recon_model.test(img_HR_tensor, vis_weight=args.vis_weight)
70 | else:
71 | output_HR = recon_model.test_tile(img_HR_tensor, vis_weight=args.vis_weight)
72 |
73 | if args.vis_weight:
74 | weight_map = output_HR[1]
75 | vis_weight(weight_map, os.path.join(args.output, 'weight_map', img_name))
76 | output = output_HR[0]
77 | else:
78 | output = output_HR
79 | output_img = tensor2img(output)
80 |
81 | imwrite(output_img, os.path.join(args.output, f'{img_name}'))
82 |
83 | for name in metric_funcs.keys():
84 | metric_results[name] += metric_funcs[name](img_HR_tensor, output).item()
85 | pbar.update(1)
86 | # except:
87 | # print(path, ' fails.')
88 | pbar.close()
89 |
90 | for name in metric_results.keys():
91 | metric_results[name] /= len(paths)
92 | print('Result for {}'.format(args.weight))
93 | print(metric_results)
94 |
95 |
96 | def vis_weight(weight, save_path):
97 | # weight: B x n x 1 x H x W
98 | weight = weight.cpu().numpy()
99 | # normalize weights
100 | # norm_weight = weight
101 | norm_weight = (weight - weight.mean()) / weight.std() / 2
102 | norm_weight = np.abs(norm_weight)
103 | norm_weight *= 255
104 | norm_weight = np.clip(norm_weight, 0, 255)
105 | norm_weight = norm_weight.astype(np.uint8)
106 | # visualize
107 | display_grid = np.zeros((weight.shape[3], (weight.shape[4]+1)*weight.shape[1]-1))
108 | for img_id in range(len(norm_weight)):
109 | for c in range(norm_weight.shape[1]):
110 | display_grid[:, c*weight.shape[4]+c:(c+1)*weight.shape[4]+c] = norm_weight[img_id, c, 0, :, :]
111 | # weight_path = save_path.split('.')[0] + '_{}.png'.format(str(c))
112 | # Image.fromarray(norm_weight[img_id, c, 0, :, :]).save(weight_path)
113 | plt.figure(figsize=(30,150))
114 | plt.axis('off')
115 | plt.imshow(display_grid)
116 | plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
117 | plt.close()
118 |
119 |
120 |
121 |
122 | if __name__ == '__main__':
123 | parser = argparse.ArgumentParser()
124 | parser.add_argument('-i', '--input', type=str, default='inputs', help='Input image or folder')
125 | parser.add_argument('-w', '--weight', type=str, default=None, help='path for model weights')
126 | parser.add_argument('-o', '--output', type=str, default='results', help='Output folder')
127 | parser.add_argument('--suffix', type=str, default='', help='Suffix of the restored image')
128 | parser.add_argument('--max_size', type=int, default=600000, help='Max image size for whole image inference, otherwise use tiled_test')
129 | parser.add_argument('--vis_weight', action='store_true', help='visualize weight map')
130 | args = parser.parse_args()
131 |
132 | if args.vis_weight:
133 | os.makedirs(os.path.join(args.output, 'weight_map'), exist_ok=True)
134 |
135 | main(args)
136 |
--------------------------------------------------------------------------------
/testset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/.DS_Store
--------------------------------------------------------------------------------
/testset/inpaint/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/.DS_Store
--------------------------------------------------------------------------------
/testset/inpaint/0801_s009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0801_s009.png
--------------------------------------------------------------------------------
/testset/inpaint/0804_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0804_s003.png
--------------------------------------------------------------------------------
/testset/inpaint/0804_s006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0804_s006.png
--------------------------------------------------------------------------------
/testset/inpaint/0808_s010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0808_s010.png
--------------------------------------------------------------------------------
/testset/inpaint/0809_s006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0809_s006.png
--------------------------------------------------------------------------------
/testset/inpaint/0813_s009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0813_s009.png
--------------------------------------------------------------------------------
/testset/inpaint/0817_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0817_s003.png
--------------------------------------------------------------------------------
/testset/inpaint/0817_s012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0817_s012.png
--------------------------------------------------------------------------------
/testset/inpaint/0830_s007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0830_s007.png
--------------------------------------------------------------------------------
/testset/inpaint/0831_s006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0831_s006.png
--------------------------------------------------------------------------------
/testset/inpaint/0833_s003.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0833_s003.png
--------------------------------------------------------------------------------
/testset/inpaint/0836_s001.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0836_s001.png
--------------------------------------------------------------------------------
/testset/inpaint/0836_s010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0836_s010.png
--------------------------------------------------------------------------------
/testset/inpaint/0839_s012.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0839_s012.png
--------------------------------------------------------------------------------
/testset/inpaint/0842_s002.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0842_s002.png
--------------------------------------------------------------------------------
/testset/inpaint/0848_s009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0848_s009.png
--------------------------------------------------------------------------------
/testset/inpaint/0849_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0849_s004.png
--------------------------------------------------------------------------------
/testset/inpaint/0852_s010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0852_s010.png
--------------------------------------------------------------------------------
/testset/inpaint/0856_s006.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0856_s006.png
--------------------------------------------------------------------------------
/testset/inpaint/0856_s010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0856_s010.png
--------------------------------------------------------------------------------
/testset/inpaint/0863_s004.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0863_s004.png
--------------------------------------------------------------------------------
/testset/inpaint/0865_s011.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0865_s011.png
--------------------------------------------------------------------------------
/testset/inpaint/0866_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0866_s005.png
--------------------------------------------------------------------------------
/testset/inpaint/0870_s005.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0870_s005.png
--------------------------------------------------------------------------------
/testset/inpaint/0874_s007.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/0874_s007.png
--------------------------------------------------------------------------------
/testset/inpaint/64010.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/inpaint/64010.png
--------------------------------------------------------------------------------
/testset/sr/0003.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0003.jpg
--------------------------------------------------------------------------------
/testset/sr/0014.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0014.jpg
--------------------------------------------------------------------------------
/testset/sr/0015.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0015.jpg
--------------------------------------------------------------------------------
/testset/sr/0030.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0030.jpg
--------------------------------------------------------------------------------
/testset/sr/0032.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0032.jpg
--------------------------------------------------------------------------------
/testset/sr/0054.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0054.jpg
--------------------------------------------------------------------------------
/testset/sr/0068.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/0068.jpg
--------------------------------------------------------------------------------
/testset/sr/49370.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/49370.png
--------------------------------------------------------------------------------
/testset/sr/ADE_val_00000015.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ADE_val_00000015.jpg
--------------------------------------------------------------------------------
/testset/sr/ADE_val_00000114.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ADE_val_00000114.jpg
--------------------------------------------------------------------------------
/testset/sr/Canon_004_LR4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/Canon_004_LR4.png
--------------------------------------------------------------------------------
/testset/sr/Canon_045_LR4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/Canon_045_LR4.png
--------------------------------------------------------------------------------
/testset/sr/OST_009.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_009.png
--------------------------------------------------------------------------------
/testset/sr/OST_020.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_020.png
--------------------------------------------------------------------------------
/testset/sr/OST_120.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/OST_120.png
--------------------------------------------------------------------------------
/testset/sr/building.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/building.png
--------------------------------------------------------------------------------
/testset/sr/butterfly.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/butterfly.png
--------------------------------------------------------------------------------
/testset/sr/butterfly2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/butterfly2.png
--------------------------------------------------------------------------------
/testset/sr/chip.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/chip.png
--------------------------------------------------------------------------------
/testset/sr/comic1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic1.png
--------------------------------------------------------------------------------
/testset/sr/comic2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic2.png
--------------------------------------------------------------------------------
/testset/sr/comic3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/comic3.png
--------------------------------------------------------------------------------
/testset/sr/computer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/computer.png
--------------------------------------------------------------------------------
/testset/sr/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/dog.png
--------------------------------------------------------------------------------
/testset/sr/dped_crop00061.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/dped_crop00061.png
--------------------------------------------------------------------------------
/testset/sr/foreman.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/foreman.png
--------------------------------------------------------------------------------
/testset/sr/frog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/frog.png
--------------------------------------------------------------------------------
/testset/sr/oldphoto3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/oldphoto3.png
--------------------------------------------------------------------------------
/testset/sr/oldphoto6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/oldphoto6.png
--------------------------------------------------------------------------------
/testset/sr/painting.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/painting.png
--------------------------------------------------------------------------------
/testset/sr/pattern.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/pattern.png
--------------------------------------------------------------------------------
/testset/sr/ppt3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/ppt3.png
--------------------------------------------------------------------------------
/testset/sr/tiger.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kechunl/AdaCode/123b22958021b2342f2cc1c5a3672a3d8eba8f1a/testset/sr/tiger.png
--------------------------------------------------------------------------------
/vis_codebook.py:
--------------------------------------------------------------------------------
1 | from itertools import count
2 | from tokenize import PlainToken
3 | import torch
4 | import torchvision.transforms as tf
5 | from torchvision.utils import save_image
6 | import torchvision.utils as tvu
7 |
8 | import numpy as np
9 | import os
10 | import random
11 | from tqdm import tqdm
12 | import cv2
13 | from matplotlib import pyplot as plt
14 | import seaborn as sns
15 |
16 | from basicsr.utils.misc import set_random_seed
17 | from basicsr.utils import img2tensor, tensor2img, imwrite
18 | from basicsr.archs.femasr_arch import FeMaSRNet
19 |
20 |
21 | def reconstruct_ost(model, data_dir, save_dir, maxnum=100):
22 |
23 | texture_classes = list(os.listdir(data_dir))
24 | texture_classes.remove('manga109')
25 | code_idx_dict = {}
26 | for tc in texture_classes:
27 | img_name_list = os.listdir(os.path.join(data_dir, tc))
28 | random.shuffle(img_name_list)
29 | tmp_code_idx_list = []
30 | for img_name in tqdm(img_name_list[:maxnum]):
31 | img_path = os.path.join(data_dir, tc, img_name)
32 |
33 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
34 | img_tensor = img2tensor(img).to(device) / 255.
35 | img_tensor = img_tensor.unsqueeze(0)
36 |
37 | rec, _, _, indices = model(img_tensor)
38 | indices = indices[0]
39 |
40 | save_path = os.path.join(save_dir, tc, img_name)
41 | if not os.path.exists(os.path.join(save_dir, tc)):
42 | os.makedirs(os.path.join(save_dir, tc), exist_ok=True)
43 | imwrite(tensor2img(rec), save_path)
44 |
45 | save_org_dir = save_dir.replace('rec', 'org')
46 | save_org_path = os.path.join(save_org_dir, tc, img_name)
47 | if not os.path.exists(os.path.join(save_org_dir, tc)):
48 | os.makedirs(os.path.join(save_org_dir, tc), exist_ok=True)
49 | imwrite(tensor2img(img_tensor), save_org_path)
50 |
51 | tmp_code_idx_list.append(indices)
52 | code_idx_dict[tc] = tmp_code_idx_list
53 |
54 | torch.save(code_idx_dict, './tmp_code_vis/code_idx_dict.pth')
55 |
56 |
57 | def vis_hrp(model, code_list_path, save_dir, samples_each_class=16):
58 | code_idx_dict = torch.load(code_list_path)
59 | classes = list(code_idx_dict.keys())
60 |
61 | latent_size = 8
62 | color_palette = sns.color_palette()
63 | for idx, (key, value) in enumerate(code_idx_dict.items()):
64 | all_idx = torch.cat([x.flatten() for x in value])
65 |
66 | plt.figure(figsize=(16, 8))
67 | sns.histplot(all_idx.cpu().numpy(), color=color_palette[idx])
68 | plt.xlabel(key, fontsize=30)
69 | plt.ylabel('Count', fontsize=30)
70 | plt.savefig(f'./tmp_code_vis/code_stat/code_index_bincount_{key}.pdf')
71 |
72 | counts = all_idx.bincount()
73 | dist = counts / sum(counts)
74 | dist = dist.cpu().numpy()
75 |
76 | vis_tex_samples = []
77 | for sid in range(32):
78 | vis_tex_map = np.random.choice(np.arange(dist.shape[0]), latent_size ** 2, p=dist)
79 | vis_tex_map = torch.from_numpy(vis_tex_map).to(all_idx)
80 | vis_tex_map = vis_tex_map.reshape(1, 1, latent_size, latent_size)
81 | vis_tex_img = model.decode_indices(vis_tex_map)
82 | vis_tex_samples.append(vis_tex_img)
83 | vis_tex_samples = torch.cat(vis_tex_samples, dim=0)
84 | save_image(vis_tex_samples, f'./tmp_code_vis/tmp_tex_vis/{key}.jpg', normalize=True, nrow=16)
85 |
86 |
87 | def vis_single_code(model, codenum, save_path, up_factor=4):
88 | code_idx = torch.arange(codenum).reshape(codenum, 1, 1, 1)
89 | code_idx = code_idx.repeat(1, 1, up_factor, up_factor)
90 | output_img = model.decode_indices(code_idx)
91 | output_img = tvu.make_grid(output_img, nrow=32)
92 | save_image(output_img, save_path)
93 |
94 | def vis_rand_single_code(model, codenum, save_path, up_factor=4, vis_num=5):
95 | code_idx = torch.randint(codenum, (vis_num,)).reshape(vis_num, 1, 1, 1)
96 | code_idx = code_idx.repeat(1, 1, up_factor, up_factor)
97 | output_img = model.decode_indices(code_idx)
98 | output_img = tvu.make_grid(output_img, nrow=32)
99 | save_image(output_img, save_path)
100 |
101 |
102 | if __name__ == '__main__':
103 | # set random seeds to ensure reproducibility
104 | set_random_seed(123)
105 |
106 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
107 |
108 | # set up the model
109 | weight_path = './experiments/pretrained_models/S1_c2_512x256_net_g_best.pth'
110 | codebook_size = os.path.basename(weight_path).split('_')[2].split('x')
111 | vqgan = FeMaSRNet(codebook_params=[[32, int(codebook_size[0]), int(codebook_size[1])]], LQ_stage=False).to(device)
112 | vqgan.load_state_dict(torch.load(weight_path)['params'], strict=False)
113 | vqgan.eval()
114 |
115 | os.makedirs('results/codebook_vis', exist_ok=True)
116 | vis_single_code(vqgan, int(codebook_size[0]), 'results/codebook_vis/{}.png'.format(os.path.basename(weight_path).split('.')[0]))
117 | # vis_rand_single_code(vqgan, 256, 'codebook_vis/sample_ffhq.png', vis_num=10)
118 |
119 | # reconstruct_ost(vqgan, '../datasets/SR_OST_datasets/OutdoorSceneTrain_v2/', './tmp_code_vis/ost_rec', maxnum=1000)
120 | # vis_hrp(vqgan, './tmp_code_vis/code_idx_dict.pth', './tmp_code_vis/')
121 |
--------------------------------------------------------------------------------