├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yaml ├── CITATION.cff ├── LICENSE.txt ├── MANIFEST.in ├── README.md ├── VERSION ├── basicsr ├── __init__.py ├── archs │ ├── __init__.py │ ├── arch_util.py │ ├── discriminator_arch.py │ ├── mlpmixer_util.py │ └── mma_arch.py ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── degradations.py │ ├── imagenet_paired_dataset.py │ ├── paired_image_dataset.py │ ├── prefetch_dataloader.py │ └── transforms.py ├── losses │ ├── __init__.py │ ├── basic_loss.py │ ├── gan_loss.py │ └── loss_util.py ├── metrics │ ├── __init__.py │ ├── fid.py │ ├── metric_util.py │ ├── psnr_ssim.py │ └── test_metrics │ │ └── test_psnr_ssim.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── lr_scheduler.py │ └── sr_model.py ├── ops │ ├── __init__.py │ ├── dcn │ │ ├── __init__.py │ │ ├── deform_conv.py │ │ └── src │ │ │ ├── deform_conv_cuda.cpp │ │ │ ├── deform_conv_cuda_kernel.cu │ │ │ └── deform_conv_ext.cpp │ ├── fused_act │ │ ├── __init__.py │ │ ├── fused_act.py │ │ └── src │ │ │ ├── fused_bias_act.cpp │ │ │ └── fused_bias_act_kernel.cu │ └── upfirdn2d │ │ ├── __init__.py │ │ ├── src │ │ ├── upfirdn2d.cpp │ │ └── upfirdn2d_kernel.cu │ │ └── upfirdn2d.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── color_util.py │ ├── diffjpeg.py │ ├── dist_util.py │ ├── download_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_process_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ ├── options.py │ ├── plot_util.py │ └── registry.py ├── main.jpg ├── options ├── test │ └── MMA │ │ ├── test_MMA_x2.yml │ │ ├── test_MMA_x3.yml │ │ └── test_MMA_x4.yml └── train │ └── MMA │ ├── train_MMA_x2_finetune.yml │ ├── train_MMA_x2_pretrain.yml │ ├── train_MMA_x3_finetune.yml │ ├── train_MMA_x3_pretrain.yml │ ├── train_MMA_x4_finetune.yml │ └── train_MMA_x4_pretrain.yml ├── requirements.txt ├── results.jpg ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | wandb/* 7 | tmp/* 8 | 9 | docs/api 10 | scripts/__init__.py 11 | 12 | *.DS_Store 13 | .idea 14 | 15 | # ignored files 16 | version.py 17 | 18 | # ignored files with suffix 19 | *.html 20 | *.png 21 | *.jpeg 22 | *.gif 23 | *.pth 24 | *.zip 25 | 26 | # template 27 | 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | MANIFEST 54 | 55 | # PyInstaller 56 | # Usually these files are written by a python script from a template 57 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 58 | *.manifest 59 | *.spec 60 | 61 | # Installer logs 62 | pip-log.txt 63 | pip-delete-this-directory.txt 64 | 65 | # Unit test / coverage reports 66 | htmlcov/ 67 | .tox/ 68 | .coverage 69 | .coverage.* 70 | .cache 71 | nosetests.xml 72 | coverage.xml 73 | *.cover 74 | .hypothesis/ 75 | .pytest_cache/ 76 | 77 | # Translations 78 | *.mo 79 | *.pot 80 | 81 | # Django stuff: 82 | *.log 83 | local_settings.py 84 | db.sqlite3 85 | 86 | # Flask stuff: 87 | instance/ 88 | .webassets-cache 89 | 90 | # Scrapy stuff: 91 | .scrapy 92 | 93 | # Sphinx documentation 94 | docs/_build/ 95 | 96 | # PyBuilder 97 | target/ 98 | 99 | # Jupyter Notebook 100 | .ipynb_checkpoints 101 | 102 | # pyenv 103 | .python-version 104 | 105 | # celery beat schedule file 106 | celerybeat-schedule 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--config=setup.cfg", "--ignore=W504, W503"] 8 | 9 | # modify known_third_party 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v2.2.0 12 | hooks: 13 | - id: seed-isort-config 14 | 15 | # isort 16 | - repo: https://github.com/timothycrosley/isort 17 | rev: 5.2.2 18 | hooks: 19 | - id: isort 20 | 21 | # yapf 22 | - repo: https://github.com/pre-commit/mirrors-yapf 23 | rev: v0.30.0 24 | hooks: 25 | - id: yapf 26 | 27 | # codespell 28 | - repo: https://github.com/codespell-project/codespell 29 | rev: v2.1.0 30 | hooks: 31 | - id: codespell 32 | 33 | # pre-commit-hooks 34 | - repo: https://github.com/pre-commit/pre-commit-hooks 35 | rev: v3.2.0 36 | hooks: 37 | - id: trailing-whitespace # Trim trailing whitespace 38 | - id: check-yaml # Attempt to load all yaml files to verify syntax 39 | - id: check-merge-conflict # Check for files that contain merge conflict strings 40 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 41 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 42 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 43 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 44 | args: ["--remove"] 45 | - id: mixed-line-ending # Replace or check mixed line ending 46 | args: ["--fix=lf"] 47 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.8" 13 | # You can also specify other tool versions: 14 | # nodejs: "16" 15 | # rust: "1.55" 16 | # golang: "1.17" 17 | 18 | # Build documentation in the docs/ directory with Sphinx 19 | sphinx: 20 | configuration: docs/conf.py 21 | 22 | # If using Sphinx, optionally build your docs in additional formats such as PDF 23 | # formats: 24 | # - pdf 25 | 26 | # Optionally declare the Python requirements required to build your docs 27 | python: 28 | install: 29 | - requirements: docs/requirements.txt 30 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you use this project, please cite it as below." 3 | title: "BasicSR: Open Source Image and Video Restoration Toolbox" 4 | version: 1.3.5 5 | date-released: 2022-02-16 6 | url: "https://github.com/XPixelGroup/BasicSR" 7 | license: Apache-2.0 8 | authors: 9 | - family-names: Wang 10 | given-names: Xintao 11 | - family-names: Xie 12 | given-names: Liangbin 13 | - family-names: Yu 14 | given-names: Ke 15 | - family-names: Chan 16 | given-names: Kelvin C.K. 17 | - family-names: Loy 18 | given-names: Chen Change 19 | - family-names: Dong 20 | given-names: Chao 21 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp 2 | include basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/src/*.cpp 3 | include basicsr/ops/upfirdn2d/src/*.cu basicsr/ops/upfirdn2d/src/*.cpp 4 | include basicsr/metrics/niqe_pris_params.npz 5 | include VERSION 6 | include requirements.txt 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Activating Wider Areas in Image Super-Resolution 2 | 3 | Cheng Cheng, Hang Wang, [Hongbin Sun](https://gr.xjtu.edu.cn/en/web/hsun/home) 4 | 5 | #### 🔥🔥🔥 News 6 | 7 | - **2023-07-16:** This repo is released. 8 | 9 | 10 | [[arXiv](http://arxiv.org/abs/2403.08330)] 11 | 12 | 13 | --- 14 | 15 | > **Abstract:** *The prevalence of convolution neural networks (CNNs) and vision transformers (ViTs) has markedly revolutionized the area of single-image super-resolution (SISR). To further boost the SR performances, several techniques, such as residual learning and attention mechanism, are introduced, which can be largely attributed to a wider range of activated area, that is, the input pixels that strongly influence the SR results. However, the possibility of further improving SR performance through another versatile vision backbone remains an unresolved challenge. To address this issue, in this paper, we unleash the representation potential of the modern state space model, i.e., Vision Mamba (Vim), in the context of SISR. Specifically, we present three recipes for better utilization of Vim-based models: 1) Integration into a MetaFormer-style block; 2) Pre-training on a larger and broader dataset; 3) Employing complementary attention mechanism, upon which we introduce the MMA. 16 | The resulting network MMA is capable of finding the most relevant and representative input pixels to reconstruct the corresponding high-resolution images. Comprehensive experimental analysis reveals that MMA not only achieves competitive or even superior performance compared to state-of-the-art SISR methods but also maintains relatively low memory and computational overheads (e.g., +0.5 dB PSNR elevation on Manga109 dataset with 19.8 M parameters at the scale of 2). Furthermore, MMA proves its versatility in lightweight SR applications. Through this work, we aim to illuminate the potential applications of state space models in the broader realm of image processing rather than SISR, encouraging further exploration in this innovative direction.* 17 | 18 | ![Intro](main.jpg) 19 | 20 | ![Intro](results.jpg) 21 | 22 | 23 | ## TODO 24 | - Update lightweight results 25 | 26 | 27 | ## Dependencies 28 | 29 | - Python 3.10 30 | - PyTorch 2.1.1 31 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 32 | 33 | ``` 34 | # Clone the github repo and go to the default directory 'DAT'. 35 | 36 | git clone https://github.com/ArsenalCheng/MMA.git 37 | conda create -n MMA python=3.8 38 | conda activate MMA 39 | pip install -r requirements.txt 40 | python setup.py develop 41 | ``` 42 | 43 | 44 | ## Contents 45 | 46 | 1. [Datasets](#Datasets) 47 | 1. [Models](#Models) 48 | 1. [Training](#Training) 49 | 1. [Testing](#Testing) 50 | 1. [Results](#Results) 51 | 1. [Citation](#Citation) 52 | 1. [Acknowledgements](#Acknowledgements) 53 | 54 | --- 55 | 56 | ## Datasets 57 | 58 | Used training and testing sets can be downloaded as follows: 59 | 60 | | Training Set | Testing Set | 61 | | :----------------------------------------------------------- | :----------------------------------------------------------: | 62 | | [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (800 training images, 100 validation images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) [complete training dataset DF2K: [Google Drive](https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view?usp=share_link)] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset: [Google Drive](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing)] | 63 | 64 | 65 | ## Models 66 | 67 | | Method | Scale | Dataset | PSNR (dB) | SSIM | Model Zoo | 68 | | :-------- | :----: | :-------: | :------: | :-------: | :----: | 69 | | MMA | 2 | Urban100 | 34.13 | 0.9446 | [Google Drive](https://drive.google.com/drive/folders/1OW9fdzqrMh-j_OA6VC77ZMvKCnOtfiqq?usp=drive_link) | 70 | | MMA | 3 | Urban100 | 29.93 | 0.8829 | [Google Drive](https://drive.google.com/drive/folders/1wF7kdCV_JdwKghzeS-zzUOPt3dMUSzqG?usp=drive_link)| 71 | | MMA | 4 | Urban100 | 27.64 | 0.8272 | [Google Drive](https://drive.google.com/drive/folders/1HDULsB8jJKLNfrV0_Xs6Us4CMtH4T3_I?usp=drive_link)| 72 | 73 | 74 | ## Training 75 | 76 | - Download training (DF2K, already processed) and testing (Set5, Set14, BSD100, Urban100, Manga109, already processed) datasets, place them in `datasets/`. 77 | 78 | - Run the following scripts. The training configuration is in `options/train/`. 79 | 80 | ``` 81 | # MMA-x2, input=64x64, 8 GPUs 82 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x2_pretrain.yml --launcher pytorch 83 | # Then change the "pretrain_network_g" in options/train/MMA/train_MMA_x2_finetune.yml to the best ckp during pretraining. 84 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x2_finetune.yml --launcher pytorch 85 | 86 | # MMA-x3, input=64x64, 8 GPUs 87 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x3_pretrain.yml --launcher pytorch 88 | # Then change the "pretrain_network_g" in options/train/MMA/train_MMA_x3_finetune.yml to the best ckp during pretraining. 89 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x3_finetune.yml --launcher pytorch 90 | 91 | # MMA-x4, input=64x64, 8 GPUs 92 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x4_pretrain.yml --launcher pytorch 93 | # Then change the "pretrain_network_g" in options/train/MMA/train_MMA_x4_finetune.yml to the best ckp during pretraining. 94 | torchrun --nproc_per_node=8 --master_port=4321 basicsr/train.py -opt options/train/MMA/train_MMA_x4_finetune.yml --launcher pytorch 95 | 96 | ``` 97 | 98 | - The training experiment is in `experiments/`. 99 | 100 | ## Testing 101 | 102 | ### Test images with HR 103 | 104 | - Download the pre-trained [models](https://drive.google.com/drive/folders/1j32og3Yn7z_je1m_yR3qLEz2yqOLuhv4?usp=drive_link) and place them in `experiments/pretrained_models/`. 105 | 106 | We provide pre-trained models for image SR: MMA (x2, x3, x4). 107 | 108 | - Download testing (Set5, Set14, BSD100, Urban100, Manga109) datasets, place them in `datasets/`. 109 | 110 | - Run the following scripts. The testing configuration is in `options/test/` (e.g., [test_MMA_x2.yml](options/test/MMA/test_MMA_x2.yml)). 111 | 112 | ```shell 113 | # MMA, reproduces results in Table 1 of the main paper 114 | 115 | python basicsr/test.py -opt options/test/test_MMA_x2.yml 116 | python basicsr/test.py -opt options/test/test_MMA_x3.yml 117 | python basicsr/test.py -opt options/test/test_MMA_x4.yml 118 | ``` 119 | 120 | - The output is in `results/`. 121 | 122 | 123 | 124 | 125 | ## Citation 126 | 127 | If you find the code helpful in your research or work, please cite the following paper(s). 128 | 129 | ``` 130 | @misc{cheng2024activating, 131 | title={Activating Wider Areas in Image Super-Resolution}, 132 | author={Cheng Cheng and Hang Wang and Hongbin Sun}, 133 | year={2024}, 134 | eprint={2403.08330}, 135 | archivePrefix={arXiv}, 136 | primaryClass={cs.CV} 137 | } 138 | ``` 139 | 140 | ## Acknowledgements 141 | 142 | This code is built on [BasicSR](https://github.com/XPixelGroup/BasicSR). 143 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.4.2 2 | -------------------------------------------------------------------------------- /basicsr/__init__.py: -------------------------------------------------------------------------------- 1 | # https://github.com/xinntao/BasicSR 2 | # flake8: noqa 3 | from .archs import * 4 | from .data import * 5 | from .losses import * 6 | from .metrics import * 7 | from .models import * 8 | from .ops import * 9 | from .test import * 10 | from .train import * 11 | from .utils import * 12 | from .version import __gitsha__, __version__ 13 | -------------------------------------------------------------------------------- /basicsr/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with '_arch.py' 12 | arch_folder = osp.dirname(osp.abspath(__file__)) 13 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 14 | # import all the arch modules 15 | _arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] 16 | 17 | 18 | def build_network(opt): 19 | opt = deepcopy(opt) 20 | network_type = opt.pop('type') 21 | net = ARCH_REGISTRY.get(network_type)(**opt) 22 | logger = get_root_logger() 23 | logger.info(f'Network [{net.__class__.__name__}] is created.') 24 | return net 25 | -------------------------------------------------------------------------------- /basicsr/archs/discriminator_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | from torch.nn.utils import spectral_norm 4 | 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class VGGStyleDiscriminator(nn.Module): 10 | """VGG style discriminator with input size 128 x 128 or 256 x 256. 11 | 12 | It is used to train SRGAN, ESRGAN, and VideoGAN. 13 | 14 | Args: 15 | num_in_ch (int): Channel number of inputs. Default: 3. 16 | num_feat (int): Channel number of base intermediate features.Default: 64. 17 | """ 18 | 19 | def __init__(self, num_in_ch, num_feat, input_size=128): 20 | super(VGGStyleDiscriminator, self).__init__() 21 | self.input_size = input_size 22 | assert self.input_size == 128 or self.input_size == 256, ( 23 | f'input size must be 128 or 256, but received {input_size}') 24 | 25 | self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True) 26 | self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False) 27 | self.bn0_1 = nn.BatchNorm2d(num_feat, affine=True) 28 | 29 | self.conv1_0 = nn.Conv2d(num_feat, num_feat * 2, 3, 1, 1, bias=False) 30 | self.bn1_0 = nn.BatchNorm2d(num_feat * 2, affine=True) 31 | self.conv1_1 = nn.Conv2d(num_feat * 2, num_feat * 2, 4, 2, 1, bias=False) 32 | self.bn1_1 = nn.BatchNorm2d(num_feat * 2, affine=True) 33 | 34 | self.conv2_0 = nn.Conv2d(num_feat * 2, num_feat * 4, 3, 1, 1, bias=False) 35 | self.bn2_0 = nn.BatchNorm2d(num_feat * 4, affine=True) 36 | self.conv2_1 = nn.Conv2d(num_feat * 4, num_feat * 4, 4, 2, 1, bias=False) 37 | self.bn2_1 = nn.BatchNorm2d(num_feat * 4, affine=True) 38 | 39 | self.conv3_0 = nn.Conv2d(num_feat * 4, num_feat * 8, 3, 1, 1, bias=False) 40 | self.bn3_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 41 | self.conv3_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 42 | self.bn3_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 43 | 44 | self.conv4_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) 45 | self.bn4_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 46 | self.conv4_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 47 | self.bn4_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 48 | 49 | if self.input_size == 256: 50 | self.conv5_0 = nn.Conv2d(num_feat * 8, num_feat * 8, 3, 1, 1, bias=False) 51 | self.bn5_0 = nn.BatchNorm2d(num_feat * 8, affine=True) 52 | self.conv5_1 = nn.Conv2d(num_feat * 8, num_feat * 8, 4, 2, 1, bias=False) 53 | self.bn5_1 = nn.BatchNorm2d(num_feat * 8, affine=True) 54 | 55 | self.linear1 = nn.Linear(num_feat * 8 * 4 * 4, 100) 56 | self.linear2 = nn.Linear(100, 1) 57 | 58 | # activation function 59 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 60 | 61 | def forward(self, x): 62 | assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.') 63 | 64 | feat = self.lrelu(self.conv0_0(x)) 65 | feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2 66 | 67 | feat = self.lrelu(self.bn1_0(self.conv1_0(feat))) 68 | feat = self.lrelu(self.bn1_1(self.conv1_1(feat))) # output spatial size: /4 69 | 70 | feat = self.lrelu(self.bn2_0(self.conv2_0(feat))) 71 | feat = self.lrelu(self.bn2_1(self.conv2_1(feat))) # output spatial size: /8 72 | 73 | feat = self.lrelu(self.bn3_0(self.conv3_0(feat))) 74 | feat = self.lrelu(self.bn3_1(self.conv3_1(feat))) # output spatial size: /16 75 | 76 | feat = self.lrelu(self.bn4_0(self.conv4_0(feat))) 77 | feat = self.lrelu(self.bn4_1(self.conv4_1(feat))) # output spatial size: /32 78 | 79 | if self.input_size == 256: 80 | feat = self.lrelu(self.bn5_0(self.conv5_0(feat))) 81 | feat = self.lrelu(self.bn5_1(self.conv5_1(feat))) # output spatial size: / 64 82 | 83 | # spatial size: (4, 4) 84 | feat = feat.view(feat.size(0), -1) 85 | feat = self.lrelu(self.linear1(feat)) 86 | out = self.linear2(feat) 87 | return out 88 | 89 | 90 | @ARCH_REGISTRY.register(suffix='basicsr') 91 | class UNetDiscriminatorSN(nn.Module): 92 | """Defines a U-Net discriminator with spectral normalization (SN) 93 | 94 | It is used in Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data. 95 | 96 | Arg: 97 | num_in_ch (int): Channel number of inputs. Default: 3. 98 | num_feat (int): Channel number of base intermediate features. Default: 64. 99 | skip_connection (bool): Whether to use skip connections between U-Net. Default: True. 100 | """ 101 | 102 | def __init__(self, num_in_ch, num_feat=64, skip_connection=True): 103 | super(UNetDiscriminatorSN, self).__init__() 104 | self.skip_connection = skip_connection 105 | norm = spectral_norm 106 | # the first convolution 107 | self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1) 108 | # downsample 109 | self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False)) 110 | self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False)) 111 | self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False)) 112 | # upsample 113 | self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False)) 114 | self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False)) 115 | self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False)) 116 | # extra convolutions 117 | self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 118 | self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False)) 119 | self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1) 120 | 121 | def forward(self, x): 122 | # downsample 123 | x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True) 124 | x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True) 125 | x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True) 126 | x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True) 127 | 128 | # upsample 129 | x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False) 130 | x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True) 131 | 132 | if self.skip_connection: 133 | x4 = x4 + x2 134 | x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False) 135 | x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True) 136 | 137 | if self.skip_connection: 138 | x5 = x5 + x1 139 | x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False) 140 | x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True) 141 | 142 | if self.skip_connection: 143 | x6 = x6 + x0 144 | 145 | # extra convolutions 146 | out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True) 147 | out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True) 148 | out = self.conv9(out) 149 | 150 | return out 151 | -------------------------------------------------------------------------------- /basicsr/archs/mlpmixer_util.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from einops import rearrange 6 | from mamba_ssm import Mamba 7 | 8 | 9 | class ChannelAttention(nn.Module): 10 | """Channel attention used in RCAN. 11 | Args: 12 | num_feat (int): Channel number of intermediate features. 13 | squeeze_factor (int): Channel squeeze factor. Default: 16. 14 | """ 15 | 16 | def __init__(self, num_feat, squeeze_factor=16): 17 | super(ChannelAttention, self).__init__() 18 | self.attention = nn.Sequential( 19 | nn.AdaptiveAvgPool2d(1), 20 | nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0), 21 | nn.ReLU(inplace=True), 22 | nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), 23 | nn.Sigmoid()) 24 | 25 | def forward(self, x): 26 | y = self.attention(x) 27 | return x * y 28 | 29 | 30 | class CAB(nn.Module): 31 | 32 | def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30): 33 | super(CAB, self).__init__() 34 | 35 | self.cab = nn.Sequential( 36 | nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1), 37 | nn.GELU(), 38 | nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1), 39 | ChannelAttention(num_feat, squeeze_factor) 40 | ) 41 | 42 | def forward(self, x): 43 | return self.cab(x) 44 | 45 | 46 | def to_3d(x): 47 | return rearrange(x, 'b c h w -> b (h w) c') 48 | 49 | 50 | def to_4d(x, h, w): 51 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 52 | 53 | 54 | class BiasFree_LayerNorm(nn.Module): 55 | def __init__(self, normalized_shape): 56 | super(BiasFree_LayerNorm, self).__init__() 57 | if isinstance(normalized_shape, numbers.Integral): 58 | normalized_shape = (normalized_shape,) 59 | normalized_shape = torch.Size(normalized_shape) 60 | 61 | assert len(normalized_shape) == 1 62 | 63 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 64 | self.normalized_shape = normalized_shape 65 | 66 | def forward(self, x): 67 | sigma = x.var(-1, keepdim=True, unbiased=False) 68 | return x / torch.sqrt(sigma + 1e-5) * self.weight 69 | 70 | 71 | class WithBias_LayerNorm(nn.Module): 72 | def __init__(self, normalized_shape): 73 | super(WithBias_LayerNorm, self).__init__() 74 | if isinstance(normalized_shape, numbers.Integral): 75 | normalized_shape = (normalized_shape,) 76 | normalized_shape = torch.Size(normalized_shape) 77 | 78 | assert len(normalized_shape) == 1 79 | 80 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 81 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 82 | self.normalized_shape = normalized_shape 83 | 84 | def forward(self, x): 85 | mu = x.mean(-1, keepdim=True) 86 | sigma = x.var(-1, keepdim=True, unbiased=False) 87 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 88 | 89 | 90 | class LayerNorm(nn.Module): 91 | def __init__(self, dim, LayerNorm_type='WithBias'): 92 | super(LayerNorm, self).__init__() 93 | if LayerNorm_type == 'BiasFree': 94 | self.body = BiasFree_LayerNorm(dim) 95 | else: 96 | self.body = WithBias_LayerNorm(dim) 97 | 98 | def forward(self, x): 99 | h, w = x.shape[-2:] 100 | return to_4d(self.body(to_3d(x)), h, w) 101 | 102 | 103 | # Gated-Dconv Feed-Forward Network (GDFN) 104 | class FeedForward(nn.Module): 105 | def __init__(self, dim, ffn_expansion_factor, bias): 106 | super(FeedForward, self).__init__() 107 | 108 | hidden_features = int(dim*ffn_expansion_factor) 109 | 110 | self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias) 111 | 112 | self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, 113 | groups=hidden_features*2, bias=bias) 114 | 115 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 116 | 117 | def forward(self, x): 118 | x = self.project_in(x) 119 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 120 | x = F.gelu(x1) * x2 121 | x = self.project_out(x) 122 | return x 123 | 124 | 125 | class MambaLayer(nn.Module): 126 | def __init__(self, dim): 127 | super(MambaLayer, self).__init__() 128 | self.mamba = Mamba(d_model=dim, d_state=16, d_conv=4, expand=2, bimamba_type='v2') 129 | self.dim = dim 130 | 131 | def forward(self, x): 132 | B, C = x.shape[:2] 133 | assert C == self.dim 134 | n_tokens = x.shape[2:].numel() 135 | img_dims = x.shape[2:] 136 | x_flat = x.reshape(B, C, n_tokens).transpose(-1, -2) 137 | out = self.mamba(x_flat) 138 | out = out.transpose(-1, -2).reshape(B, C, *img_dims) 139 | return out 140 | 141 | 142 | class HybridAttBlock(nn.Module): 143 | def __init__(self, dim): 144 | super(HybridAttBlock, self).__init__() 145 | self.mamba = MambaLayer(dim=dim) 146 | self.channel_attention = CAB(num_feat=dim, compress_ratio=4) 147 | 148 | def forward(self, x): 149 | att1 = self.mamba(x) 150 | att2 = self.channel_attention(x) 151 | return att1 + att2 152 | 153 | 154 | class MLPMixer(nn.Module): 155 | def __init__(self, dim, ffn_expansion_factor=2.66, bias=True, LayerNorm_type='WithBias'): 156 | super(MLPMixer, self).__init__() 157 | self.dim = dim 158 | self.norm1 = LayerNorm(dim, LayerNorm_type) 159 | self.attn = HybridAttBlock(dim=dim) 160 | self.norm2 = LayerNorm(dim, LayerNorm_type) 161 | self.ffn = FeedForward(dim, ffn_expansion_factor, bias) 162 | 163 | def forward(self, x): 164 | x = x + self.attn(self.norm1(x)) 165 | x = x + self.ffn(self.norm2(x)) 166 | return x 167 | 168 | 169 | if __name__ == '__main__': 170 | input = torch.rand([1, 64, 32, 32]).cuda() 171 | layer = HybridAttBlock(dim=64).cuda() 172 | output = layer(input) 173 | # print(output.shape) 174 | -------------------------------------------------------------------------------- /basicsr/archs/mma_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from basicsr.archs.arch_util import Upsample, make_layer 4 | from basicsr.utils.registry import ARCH_REGISTRY 5 | try: 6 | from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn 7 | except ImportError: 8 | RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None 9 | from basicsr.archs.mlpmixer_util import MLPMixer 10 | 11 | 12 | @ARCH_REGISTRY.register() 13 | class MMA(nn.Module): 14 | """MMA network structure. 15 | """ 16 | 17 | def __init__(self, 18 | num_in_ch, 19 | num_out_ch, 20 | num_feat=64, 21 | num_block=16, 22 | upscale=4, 23 | img_range=255., 24 | rgb_mean=(0.4488, 0.4371, 0.4040)): 25 | super(MMA, self).__init__() 26 | 27 | self.img_range = img_range 28 | self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1) 29 | 30 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 31 | self.body = make_layer(MLPMixer, num_block, dim=num_feat) 32 | self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 33 | self.upsample = Upsample(upscale, num_feat) 34 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 35 | 36 | def forward(self, x): 37 | self.mean = self.mean.type_as(x) 38 | 39 | x = (x - self.mean) * self.img_range 40 | x = self.conv_first(x) 41 | res = self.conv_after_body(self.body(x)) 42 | res += x 43 | 44 | x = self.conv_last(self.upsample(res)) 45 | x = x / self.img_range + self.mean 46 | 47 | return x -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'basicsr.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset(dataset_opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 41 | """Build dataloader. 42 | 43 | Args: 44 | dataset (torch.utils.data.Dataset): Dataset. 45 | dataset_opt (dict): Dataset options. It contains the following keys: 46 | phase (str): 'train' or 'val'. 47 | num_worker_per_gpu (int): Number of workers for each GPU. 48 | batch_size_per_gpu (int): Training batch size for each GPU. 49 | num_gpu (int): Number of GPUs. Used only in the train phase. 50 | Default: 1. 51 | dist (bool): Whether in distributed training. Used only in the train 52 | phase. Default: False. 53 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 54 | seed (int | None): Seed. Default: None 55 | """ 56 | phase = dataset_opt['phase'] 57 | rank, _ = get_dist_info() 58 | if phase == 'train': 59 | if dist: # distributed training 60 | batch_size = dataset_opt['batch_size_per_gpu'] 61 | num_workers = dataset_opt['num_worker_per_gpu'] 62 | else: # non-distributed training 63 | multiplier = 1 if num_gpu == 0 else num_gpu 64 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 65 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 66 | dataloader_args = dict( 67 | dataset=dataset, 68 | batch_size=batch_size, 69 | shuffle=False, 70 | num_workers=num_workers, 71 | sampler=sampler, 72 | drop_last=True) 73 | if sampler is None: 74 | dataloader_args['shuffle'] = True 75 | dataloader_args['worker_init_fn'] = partial( 76 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 77 | elif phase in ['val', 'test']: # validation 78 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 79 | else: 80 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 81 | 82 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 83 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 84 | 85 | prefetch_mode = dataset_opt.get('prefetch_mode') 86 | if prefetch_mode == 'cpu': # CPUPrefetcher 87 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 88 | logger = get_root_logger() 89 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 90 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 91 | else: 92 | # prefetch_mode=None: Normal dataloader 93 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 94 | return torch.utils.data.DataLoader(**dataloader_args) 95 | 96 | 97 | def worker_init_fn(worker_id, num_workers, rank, seed): 98 | # Set the worker seed to num_workers * rank + worker_id + seed 99 | worker_seed = num_workers * rank + worker_id + seed 100 | np.random.seed(worker_seed) 101 | random.seed(worker_seed) 102 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil(len(self.dataset) * ratio / self.num_replicas) 27 | self.total_size = self.num_samples * self.num_replicas 28 | 29 | def __iter__(self): 30 | # deterministically shuffle based on epoch 31 | g = torch.Generator() 32 | g.manual_seed(self.epoch) 33 | indices = torch.randperm(self.total_size, generator=g).tolist() 34 | 35 | dataset_size = len(self.dataset) 36 | indices = [v % dataset_size for v in indices] 37 | 38 | # subsample 39 | indices = indices[self.rank:self.total_size:self.num_replicas] 40 | assert len(indices) == self.num_samples 41 | 42 | return iter(indices) 43 | 44 | def __len__(self): 45 | return self.num_samples 46 | 47 | def set_epoch(self, epoch): 48 | self.epoch = epoch 49 | -------------------------------------------------------------------------------- /basicsr/data/imagenet_paired_dataset.py: -------------------------------------------------------------------------------- 1 | import pdb 2 | 3 | import cv2 4 | import numpy as np 5 | import os.path as osp 6 | from torch.utils import data as data 7 | from torchvision.transforms.functional import normalize 8 | import os 9 | from basicsr.data.data_util import paths_from_lmdb, scandir 10 | from basicsr.data.transforms import augment, paired_random_crop 11 | from basicsr.utils import FileClient, imfrombytes, img2tensor 12 | from basicsr.utils.matlab_functions import imresize, rgb2ycbcr 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | 16 | @DATASET_REGISTRY.register() 17 | class ImageNetPairedDataset(data.Dataset): 18 | 19 | def __init__(self, opt): 20 | super(ImageNetPairedDataset, self).__init__() 21 | self.opt = opt 22 | # file client (io backend) 23 | self.file_client = None 24 | self.io_backend_opt = opt['io_backend'] 25 | self.mean = opt['mean'] if 'mean' in opt else None 26 | self.std = opt['std'] if 'std' in opt else None 27 | self.gt_folder = opt['dataroot_gt'] 28 | 29 | if self.io_backend_opt['type'] == 'lmdb': 30 | self.io_backend_opt['db_paths'] = [self.gt_folder] 31 | self.io_backend_opt['client_keys'] = ['gt'] 32 | self.paths = paths_from_lmdb(self.gt_folder) 33 | elif 'meta_info_file' in self.opt: 34 | with open(self.opt['meta_info_file'], 'r') as fin: 35 | self.paths = [osp.join(self.gt_folder, line.split(' ')[0]) for line in fin] 36 | else: 37 | self.paths = [] 38 | class_list = os.listdir(self.gt_folder) 39 | for item in class_list: 40 | current_dir = self.gt_folder + '/' + item 41 | current_paths = os.listdir(current_dir) 42 | current_paths = [current_dir + '/' + i for i in current_paths] 43 | self.paths += current_paths 44 | 45 | def __getitem__(self, index): 46 | if self.file_client is None: 47 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 48 | 49 | scale = self.opt['scale'] 50 | 51 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 52 | # image range: [0, 1], float32. 53 | gt_path = self.paths[index] 54 | img_bytes = self.file_client.get(gt_path, 'gt') 55 | img_gt = imfrombytes(img_bytes, float32=True) 56 | 57 | # modcrop 58 | size_h, size_w, _ = img_gt.shape 59 | size_h = size_h - size_h % scale 60 | size_w = size_w - size_w % scale 61 | img_gt = img_gt[0:size_h, 0:size_w, :] 62 | 63 | # generate training pairs 64 | size_h = max(size_h, self.opt['gt_size']) 65 | size_w = max(size_w, self.opt['gt_size']) 66 | img_gt = cv2.resize(img_gt, (size_w, size_h)) 67 | img_lq = imresize(img_gt, 1 / scale) 68 | 69 | img_gt = np.ascontiguousarray(img_gt, dtype=np.float32) 70 | img_lq = np.ascontiguousarray(img_lq, dtype=np.float32) 71 | 72 | # augmentation for training 73 | if self.opt['phase'] == 'train': 74 | gt_size = self.opt['gt_size'] 75 | # random crop 76 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 77 | # flip, rotation 78 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 79 | 80 | # color space transform 81 | if 'color' in self.opt and self.opt['color'] == 'y': 82 | img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] 83 | img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 84 | 85 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 86 | # TODO: It is better to update the datasets, rather than force to crop 87 | if self.opt['phase'] != 'train': 88 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 89 | 90 | # BGR to RGB, HWC to CHW, numpy to tensor 91 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 92 | # normalize 93 | if self.mean is not None or self.std is not None: 94 | normalize(img_lq, self.mean, self.std, inplace=True) 95 | normalize(img_gt, self.mean, self.std, inplace=True) 96 | 97 | return {'lq': img_lq, 'gt': img_gt, 'gt_path': gt_path} 98 | 99 | def __len__(self): 100 | return len(self.paths) 101 | -------------------------------------------------------------------------------- /basicsr/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment, paired_random_crop 6 | from basicsr.utils import FileClient, bgr2ycbcr, imfrombytes, img2tensor 7 | from basicsr.utils.registry import DATASET_REGISTRY 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class PairedImageDataset(data.Dataset): 12 | """Paired image dataset for image restoration. 13 | 14 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 15 | 16 | There are three modes: 17 | 18 | 1. **lmdb**: Use lmdb files. If opt['io_backend'] == lmdb. 19 | 2. **meta_info_file**: Use meta information file to generate paths. \ 20 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 21 | 3. **folder**: Scan folders to generate paths. The rest. 22 | 23 | Args: 24 | opt (dict): Config for train datasets. It contains the following keys: 25 | dataroot_gt (str): Data root path for gt. 26 | dataroot_lq (str): Data root path for lq. 27 | meta_info_file (str): Path for meta information file. 28 | io_backend (dict): IO backend type and other kwarg. 29 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 30 | Default: '{}'. 31 | gt_size (int): Cropped patched size for gt patches. 32 | use_hflip (bool): Use horizontal flips. 33 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 34 | scale (bool): Scale, which will be added automatically. 35 | phase (str): 'train' or 'val'. 36 | """ 37 | 38 | def __init__(self, opt): 39 | super(PairedImageDataset, self).__init__() 40 | self.opt = opt 41 | # file client (io backend) 42 | self.file_client = None 43 | self.io_backend_opt = opt['io_backend'] 44 | self.mean = opt['mean'] if 'mean' in opt else None 45 | self.std = opt['std'] if 'std' in opt else None 46 | 47 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 48 | if 'filename_tmpl' in opt: 49 | self.filename_tmpl = opt['filename_tmpl'] 50 | else: 51 | self.filename_tmpl = '{}' 52 | 53 | if self.io_backend_opt['type'] == 'lmdb': 54 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 55 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 56 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 57 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 58 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 59 | self.opt['meta_info_file'], self.filename_tmpl) 60 | else: 61 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 62 | 63 | def __getitem__(self, index): 64 | if self.file_client is None: 65 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 66 | 67 | scale = self.opt['scale'] 68 | 69 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 70 | # image range: [0, 1], float32. 71 | gt_path = self.paths[index]['gt_path'] 72 | img_bytes = self.file_client.get(gt_path, 'gt') 73 | img_gt = imfrombytes(img_bytes, float32=True) 74 | lq_path = self.paths[index]['lq_path'] 75 | img_bytes = self.file_client.get(lq_path, 'lq') 76 | img_lq = imfrombytes(img_bytes, float32=True) 77 | 78 | # augmentation for training 79 | if self.opt['phase'] == 'train': 80 | gt_size = self.opt['gt_size'] 81 | # random crop 82 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 83 | # flip, rotation 84 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 85 | 86 | # color space transform 87 | if 'color' in self.opt and self.opt['color'] == 'y': 88 | img_gt = bgr2ycbcr(img_gt, y_only=True)[..., None] 89 | img_lq = bgr2ycbcr(img_lq, y_only=True)[..., None] 90 | 91 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 92 | # TODO: It is better to update the datasets, rather than force to crop 93 | if self.opt['phase'] != 'train': 94 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 95 | 96 | # BGR to RGB, HWC to CHW, numpy to tensor 97 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 98 | # normalize 99 | if self.mean is not None or self.std is not None: 100 | normalize(img_lq, self.mean, self.std, inplace=True) 101 | normalize(img_gt, self.mean, self.std, inplace=True) 102 | 103 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path} 104 | 105 | def __len__(self): 106 | return len(self.paths) -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Reference: https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 11 | 12 | Args: 13 | generator: Python generator. 14 | num_prefetch_queue (int): Number of prefetch queue. 15 | """ 16 | 17 | def __init__(self, generator, num_prefetch_queue): 18 | threading.Thread.__init__(self) 19 | self.queue = Queue.Queue(num_prefetch_queue) 20 | self.generator = generator 21 | self.daemon = True 22 | self.start() 23 | 24 | def run(self): 25 | for item in self.generator: 26 | self.queue.put(item) 27 | self.queue.put(None) 28 | 29 | def __next__(self): 30 | next_item = self.queue.get() 31 | if next_item is None: 32 | raise StopIteration 33 | return next_item 34 | 35 | def __iter__(self): 36 | return self 37 | 38 | 39 | class PrefetchDataLoader(DataLoader): 40 | """Prefetch version of dataloader. 41 | 42 | Reference: https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 43 | 44 | TODO: 45 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 46 | ddp. 47 | 48 | Args: 49 | num_prefetch_queue (int): Number of prefetch queue. 50 | kwargs (dict): Other arguments for dataloader. 51 | """ 52 | 53 | def __init__(self, num_prefetch_queue, **kwargs): 54 | self.num_prefetch_queue = num_prefetch_queue 55 | super(PrefetchDataLoader, self).__init__(**kwargs) 56 | 57 | def __iter__(self): 58 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 59 | 60 | 61 | class CPUPrefetcher(): 62 | """CPU prefetcher. 63 | 64 | Args: 65 | loader: Dataloader. 66 | """ 67 | 68 | def __init__(self, loader): 69 | self.ori_loader = loader 70 | self.loader = iter(loader) 71 | 72 | def next(self): 73 | try: 74 | return next(self.loader) 75 | except StopIteration: 76 | return None 77 | 78 | def reset(self): 79 | self.loader = iter(self.ori_loader) 80 | 81 | 82 | class CUDAPrefetcher(): 83 | """CUDA prefetcher. 84 | 85 | Reference: https://github.com/NVIDIA/apex/issues/304# 86 | 87 | It may consume more GPU memory. 88 | 89 | Args: 90 | loader: Dataloader. 91 | opt (dict): Options. 92 | """ 93 | 94 | def __init__(self, loader, opt): 95 | self.ori_loader = loader 96 | self.loader = iter(loader) 97 | self.opt = opt 98 | self.stream = torch.cuda.Stream() 99 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 100 | self.preload() 101 | 102 | def preload(self): 103 | try: 104 | self.batch = next(self.loader) # self.batch is a dict 105 | except StopIteration: 106 | self.batch = None 107 | return None 108 | # put tensors to gpu 109 | with torch.cuda.stream(self.stream): 110 | for k, v in self.batch.items(): 111 | if torch.is_tensor(v): 112 | self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) 113 | 114 | def next(self): 115 | torch.cuda.current_stream().wait_stream(self.stream) 116 | batch = self.batch 117 | self.preload() 118 | return batch 119 | 120 | def reset(self): 121 | self.loader = iter(self.ori_loader) 122 | self.preload() 123 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def mod_crop(img, scale): 7 | """Mod crop images, used during testing. 8 | 9 | Args: 10 | img (ndarray): Input image. 11 | scale (int): Scale factor. 12 | 13 | Returns: 14 | ndarray: Result image. 15 | """ 16 | img = img.copy() 17 | if img.ndim in (2, 3): 18 | h, w = img.shape[0], img.shape[1] 19 | h_remainder, w_remainder = h % scale, w % scale 20 | img = img[:h - h_remainder, :w - w_remainder, ...] 21 | else: 22 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 23 | return img 24 | 25 | 26 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 27 | """Paired random crop. Support Numpy array and Tensor inputs. 28 | 29 | It crops lists of lq and gt images with corresponding locations. 30 | 31 | Args: 32 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 33 | should have the same shape. If the input is an ndarray, it will 34 | be transformed to a list containing itself. 35 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 36 | should have the same shape. If the input is an ndarray, it will 37 | be transformed to a list containing itself. 38 | gt_patch_size (int): GT patch size. 39 | scale (int): Scale factor. 40 | gt_path (str): Path to ground-truth. Default: None. 41 | 42 | Returns: 43 | list[ndarray] | ndarray: GT images and LQ images. If returned results 44 | only have one element, just return ndarray. 45 | """ 46 | 47 | if not isinstance(img_gts, list): 48 | img_gts = [img_gts] 49 | if not isinstance(img_lqs, list): 50 | img_lqs = [img_lqs] 51 | 52 | # determine input type: Numpy array or Tensor 53 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 54 | 55 | if input_type == 'Tensor': 56 | h_lq, w_lq = img_lqs[0].size()[-2:] 57 | h_gt, w_gt = img_gts[0].size()[-2:] 58 | else: 59 | h_lq, w_lq = img_lqs[0].shape[0:2] 60 | h_gt, w_gt = img_gts[0].shape[0:2] 61 | lq_patch_size = gt_patch_size // scale 62 | 63 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 64 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 65 | f'multiplication of LQ ({h_lq}, {w_lq}).') 66 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 67 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 68 | f'({lq_patch_size}, {lq_patch_size}). ' 69 | f'Please remove {gt_path}.') 70 | 71 | # randomly choose top and left coordinates for lq patch 72 | top = random.randint(0, h_lq - lq_patch_size) 73 | left = random.randint(0, w_lq - lq_patch_size) 74 | 75 | # crop lq patch 76 | if input_type == 'Tensor': 77 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 78 | else: 79 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 80 | 81 | # crop corresponding gt patch 82 | top_gt, left_gt = int(top * scale), int(left * scale) 83 | if input_type == 'Tensor': 84 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 85 | else: 86 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 87 | if len(img_gts) == 1: 88 | img_gts = img_gts[0] 89 | if len(img_lqs) == 1: 90 | img_lqs = img_lqs[0] 91 | return img_gts, img_lqs 92 | 93 | 94 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 95 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 96 | 97 | We use vertical flip and transpose for rotation implementation. 98 | All the images in the list use the same augmentation. 99 | 100 | Args: 101 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 102 | is an ndarray, it will be transformed to a list. 103 | hflip (bool): Horizontal flip. Default: True. 104 | rotation (bool): Ratotation. Default: True. 105 | flows (list[ndarray]: Flows to be augmented. If the input is an 106 | ndarray, it will be transformed to a list. 107 | Dimension is (h, w, 2). Default: None. 108 | return_status (bool): Return the status of flip and rotation. 109 | Default: False. 110 | 111 | Returns: 112 | list[ndarray] | ndarray: Augmented images and flows. If returned 113 | results only have one element, just return ndarray. 114 | 115 | """ 116 | hflip = hflip and random.random() < 0.5 117 | vflip = rotation and random.random() < 0.5 118 | rot90 = rotation and random.random() < 0.5 119 | 120 | def _augment(img): 121 | if hflip: # horizontal 122 | cv2.flip(img, 1, img) 123 | if vflip: # vertical 124 | cv2.flip(img, 0, img) 125 | if rot90: 126 | img = img.transpose(1, 0, 2) 127 | return img 128 | 129 | def _augment_flow(flow): 130 | if hflip: # horizontal 131 | cv2.flip(flow, 1, flow) 132 | flow[:, :, 0] *= -1 133 | if vflip: # vertical 134 | cv2.flip(flow, 0, flow) 135 | flow[:, :, 1] *= -1 136 | if rot90: 137 | flow = flow.transpose(1, 0, 2) 138 | flow = flow[:, :, [1, 0]] 139 | return flow 140 | 141 | if not isinstance(imgs, list): 142 | imgs = [imgs] 143 | imgs = [_augment(img) for img in imgs] 144 | if len(imgs) == 1: 145 | imgs = imgs[0] 146 | 147 | if flows is not None: 148 | if not isinstance(flows, list): 149 | flows = [flows] 150 | flows = [_augment_flow(flow) for flow in flows] 151 | if len(flows) == 1: 152 | flows = flows[0] 153 | return imgs, flows 154 | else: 155 | if return_status: 156 | return imgs, (hflip, vflip, rot90) 157 | else: 158 | return imgs 159 | 160 | 161 | def img_rotate(img, angle, center=None, scale=1.0): 162 | """Rotate image. 163 | 164 | Args: 165 | img (ndarray): Image to be rotated. 166 | angle (float): Rotation angle in degrees. Positive values mean 167 | counter-clockwise rotation. 168 | center (tuple[int]): Rotation center. If the center is None, 169 | initialize it as the center of the image. Default: None. 170 | scale (float): Isotropic scale factor. Default: 1.0. 171 | """ 172 | (h, w) = img.shape[:2] 173 | 174 | if center is None: 175 | center = (w // 2, h // 2) 176 | 177 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 178 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 179 | return rotated_img 180 | -------------------------------------------------------------------------------- /basicsr/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'basicsr.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /basicsr/losses/basic_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | 5 | from basicsr.utils.registry import LOSS_REGISTRY 6 | from .loss_util import weighted_loss 7 | 8 | _reduction_modes = ['none', 'mean', 'sum'] 9 | 10 | 11 | @weighted_loss 12 | def l1_loss(pred, target): 13 | return F.l1_loss(pred, target, reduction='none') 14 | 15 | 16 | @weighted_loss 17 | def mse_loss(pred, target): 18 | return F.mse_loss(pred, target, reduction='none') 19 | 20 | 21 | @weighted_loss 22 | def charbonnier_loss(pred, target, eps=1e-12): 23 | return torch.sqrt((pred - target)**2 + eps) 24 | 25 | 26 | @LOSS_REGISTRY.register() 27 | class L1Loss(nn.Module): 28 | """L1 (mean absolute error, MAE) loss. 29 | 30 | Args: 31 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 32 | reduction (str): Specifies the reduction to apply to the output. 33 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 34 | """ 35 | 36 | def __init__(self, loss_weight=1.0, reduction='mean'): 37 | super(L1Loss, self).__init__() 38 | if reduction not in ['none', 'mean', 'sum']: 39 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') 40 | 41 | self.loss_weight = loss_weight 42 | self.reduction = reduction 43 | 44 | def forward(self, pred, target, weight=None, **kwargs): 45 | """ 46 | Args: 47 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 48 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 49 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. 50 | """ 51 | return self.loss_weight * l1_loss(pred, target, weight, reduction=self.reduction) 52 | 53 | 54 | @LOSS_REGISTRY.register() 55 | class MSELoss(nn.Module): 56 | """MSE (L2) loss. 57 | 58 | Args: 59 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 60 | reduction (str): Specifies the reduction to apply to the output. 61 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 62 | """ 63 | 64 | def __init__(self, loss_weight=1.0, reduction='mean'): 65 | super(MSELoss, self).__init__() 66 | if reduction not in ['none', 'mean', 'sum']: 67 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') 68 | 69 | self.loss_weight = loss_weight 70 | self.reduction = reduction 71 | 72 | def forward(self, pred, target, weight=None, **kwargs): 73 | """ 74 | Args: 75 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 76 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 77 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. 78 | """ 79 | return self.loss_weight * mse_loss(pred, target, weight, reduction=self.reduction) 80 | 81 | 82 | @LOSS_REGISTRY.register() 83 | class CharbonnierLoss(nn.Module): 84 | """Charbonnier loss (one variant of Robust L1Loss, a differentiable 85 | variant of L1Loss). 86 | 87 | Described in "Deep Laplacian Pyramid Networks for Fast and Accurate 88 | Super-Resolution". 89 | 90 | Args: 91 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 92 | reduction (str): Specifies the reduction to apply to the output. 93 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 94 | eps (float): A value used to control the curvature near zero. Default: 1e-12. 95 | """ 96 | 97 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-12): 98 | super(CharbonnierLoss, self).__init__() 99 | if reduction not in ['none', 'mean', 'sum']: 100 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: {_reduction_modes}') 101 | 102 | self.loss_weight = loss_weight 103 | self.reduction = reduction 104 | self.eps = eps 105 | 106 | def forward(self, pred, target, weight=None, **kwargs): 107 | """ 108 | Args: 109 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 110 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 111 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. 112 | """ 113 | return self.loss_weight * charbonnier_loss(pred, target, weight, eps=self.eps, reduction=self.reduction) 114 | 115 | 116 | @LOSS_REGISTRY.register() 117 | class WeightedTVLoss(L1Loss): 118 | """Weighted TV loss. 119 | 120 | Args: 121 | loss_weight (float): Loss weight. Default: 1.0. 122 | """ 123 | 124 | def __init__(self, loss_weight=1.0, reduction='mean'): 125 | if reduction not in ['mean', 'sum']: 126 | raise ValueError(f'Unsupported reduction mode: {reduction}. Supported ones are: mean | sum') 127 | super(WeightedTVLoss, self).__init__(loss_weight=loss_weight, reduction=reduction) 128 | 129 | def forward(self, pred, weight=None): 130 | if weight is None: 131 | y_weight = None 132 | x_weight = None 133 | else: 134 | y_weight = weight[:, :, :-1, :] 135 | x_weight = weight[:, :, :, :-1] 136 | 137 | y_diff = super().forward(pred[:, :, :-1, :], pred[:, :, 1:, :], weight=y_weight) 138 | x_diff = super().forward(pred[:, :, :, :-1], pred[:, :, :, 1:], weight=x_weight) 139 | 140 | loss = x_diff + y_diff 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /basicsr/losses/gan_loss.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import autograd as autograd 4 | from torch import nn as nn 5 | from torch.nn import functional as F 6 | 7 | from basicsr.utils.registry import LOSS_REGISTRY 8 | 9 | 10 | @LOSS_REGISTRY.register() 11 | class GANLoss(nn.Module): 12 | """Define GAN loss. 13 | 14 | Args: 15 | gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. 16 | real_label_val (float): The value for real label. Default: 1.0. 17 | fake_label_val (float): The value for fake label. Default: 0.0. 18 | loss_weight (float): Loss weight. Default: 1.0. 19 | Note that loss_weight is only for generators; and it is always 1.0 20 | for discriminators. 21 | """ 22 | 23 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): 24 | super(GANLoss, self).__init__() 25 | self.gan_type = gan_type 26 | self.loss_weight = loss_weight 27 | self.real_label_val = real_label_val 28 | self.fake_label_val = fake_label_val 29 | 30 | if self.gan_type == 'vanilla': 31 | self.loss = nn.BCEWithLogitsLoss() 32 | elif self.gan_type == 'lsgan': 33 | self.loss = nn.MSELoss() 34 | elif self.gan_type == 'wgan': 35 | self.loss = self._wgan_loss 36 | elif self.gan_type == 'wgan_softplus': 37 | self.loss = self._wgan_softplus_loss 38 | elif self.gan_type == 'hinge': 39 | self.loss = nn.ReLU() 40 | else: 41 | raise NotImplementedError(f'GAN type {self.gan_type} is not implemented.') 42 | 43 | def _wgan_loss(self, input, target): 44 | """wgan loss. 45 | 46 | Args: 47 | input (Tensor): Input tensor. 48 | target (bool): Target label. 49 | 50 | Returns: 51 | Tensor: wgan loss. 52 | """ 53 | return -input.mean() if target else input.mean() 54 | 55 | def _wgan_softplus_loss(self, input, target): 56 | """wgan loss with soft plus. softplus is a smooth approximation to the 57 | ReLU function. 58 | 59 | In StyleGAN2, it is called: 60 | Logistic loss for discriminator; 61 | Non-saturating loss for generator. 62 | 63 | Args: 64 | input (Tensor): Input tensor. 65 | target (bool): Target label. 66 | 67 | Returns: 68 | Tensor: wgan loss. 69 | """ 70 | return F.softplus(-input).mean() if target else F.softplus(input).mean() 71 | 72 | def get_target_label(self, input, target_is_real): 73 | """Get target label. 74 | 75 | Args: 76 | input (Tensor): Input tensor. 77 | target_is_real (bool): Whether the target is real or fake. 78 | 79 | Returns: 80 | (bool | Tensor): Target tensor. Return bool for wgan, otherwise, 81 | return Tensor. 82 | """ 83 | 84 | if self.gan_type in ['wgan', 'wgan_softplus']: 85 | return target_is_real 86 | target_val = (self.real_label_val if target_is_real else self.fake_label_val) 87 | return input.new_ones(input.size()) * target_val 88 | 89 | def forward(self, input, target_is_real, is_disc=False): 90 | """ 91 | Args: 92 | input (Tensor): The input for the loss module, i.e., the network 93 | prediction. 94 | target_is_real (bool): Whether the targe is real or fake. 95 | is_disc (bool): Whether the loss for discriminators or not. 96 | Default: False. 97 | 98 | Returns: 99 | Tensor: GAN loss value. 100 | """ 101 | target_label = self.get_target_label(input, target_is_real) 102 | if self.gan_type == 'hinge': 103 | if is_disc: # for discriminators in hinge-gan 104 | input = -input if target_is_real else input 105 | loss = self.loss(1 + input).mean() 106 | else: # for generators in hinge-gan 107 | loss = -input.mean() 108 | else: # other gan types 109 | loss = self.loss(input, target_label) 110 | 111 | # loss_weight is always 1.0 for discriminators 112 | return loss if is_disc else loss * self.loss_weight 113 | 114 | 115 | @LOSS_REGISTRY.register() 116 | class MultiScaleGANLoss(GANLoss): 117 | """ 118 | MultiScaleGANLoss accepts a list of predictions 119 | """ 120 | 121 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): 122 | super(MultiScaleGANLoss, self).__init__(gan_type, real_label_val, fake_label_val, loss_weight) 123 | 124 | def forward(self, input, target_is_real, is_disc=False): 125 | """ 126 | The input is a list of tensors, or a list of (a list of tensors) 127 | """ 128 | if isinstance(input, list): 129 | loss = 0 130 | for pred_i in input: 131 | if isinstance(pred_i, list): 132 | # Only compute GAN loss for the last layer 133 | # in case of multiscale feature matching 134 | pred_i = pred_i[-1] 135 | # Safe operation: 0-dim tensor calling self.mean() does nothing 136 | loss_tensor = super().forward(pred_i, target_is_real, is_disc).mean() 137 | loss += loss_tensor 138 | return loss / len(input) 139 | else: 140 | return super().forward(input, target_is_real, is_disc) 141 | 142 | 143 | def r1_penalty(real_pred, real_img): 144 | """R1 regularization for discriminator. The core idea is to 145 | penalize the gradient on real data alone: when the 146 | generator distribution produces the true data distribution 147 | and the discriminator is equal to 0 on the data manifold, the 148 | gradient penalty ensures that the discriminator cannot create 149 | a non-zero gradient orthogonal to the data manifold without 150 | suffering a loss in the GAN game. 151 | 152 | Reference: Eq. 9 in Which training methods for GANs do actually converge. 153 | """ 154 | grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0] 155 | grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean() 156 | return grad_penalty 157 | 158 | 159 | def g_path_regularize(fake_img, latents, mean_path_length, decay=0.01): 160 | noise = torch.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3]) 161 | grad = autograd.grad(outputs=(fake_img * noise).sum(), inputs=latents, create_graph=True)[0] 162 | path_lengths = torch.sqrt(grad.pow(2).sum(2).mean(1)) 163 | 164 | path_mean = mean_path_length + decay * (path_lengths.mean() - mean_path_length) 165 | 166 | path_penalty = (path_lengths - path_mean).pow(2).mean() 167 | 168 | return path_penalty, path_lengths.detach().mean(), path_mean.detach() 169 | 170 | 171 | def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None): 172 | """Calculate gradient penalty for wgan-gp. 173 | 174 | Args: 175 | discriminator (nn.Module): Network for the discriminator. 176 | real_data (Tensor): Real input data. 177 | fake_data (Tensor): Fake input data. 178 | weight (Tensor): Weight tensor. Default: None. 179 | 180 | Returns: 181 | Tensor: A tensor for gradient penalty. 182 | """ 183 | 184 | batch_size = real_data.size(0) 185 | alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1)) 186 | 187 | # interpolate between real_data and fake_data 188 | interpolates = alpha * real_data + (1. - alpha) * fake_data 189 | interpolates = autograd.Variable(interpolates, requires_grad=True) 190 | 191 | disc_interpolates = discriminator(interpolates) 192 | gradients = autograd.grad( 193 | outputs=disc_interpolates, 194 | inputs=interpolates, 195 | grad_outputs=torch.ones_like(disc_interpolates), 196 | create_graph=True, 197 | retain_graph=True, 198 | only_inputs=True)[0] 199 | 200 | if weight is not None: 201 | gradients = gradients * weight 202 | 203 | gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean() 204 | if weight is not None: 205 | gradients_penalty /= torch.mean(weight) 206 | 207 | return gradients_penalty 208 | -------------------------------------------------------------------------------- /basicsr/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | from torch.nn import functional as F 4 | 5 | 6 | def reduce_loss(loss, reduction): 7 | """Reduce loss as specified. 8 | 9 | Args: 10 | loss (Tensor): Elementwise loss tensor. 11 | reduction (str): Options are 'none', 'mean' and 'sum'. 12 | 13 | Returns: 14 | Tensor: Reduced loss tensor. 15 | """ 16 | reduction_enum = F._Reduction.get_enum(reduction) 17 | # none: 0, elementwise_mean:1, sum: 2 18 | if reduction_enum == 0: 19 | return loss 20 | elif reduction_enum == 1: 21 | return loss.mean() 22 | else: 23 | return loss.sum() 24 | 25 | 26 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 27 | """Apply element-wise weight and reduce loss. 28 | 29 | Args: 30 | loss (Tensor): Element-wise loss. 31 | weight (Tensor): Element-wise weights. Default: None. 32 | reduction (str): Same as built-in losses of PyTorch. Options are 33 | 'none', 'mean' and 'sum'. Default: 'mean'. 34 | 35 | Returns: 36 | Tensor: Loss values. 37 | """ 38 | # if weight is specified, apply element-wise weight 39 | if weight is not None: 40 | assert weight.dim() == loss.dim() 41 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 42 | loss = loss * weight 43 | 44 | # if weight is not specified or reduction is sum, just reduce the loss 45 | if weight is None or reduction == 'sum': 46 | loss = reduce_loss(loss, reduction) 47 | # if reduction is mean, then compute mean over weight region 48 | elif reduction == 'mean': 49 | if weight.size(1) > 1: 50 | weight = weight.sum() 51 | else: 52 | weight = weight.sum() * loss.size(1) 53 | loss = loss.sum() / weight 54 | 55 | return loss 56 | 57 | 58 | def weighted_loss(loss_func): 59 | """Create a weighted version of a given loss function. 60 | 61 | To use this decorator, the loss function must have the signature like 62 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 63 | element-wise loss without any reduction. This decorator will add weight 64 | and reduction arguments to the function. The decorated function will have 65 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 66 | **kwargs)`. 67 | 68 | :Example: 69 | 70 | >>> import torch 71 | >>> @weighted_loss 72 | >>> def l1_loss(pred, target): 73 | >>> return (pred - target).abs() 74 | 75 | >>> pred = torch.Tensor([0, 2, 3]) 76 | >>> target = torch.Tensor([1, 1, 1]) 77 | >>> weight = torch.Tensor([1, 0, 1]) 78 | 79 | >>> l1_loss(pred, target) 80 | tensor(1.3333) 81 | >>> l1_loss(pred, target, weight) 82 | tensor(1.5000) 83 | >>> l1_loss(pred, target, reduction='none') 84 | tensor([1., 1., 2.]) 85 | >>> l1_loss(pred, target, weight, reduction='sum') 86 | tensor(3.) 87 | """ 88 | 89 | @functools.wraps(loss_func) 90 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 91 | # get element-wise loss 92 | loss = loss_func(pred, target, **kwargs) 93 | loss = weight_reduce_loss(loss, weight, reduction) 94 | return loss 95 | 96 | return wrapper 97 | 98 | 99 | def get_local_weights(residual, ksize): 100 | """Get local weights for generating the artifact map of LDL. 101 | 102 | It is only called by the `get_refined_artifact_map` function. 103 | 104 | Args: 105 | residual (Tensor): Residual between predicted and ground truth images. 106 | ksize (Int): size of the local window. 107 | 108 | Returns: 109 | Tensor: weight for each pixel to be discriminated as an artifact pixel 110 | """ 111 | 112 | pad = (ksize - 1) // 2 113 | residual_pad = F.pad(residual, pad=[pad, pad, pad, pad], mode='reflect') 114 | 115 | unfolded_residual = residual_pad.unfold(2, ksize, 1).unfold(3, ksize, 1) 116 | pixel_level_weight = torch.var(unfolded_residual, dim=(-1, -2), unbiased=True, keepdim=True).squeeze(-1).squeeze(-1) 117 | 118 | return pixel_level_weight 119 | 120 | 121 | def get_refined_artifact_map(img_gt, img_output, img_ema, ksize): 122 | """Calculate the artifact map of LDL 123 | (Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution. In CVPR 2022) 124 | 125 | Args: 126 | img_gt (Tensor): ground truth images. 127 | img_output (Tensor): output images given by the optimizing model. 128 | img_ema (Tensor): output images given by the ema model. 129 | ksize (Int): size of the local window. 130 | 131 | Returns: 132 | overall_weight: weight for each pixel to be discriminated as an artifact pixel 133 | (calculated based on both local and global observations). 134 | """ 135 | 136 | residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True) 137 | residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True) 138 | 139 | patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5) 140 | pixel_level_weight = get_local_weights(residual_sr.clone(), ksize) 141 | overall_weight = patch_level_weight * pixel_level_weight 142 | 143 | overall_weight[residual_sr < residual_ema] = 0 144 | 145 | return overall_weight 146 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from basicsr.utils.registry import METRIC_REGISTRY 4 | from .psnr_ssim import calculate_psnr, calculate_ssim 5 | 6 | __all__ = ['calculate_psnr', 'calculate_ssim'] 7 | 8 | 9 | def calculate_metric(data, opt): 10 | """Calculate metric from data and options. 11 | 12 | Args: 13 | opt (dict): Configuration. It must contain: 14 | type (str): Model type. 15 | """ 16 | opt = deepcopy(opt) 17 | metric_type = opt.pop('type') 18 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 19 | return metric 20 | -------------------------------------------------------------------------------- /basicsr/metrics/fid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | from scipy import linalg 5 | from tqdm import tqdm 6 | 7 | from basicsr.archs.inception import InceptionV3 8 | 9 | 10 | def load_patched_inception_v3(device='cuda', resize_input=True, normalize_input=False): 11 | # we may not resize the input, but in [rosinality/stylegan2-pytorch] it 12 | # does resize the input. 13 | inception = InceptionV3([3], resize_input=resize_input, normalize_input=normalize_input) 14 | inception = nn.DataParallel(inception).eval().to(device) 15 | return inception 16 | 17 | 18 | @torch.no_grad() 19 | def extract_inception_features(data_generator, inception, len_generator=None, device='cuda'): 20 | """Extract inception features. 21 | 22 | Args: 23 | data_generator (generator): A data generator. 24 | inception (nn.Module): Inception model. 25 | len_generator (int): Length of the data_generator to show the 26 | progressbar. Default: None. 27 | device (str): Device. Default: cuda. 28 | 29 | Returns: 30 | Tensor: Extracted features. 31 | """ 32 | if len_generator is not None: 33 | pbar = tqdm(total=len_generator, unit='batch', desc='Extract') 34 | else: 35 | pbar = None 36 | features = [] 37 | 38 | for data in data_generator: 39 | if pbar: 40 | pbar.update(1) 41 | data = data.to(device) 42 | feature = inception(data)[0].view(data.shape[0], -1) 43 | features.append(feature.to('cpu')) 44 | if pbar: 45 | pbar.close() 46 | features = torch.cat(features, 0) 47 | return features 48 | 49 | 50 | def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): 51 | """Numpy implementation of the Frechet Distance. 52 | 53 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) and X_2 ~ N(mu_2, C_2) is: 54 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 55 | Stable version by Dougal J. Sutherland. 56 | 57 | Args: 58 | mu1 (np.array): The sample mean over activations. 59 | sigma1 (np.array): The covariance matrix over activations for generated samples. 60 | mu2 (np.array): The sample mean over activations, precalculated on an representative data set. 61 | sigma2 (np.array): The covariance matrix over activations, precalculated on an representative data set. 62 | 63 | Returns: 64 | float: The Frechet Distance. 65 | """ 66 | assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths' 67 | assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions') 68 | 69 | cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False) 70 | 71 | # Product might be almost singular 72 | if not np.isfinite(cov_sqrt).all(): 73 | print('Product of cov matrices is singular. Adding {eps} to diagonal of cov estimates') 74 | offset = np.eye(sigma1.shape[0]) * eps 75 | cov_sqrt = linalg.sqrtm((sigma1 + offset) @ (sigma2 + offset)) 76 | 77 | # Numerical error might give slight imaginary component 78 | if np.iscomplexobj(cov_sqrt): 79 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 80 | m = np.max(np.abs(cov_sqrt.imag)) 81 | raise ValueError(f'Imaginary component {m}') 82 | cov_sqrt = cov_sqrt.real 83 | 84 | mean_diff = mu1 - mu2 85 | mean_norm = mean_diff @ mean_diff 86 | trace = np.trace(sigma1) + np.trace(sigma2) - 2 * np.trace(cov_sqrt) 87 | fid = mean_norm + trace 88 | 89 | return fid 90 | -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError(f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'") 25 | if len(img.shape) == 2: 26 | img = img[..., None] 27 | if input_order == 'CHW': 28 | img = img.transpose(1, 2, 0) 29 | return img 30 | 31 | 32 | def to_y_channel(img): 33 | """Change to Y channel of YCbCr. 34 | 35 | Args: 36 | img (ndarray): Images with range [0, 255]. 37 | 38 | Returns: 39 | (ndarray): Images with range [0, 255] (float type) without round. 40 | """ 41 | img = img.astype(np.float32) / 255. 42 | if img.ndim == 3 and img.shape[2] == 3: 43 | img = bgr2ycbcr(img, y_only=True) 44 | img = img[..., None] 45 | return img * 255. 46 | -------------------------------------------------------------------------------- /basicsr/metrics/test_metrics/test_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | 4 | from basicsr.metrics import calculate_psnr, calculate_ssim 5 | from basicsr.metrics.psnr_ssim import calculate_psnr_pt, calculate_ssim_pt 6 | from basicsr.utils import img2tensor 7 | 8 | 9 | def test(img_path, img_path2, crop_border, test_y_channel=False): 10 | img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) 11 | img2 = cv2.imread(img_path2, cv2.IMREAD_UNCHANGED) 12 | 13 | # --------------------- Numpy --------------------- 14 | psnr = calculate_psnr(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 15 | ssim = calculate_ssim(img, img2, crop_border=crop_border, input_order='HWC', test_y_channel=test_y_channel) 16 | print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}') 17 | 18 | # --------------------- PyTorch (CPU) --------------------- 19 | img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 20 | img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0) 21 | 22 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 23 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 24 | print(f'\tTensor (CPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 25 | 26 | # --------------------- PyTorch (GPU) --------------------- 27 | img = img.cuda() 28 | img2 = img2.cuda() 29 | psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 30 | ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel) 31 | print(f'\tTensor (GPU) \tPSNR: {psnr_pth[0]:.6f} dB, \tSSIM: {ssim_pth[0]:.6f}') 32 | 33 | psnr_pth = calculate_psnr_pt( 34 | torch.repeat_interleave(img, 2, dim=0), 35 | torch.repeat_interleave(img2, 2, dim=0), 36 | crop_border=crop_border, 37 | test_y_channel=test_y_channel) 38 | ssim_pth = calculate_ssim_pt( 39 | torch.repeat_interleave(img, 2, dim=0), 40 | torch.repeat_interleave(img2, 2, dim=0), 41 | crop_border=crop_border, 42 | test_y_channel=test_y_channel) 43 | print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,' 44 | f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}') 45 | 46 | 47 | if __name__ == '__main__': 48 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=False) 49 | test('tests/data/bic/baboon.png', 'tests/data/gt/baboon.png', crop_border=4, test_y_channel=True) 50 | 51 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=False) 52 | test('tests/data/bic/comic.png', 'tests/data/gt/comic.png', crop_border=4, test_y_channel=True) 53 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'basicsr.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | 6 | class MultiStepRestartLR(_LRScheduler): 7 | """ MultiStep with restarts learning rate scheme. 8 | 9 | Args: 10 | optimizer (torch.nn.optimizer): Torch optimizer. 11 | milestones (list): Iterations that will decrease learning rate. 12 | gamma (float): Decrease ratio. Default: 0.1. 13 | restarts (list): Restart iterations. Default: [0]. 14 | restart_weights (list): Restart weights at each restart iteration. 15 | Default: [1]. 16 | last_epoch (int): Used in _LRScheduler. Default: -1. 17 | """ 18 | 19 | def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1): 20 | self.milestones = Counter(milestones) 21 | self.gamma = gamma 22 | self.restarts = restarts 23 | self.restart_weights = restart_weights 24 | assert len(self.restarts) == len(self.restart_weights), 'restarts and their weights do not match.' 25 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 26 | 27 | def get_lr(self): 28 | if self.last_epoch in self.restarts: 29 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 30 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 31 | if self.last_epoch not in self.milestones: 32 | return [group['lr'] for group in self.optimizer.param_groups] 33 | return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups] 34 | 35 | 36 | def get_position_from_periods(iteration, cumulative_period): 37 | """Get the position from a period list. 38 | 39 | It will return the index of the right-closest number in the period list. 40 | For example, the cumulative_period = [100, 200, 300, 400], 41 | if iteration == 50, return 0; 42 | if iteration == 210, return 2; 43 | if iteration == 300, return 2. 44 | 45 | Args: 46 | iteration (int): Current iteration. 47 | cumulative_period (list[int]): Cumulative period list. 48 | 49 | Returns: 50 | int: The position of the right-closest number in the period list. 51 | """ 52 | for i, period in enumerate(cumulative_period): 53 | if iteration <= period: 54 | return i 55 | 56 | 57 | class CosineAnnealingRestartLR(_LRScheduler): 58 | """ Cosine annealing with restarts learning rate scheme. 59 | 60 | An example of config: 61 | periods = [10, 10, 10, 10] 62 | restart_weights = [1, 0.5, 0.5, 0.5] 63 | eta_min=1e-7 64 | 65 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 66 | scheduler will restart with the weights in restart_weights. 67 | 68 | Args: 69 | optimizer (torch.nn.optimizer): Torch optimizer. 70 | periods (list): Period for each cosine anneling cycle. 71 | restart_weights (list): Restart weights at each restart iteration. 72 | Default: [1]. 73 | eta_min (float): The minimum lr. Default: 0. 74 | last_epoch (int): Used in _LRScheduler. Default: -1. 75 | """ 76 | 77 | def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1): 78 | self.periods = periods 79 | self.restart_weights = restart_weights 80 | self.eta_min = eta_min 81 | assert (len(self.periods) == len( 82 | self.restart_weights)), 'periods and restart_weights should have the same length.' 83 | self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))] 84 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 85 | 86 | def get_lr(self): 87 | idx = get_position_from_periods(self.last_epoch, self.cumulative_period) 88 | current_weight = self.restart_weights[idx] 89 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 90 | current_period = self.periods[idx] 91 | 92 | return [ 93 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 94 | (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period))) 95 | for base_lr in self.base_lrs 96 | ] 97 | -------------------------------------------------------------------------------- /basicsr/ops/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/MMA/256cead19ff9540f0364f562c83486eec13c8a4a/basicsr/ops/__init__.py -------------------------------------------------------------------------------- /basicsr/ops/dcn/__init__.py: -------------------------------------------------------------------------------- 1 | from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv, 2 | modulated_deform_conv) 3 | 4 | __all__ = [ 5 | 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv', 6 | 'modulated_deform_conv' 7 | ] 8 | -------------------------------------------------------------------------------- /basicsr/ops/dcn/src/deform_conv_ext.cpp: -------------------------------------------------------------------------------- 1 | // modify from 2 | // https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c 3 | 4 | #include 5 | #include 6 | 7 | #include 8 | #include 9 | 10 | #define WITH_CUDA // always use cuda 11 | #ifdef WITH_CUDA 12 | int deform_conv_forward_cuda(at::Tensor input, at::Tensor weight, 13 | at::Tensor offset, at::Tensor output, 14 | at::Tensor columns, at::Tensor ones, int kW, 15 | int kH, int dW, int dH, int padW, int padH, 16 | int dilationW, int dilationH, int group, 17 | int deformable_group, int im2col_step); 18 | 19 | int deform_conv_backward_input_cuda(at::Tensor input, at::Tensor offset, 20 | at::Tensor gradOutput, at::Tensor gradInput, 21 | at::Tensor gradOffset, at::Tensor weight, 22 | at::Tensor columns, int kW, int kH, int dW, 23 | int dH, int padW, int padH, int dilationW, 24 | int dilationH, int group, 25 | int deformable_group, int im2col_step); 26 | 27 | int deform_conv_backward_parameters_cuda( 28 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 29 | at::Tensor gradWeight, // at::Tensor gradBias, 30 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 31 | int padW, int padH, int dilationW, int dilationH, int group, 32 | int deformable_group, float scale, int im2col_step); 33 | 34 | void modulated_deform_conv_cuda_forward( 35 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 36 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 37 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 38 | const int pad_h, const int pad_w, const int dilation_h, 39 | const int dilation_w, const int group, const int deformable_group, 40 | const bool with_bias); 41 | 42 | void modulated_deform_conv_cuda_backward( 43 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 44 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 45 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 46 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 47 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 48 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 49 | const bool with_bias); 50 | #endif 51 | 52 | int deform_conv_forward(at::Tensor input, at::Tensor weight, 53 | at::Tensor offset, at::Tensor output, 54 | at::Tensor columns, at::Tensor ones, int kW, 55 | int kH, int dW, int dH, int padW, int padH, 56 | int dilationW, int dilationH, int group, 57 | int deformable_group, int im2col_step) { 58 | if (input.device().is_cuda()) { 59 | #ifdef WITH_CUDA 60 | return deform_conv_forward_cuda(input, weight, offset, output, columns, 61 | ones, kW, kH, dW, dH, padW, padH, dilationW, dilationH, group, 62 | deformable_group, im2col_step); 63 | #else 64 | AT_ERROR("deform conv is not compiled with GPU support"); 65 | #endif 66 | } 67 | AT_ERROR("deform conv is not implemented on CPU"); 68 | } 69 | 70 | int deform_conv_backward_input(at::Tensor input, at::Tensor offset, 71 | at::Tensor gradOutput, at::Tensor gradInput, 72 | at::Tensor gradOffset, at::Tensor weight, 73 | at::Tensor columns, int kW, int kH, int dW, 74 | int dH, int padW, int padH, int dilationW, 75 | int dilationH, int group, 76 | int deformable_group, int im2col_step) { 77 | if (input.device().is_cuda()) { 78 | #ifdef WITH_CUDA 79 | return deform_conv_backward_input_cuda(input, offset, gradOutput, 80 | gradInput, gradOffset, weight, columns, kW, kH, dW, dH, padW, padH, 81 | dilationW, dilationH, group, deformable_group, im2col_step); 82 | #else 83 | AT_ERROR("deform conv is not compiled with GPU support"); 84 | #endif 85 | } 86 | AT_ERROR("deform conv is not implemented on CPU"); 87 | } 88 | 89 | int deform_conv_backward_parameters( 90 | at::Tensor input, at::Tensor offset, at::Tensor gradOutput, 91 | at::Tensor gradWeight, // at::Tensor gradBias, 92 | at::Tensor columns, at::Tensor ones, int kW, int kH, int dW, int dH, 93 | int padW, int padH, int dilationW, int dilationH, int group, 94 | int deformable_group, float scale, int im2col_step) { 95 | if (input.device().is_cuda()) { 96 | #ifdef WITH_CUDA 97 | return deform_conv_backward_parameters_cuda(input, offset, gradOutput, 98 | gradWeight, columns, ones, kW, kH, dW, dH, padW, padH, dilationW, 99 | dilationH, group, deformable_group, scale, im2col_step); 100 | #else 101 | AT_ERROR("deform conv is not compiled with GPU support"); 102 | #endif 103 | } 104 | AT_ERROR("deform conv is not implemented on CPU"); 105 | } 106 | 107 | void modulated_deform_conv_forward( 108 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 109 | at::Tensor offset, at::Tensor mask, at::Tensor output, at::Tensor columns, 110 | int kernel_h, int kernel_w, const int stride_h, const int stride_w, 111 | const int pad_h, const int pad_w, const int dilation_h, 112 | const int dilation_w, const int group, const int deformable_group, 113 | const bool with_bias) { 114 | if (input.device().is_cuda()) { 115 | #ifdef WITH_CUDA 116 | return modulated_deform_conv_cuda_forward(input, weight, bias, ones, 117 | offset, mask, output, columns, kernel_h, kernel_w, stride_h, 118 | stride_w, pad_h, pad_w, dilation_h, dilation_w, group, 119 | deformable_group, with_bias); 120 | #else 121 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 122 | #endif 123 | } 124 | AT_ERROR("modulated deform conv is not implemented on CPU"); 125 | } 126 | 127 | void modulated_deform_conv_backward( 128 | at::Tensor input, at::Tensor weight, at::Tensor bias, at::Tensor ones, 129 | at::Tensor offset, at::Tensor mask, at::Tensor columns, 130 | at::Tensor grad_input, at::Tensor grad_weight, at::Tensor grad_bias, 131 | at::Tensor grad_offset, at::Tensor grad_mask, at::Tensor grad_output, 132 | int kernel_h, int kernel_w, int stride_h, int stride_w, int pad_h, 133 | int pad_w, int dilation_h, int dilation_w, int group, int deformable_group, 134 | const bool with_bias) { 135 | if (input.device().is_cuda()) { 136 | #ifdef WITH_CUDA 137 | return modulated_deform_conv_cuda_backward(input, weight, bias, ones, 138 | offset, mask, columns, grad_input, grad_weight, grad_bias, grad_offset, 139 | grad_mask, grad_output, kernel_h, kernel_w, stride_h, stride_w, 140 | pad_h, pad_w, dilation_h, dilation_w, group, deformable_group, 141 | with_bias); 142 | #else 143 | AT_ERROR("modulated deform conv is not compiled with GPU support"); 144 | #endif 145 | } 146 | AT_ERROR("modulated deform conv is not implemented on CPU"); 147 | } 148 | 149 | 150 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 151 | m.def("deform_conv_forward", &deform_conv_forward, 152 | "deform forward"); 153 | m.def("deform_conv_backward_input", &deform_conv_backward_input, 154 | "deform_conv_backward_input"); 155 | m.def("deform_conv_backward_parameters", 156 | &deform_conv_backward_parameters, 157 | "deform_conv_backward_parameters"); 158 | m.def("modulated_deform_conv_forward", 159 | &modulated_deform_conv_forward, 160 | "modulated deform conv forward"); 161 | m.def("modulated_deform_conv_backward", 162 | &modulated_deform_conv_backward, 163 | "modulated deform conv backward"); 164 | } 165 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | 3 | __all__ = ['FusedLeakyReLU', 'fused_leaky_relu'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/fused_act.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch import nn 6 | from torch.autograd import Function 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | fused_act_ext = load( 13 | 'fused', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'fused_bias_act.cpp'), 16 | os.path.join(module_path, 'src', 'fused_bias_act_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import fused_act_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class FusedLeakyReLUFunctionBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, out, negative_slope, scale): 34 | ctx.save_for_backward(out) 35 | ctx.negative_slope = negative_slope 36 | ctx.scale = scale 37 | 38 | empty = grad_output.new_empty(0) 39 | 40 | grad_input = fused_act_ext.fused_bias_act(grad_output, empty, out, 3, 1, negative_slope, scale) 41 | 42 | dim = [0] 43 | 44 | if grad_input.ndim > 2: 45 | dim += list(range(2, grad_input.ndim)) 46 | 47 | grad_bias = grad_input.sum(dim).detach() 48 | 49 | return grad_input, grad_bias 50 | 51 | @staticmethod 52 | def backward(ctx, gradgrad_input, gradgrad_bias): 53 | out, = ctx.saved_tensors 54 | gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, 55 | ctx.scale) 56 | 57 | return gradgrad_out, None, None, None 58 | 59 | 60 | class FusedLeakyReLUFunction(Function): 61 | 62 | @staticmethod 63 | def forward(ctx, input, bias, negative_slope, scale): 64 | empty = input.new_empty(0) 65 | out = fused_act_ext.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 66 | ctx.save_for_backward(out) 67 | ctx.negative_slope = negative_slope 68 | ctx.scale = scale 69 | 70 | return out 71 | 72 | @staticmethod 73 | def backward(ctx, grad_output): 74 | out, = ctx.saved_tensors 75 | 76 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale) 77 | 78 | return grad_input, grad_bias, None, None 79 | 80 | 81 | class FusedLeakyReLU(nn.Module): 82 | 83 | def __init__(self, channel, negative_slope=0.2, scale=2**0.5): 84 | super().__init__() 85 | 86 | self.bias = nn.Parameter(torch.zeros(channel)) 87 | self.negative_slope = negative_slope 88 | self.scale = scale 89 | 90 | def forward(self, input): 91 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 92 | 93 | 94 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5): 95 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 96 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act.cpp 2 | #include 3 | 4 | 5 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, 6 | const torch::Tensor& bias, 7 | const torch::Tensor& refer, 8 | int act, int grad, float alpha, float scale); 9 | 10 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 11 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 12 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 13 | 14 | torch::Tensor fused_bias_act(const torch::Tensor& input, 15 | const torch::Tensor& bias, 16 | const torch::Tensor& refer, 17 | int act, int grad, float alpha, float scale) { 18 | CHECK_CUDA(input); 19 | CHECK_CUDA(bias); 20 | 21 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 22 | } 23 | 24 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 25 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 26 | } 27 | -------------------------------------------------------------------------------- /basicsr/ops/fused_act/src/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_bias_act_kernel.cu 2 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 3 | // 4 | // This work is made available under the Nvidia Source Code License-NC. 5 | // To view a copy of this license, visit 6 | // https://nvlabs.github.io/stylegan2/license.html 7 | 8 | #include 9 | 10 | #include 11 | #include 12 | #include 13 | #include 14 | 15 | #include 16 | #include 17 | 18 | 19 | template 20 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 21 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 22 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 23 | 24 | scalar_t zero = 0.0; 25 | 26 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 27 | scalar_t x = p_x[xi]; 28 | 29 | if (use_bias) { 30 | x += p_b[(xi / step_b) % size_b]; 31 | } 32 | 33 | scalar_t ref = use_ref ? p_ref[xi] : zero; 34 | 35 | scalar_t y; 36 | 37 | switch (act * 10 + grad) { 38 | default: 39 | case 10: y = x; break; 40 | case 11: y = x; break; 41 | case 12: y = 0.0; break; 42 | 43 | case 30: y = (x > 0.0) ? x : x * alpha; break; 44 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 45 | case 32: y = 0.0; break; 46 | } 47 | 48 | out[xi] = y * scale; 49 | } 50 | } 51 | 52 | 53 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 54 | int act, int grad, float alpha, float scale) { 55 | int curDevice = -1; 56 | cudaGetDevice(&curDevice); 57 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 58 | 59 | auto x = input.contiguous(); 60 | auto b = bias.contiguous(); 61 | auto ref = refer.contiguous(); 62 | 63 | int use_bias = b.numel() ? 1 : 0; 64 | int use_ref = ref.numel() ? 1 : 0; 65 | 66 | int size_x = x.numel(); 67 | int size_b = b.numel(); 68 | int step_b = 1; 69 | 70 | for (int i = 1 + 1; i < x.dim(); i++) { 71 | step_b *= x.size(i); 72 | } 73 | 74 | int loop_x = 4; 75 | int block_size = 4 * 32; 76 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 77 | 78 | auto y = torch::empty_like(x); 79 | 80 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 81 | fused_bias_act_kernel<<>>( 82 | y.data_ptr(), 83 | x.data_ptr(), 84 | b.data_ptr(), 85 | ref.data_ptr(), 86 | act, 87 | grad, 88 | alpha, 89 | scale, 90 | loop_x, 91 | size_x, 92 | step_b, 93 | size_b, 94 | use_bias, 95 | use_ref 96 | ); 97 | }); 98 | 99 | return y; 100 | } 101 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/__init__.py: -------------------------------------------------------------------------------- 1 | from .upfirdn2d import upfirdn2d 2 | 3 | __all__ = ['upfirdn2d'] 4 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/src/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.cpp 2 | #include 3 | 4 | 5 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 6 | int up_x, int up_y, int down_x, int down_y, 7 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 8 | 9 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 10 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 11 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 12 | 13 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 14 | int up_x, int up_y, int down_x, int down_y, 15 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 16 | CHECK_CUDA(input); 17 | CHECK_CUDA(kernel); 18 | 19 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 20 | } 21 | 22 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 23 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 24 | } 25 | -------------------------------------------------------------------------------- /basicsr/ops/upfirdn2d/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501 2 | 3 | import os 4 | import torch 5 | from torch.autograd import Function 6 | from torch.nn import functional as F 7 | 8 | BASICSR_JIT = os.getenv('BASICSR_JIT') 9 | if BASICSR_JIT == 'True': 10 | from torch.utils.cpp_extension import load 11 | module_path = os.path.dirname(__file__) 12 | upfirdn2d_ext = load( 13 | 'upfirdn2d', 14 | sources=[ 15 | os.path.join(module_path, 'src', 'upfirdn2d.cpp'), 16 | os.path.join(module_path, 'src', 'upfirdn2d_kernel.cu'), 17 | ], 18 | ) 19 | else: 20 | try: 21 | from . import upfirdn2d_ext 22 | except ImportError: 23 | pass 24 | # avoid annoying print output 25 | # print(f'Cannot import deform_conv_ext. Error: {error}. You may need to: \n ' 26 | # '1. compile with BASICSR_EXT=True. or\n ' 27 | # '2. set BASICSR_JIT=True during running') 28 | 29 | 30 | class UpFirDn2dBackward(Function): 31 | 32 | @staticmethod 33 | def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size): 34 | 35 | up_x, up_y = up 36 | down_x, down_y = down 37 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 38 | 39 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 40 | 41 | grad_input = upfirdn2d_ext.upfirdn2d( 42 | grad_output, 43 | grad_kernel, 44 | down_x, 45 | down_y, 46 | up_x, 47 | up_y, 48 | g_pad_x0, 49 | g_pad_x1, 50 | g_pad_y0, 51 | g_pad_y1, 52 | ) 53 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 54 | 55 | ctx.save_for_backward(kernel) 56 | 57 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 58 | 59 | ctx.up_x = up_x 60 | ctx.up_y = up_y 61 | ctx.down_x = down_x 62 | ctx.down_y = down_y 63 | ctx.pad_x0 = pad_x0 64 | ctx.pad_x1 = pad_x1 65 | ctx.pad_y0 = pad_y0 66 | ctx.pad_y1 = pad_y1 67 | ctx.in_size = in_size 68 | ctx.out_size = out_size 69 | 70 | return grad_input 71 | 72 | @staticmethod 73 | def backward(ctx, gradgrad_input): 74 | kernel, = ctx.saved_tensors 75 | 76 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 77 | 78 | gradgrad_out = upfirdn2d_ext.upfirdn2d( 79 | gradgrad_input, 80 | kernel, 81 | ctx.up_x, 82 | ctx.up_y, 83 | ctx.down_x, 84 | ctx.down_y, 85 | ctx.pad_x0, 86 | ctx.pad_x1, 87 | ctx.pad_y0, 88 | ctx.pad_y1, 89 | ) 90 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], 91 | # ctx.out_size[1], ctx.in_size[3]) 92 | gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]) 93 | 94 | return gradgrad_out, None, None, None, None, None, None, None, None 95 | 96 | 97 | class UpFirDn2d(Function): 98 | 99 | @staticmethod 100 | def forward(ctx, input, kernel, up, down, pad): 101 | up_x, up_y = up 102 | down_x, down_y = down 103 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 104 | 105 | kernel_h, kernel_w = kernel.shape 106 | _, channel, in_h, in_w = input.shape 107 | ctx.in_size = input.shape 108 | 109 | input = input.reshape(-1, in_h, in_w, 1) 110 | 111 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 112 | 113 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 114 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 115 | ctx.out_size = (out_h, out_w) 116 | 117 | ctx.up = (up_x, up_y) 118 | ctx.down = (down_x, down_y) 119 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 120 | 121 | g_pad_x0 = kernel_w - pad_x0 - 1 122 | g_pad_y0 = kernel_h - pad_y0 - 1 123 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 124 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 125 | 126 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 127 | 128 | out = upfirdn2d_ext.upfirdn2d(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1) 129 | # out = out.view(major, out_h, out_w, minor) 130 | out = out.view(-1, channel, out_h, out_w) 131 | 132 | return out 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | kernel, grad_kernel = ctx.saved_tensors 137 | 138 | grad_input = UpFirDn2dBackward.apply( 139 | grad_output, 140 | kernel, 141 | grad_kernel, 142 | ctx.up, 143 | ctx.down, 144 | ctx.pad, 145 | ctx.g_pad, 146 | ctx.in_size, 147 | ctx.out_size, 148 | ) 149 | 150 | return grad_input, None, None, None, None 151 | 152 | 153 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 154 | if input.device.type == 'cpu': 155 | out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) 156 | else: 157 | out = UpFirDn2d.apply(input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])) 158 | 159 | return out 160 | 161 | 162 | def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1): 163 | _, channel, in_h, in_w = input.shape 164 | input = input.reshape(-1, in_h, in_w, 1) 165 | 166 | _, in_h, in_w, minor = input.shape 167 | kernel_h, kernel_w = kernel.shape 168 | 169 | out = input.view(-1, in_h, 1, in_w, 1, minor) 170 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 171 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 172 | 173 | out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]) 174 | out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ] 175 | 176 | out = out.permute(0, 3, 1, 2) 177 | out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]) 178 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 179 | out = F.conv2d(out, w) 180 | out = out.reshape( 181 | -1, 182 | minor, 183 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 184 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 185 | ) 186 | out = out.permute(0, 2, 3, 1) 187 | out = out[:, ::down_y, ::down_x, :] 188 | 189 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 190 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 191 | 192 | return out.view(-1, channel, out_h, out_w) 193 | -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str, parse_options 9 | 10 | 11 | def test_pipeline(root_path): 12 | # parse options, set distributed setting, set ramdom seed 13 | opt, _ = parse_options(root_path, is_train=False) 14 | 15 | torch.backends.cudnn.benchmark = True 16 | # torch.backends.cudnn.deterministic = True 17 | 18 | # mkdir and initialize loggers 19 | make_exp_dirs(opt) 20 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 21 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 22 | logger.info(get_env_info()) 23 | logger.info(dict2str(opt)) 24 | 25 | # create test dataset and dataloader 26 | test_loaders = [] 27 | for _, dataset_opt in sorted(opt['datasets'].items()): 28 | test_set = build_dataset(dataset_opt) 29 | test_loader = build_dataloader( 30 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 31 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 32 | test_loaders.append(test_loader) 33 | 34 | # create model 35 | model = build_model(opt) 36 | 37 | for test_loader in test_loaders: 38 | test_set_name = test_loader.dataset.opt['name'] 39 | logger.info(f'Testing {test_set_name}...') 40 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 41 | 42 | 43 | if __name__ == '__main__': 44 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 45 | test_pipeline(root_path) 46 | -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .color_util import bgr2ycbcr, rgb2ycbcr, rgb2ycbcr_pt, ycbcr2bgr, ycbcr2rgb 2 | from .diffjpeg import DiffJPEG 3 | from .file_client import FileClient 4 | from .img_process_util import USMSharp, usm_sharp 5 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img 6 | from .logger import AvgTimer, MessageLogger, get_env_info, get_root_logger, init_tb_logger, init_wandb_logger 7 | from .misc import check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, scandir, set_random_seed, sizeof_fmt 8 | from .options import yaml_load 9 | 10 | __all__ = [ 11 | # color_util.py 12 | 'bgr2ycbcr', 13 | 'rgb2ycbcr', 14 | 'rgb2ycbcr_pt', 15 | 'ycbcr2bgr', 16 | 'ycbcr2rgb', 17 | # file_client.py 18 | 'FileClient', 19 | # img_util.py 20 | 'img2tensor', 21 | 'tensor2img', 22 | 'imfrombytes', 23 | 'imwrite', 24 | 'crop_border', 25 | # logger.py 26 | 'MessageLogger', 27 | 'AvgTimer', 28 | 'init_tb_logger', 29 | 'init_wandb_logger', 30 | 'get_root_logger', 31 | 'get_env_info', 32 | # misc.py 33 | 'set_random_seed', 34 | 'get_time_str', 35 | 'mkdir_and_rename', 36 | 'make_exp_dirs', 37 | 'scandir', 38 | 'check_resume', 39 | 'sizeof_fmt', 40 | # diffjpeg 41 | 'DiffJPEG', 42 | # img_process_util 43 | 'USMSharp', 44 | 'usm_sharp', 45 | # options 46 | 'yaml_load' 47 | ] 48 | -------------------------------------------------------------------------------- /basicsr/utils/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def rgb2ycbcr(img, y_only=False): 6 | """Convert a RGB image to YCbCr image. 7 | 8 | This function produces the same results as Matlab's `rgb2ycbcr` function. 9 | It implements the ITU-R BT.601 conversion for standard-definition 10 | television. See more details in 11 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 12 | 13 | It differs from a similar function in cv2.cvtColor: `RGB <-> YCrCb`. 14 | In OpenCV, it implements a JPEG conversion. See more details in 15 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 16 | 17 | Args: 18 | img (ndarray): The input image. It accepts: 19 | 1. np.uint8 type with range [0, 255]; 20 | 2. np.float32 type with range [0, 1]. 21 | y_only (bool): Whether to only return Y channel. Default: False. 22 | 23 | Returns: 24 | ndarray: The converted YCbCr image. The output image has the same type 25 | and range as input image. 26 | """ 27 | img_type = img.dtype 28 | img = _convert_input_type_range(img) 29 | if y_only: 30 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 31 | else: 32 | out_img = np.matmul( 33 | img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] 34 | out_img = _convert_output_type_range(out_img, img_type) 35 | return out_img 36 | 37 | 38 | def bgr2ycbcr(img, y_only=False): 39 | """Convert a BGR image to YCbCr image. 40 | 41 | The bgr version of rgb2ycbcr. 42 | It implements the ITU-R BT.601 conversion for standard-definition 43 | television. See more details in 44 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 45 | 46 | It differs from a similar function in cv2.cvtColor: `BGR <-> YCrCb`. 47 | In OpenCV, it implements a JPEG conversion. See more details in 48 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 49 | 50 | Args: 51 | img (ndarray): The input image. It accepts: 52 | 1. np.uint8 type with range [0, 255]; 53 | 2. np.float32 type with range [0, 1]. 54 | y_only (bool): Whether to only return Y channel. Default: False. 55 | 56 | Returns: 57 | ndarray: The converted YCbCr image. The output image has the same type 58 | and range as input image. 59 | """ 60 | img_type = img.dtype 61 | img = _convert_input_type_range(img) 62 | if y_only: 63 | out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0 64 | else: 65 | out_img = np.matmul( 66 | img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] 67 | out_img = _convert_output_type_range(out_img, img_type) 68 | return out_img 69 | 70 | 71 | def ycbcr2rgb(img): 72 | """Convert a YCbCr image to RGB image. 73 | 74 | This function produces the same results as Matlab's ycbcr2rgb function. 75 | It implements the ITU-R BT.601 conversion for standard-definition 76 | television. See more details in 77 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 78 | 79 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> RGB`. 80 | In OpenCV, it implements a JPEG conversion. See more details in 81 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 82 | 83 | Args: 84 | img (ndarray): The input image. It accepts: 85 | 1. np.uint8 type with range [0, 255]; 86 | 2. np.float32 type with range [0, 1]. 87 | 88 | Returns: 89 | ndarray: The converted RGB image. The output image has the same type 90 | and range as input image. 91 | """ 92 | img_type = img.dtype 93 | img = _convert_input_type_range(img) * 255 94 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 95 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126 96 | out_img = _convert_output_type_range(out_img, img_type) 97 | return out_img 98 | 99 | 100 | def ycbcr2bgr(img): 101 | """Convert a YCbCr image to BGR image. 102 | 103 | The bgr version of ycbcr2rgb. 104 | It implements the ITU-R BT.601 conversion for standard-definition 105 | television. See more details in 106 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 107 | 108 | It differs from a similar function in cv2.cvtColor: `YCrCb <-> BGR`. 109 | In OpenCV, it implements a JPEG conversion. See more details in 110 | https://en.wikipedia.org/wiki/YCbCr#JPEG_conversion. 111 | 112 | Args: 113 | img (ndarray): The input image. It accepts: 114 | 1. np.uint8 type with range [0, 255]; 115 | 2. np.float32 type with range [0, 1]. 116 | 117 | Returns: 118 | ndarray: The converted BGR image. The output image has the same type 119 | and range as input image. 120 | """ 121 | img_type = img.dtype 122 | img = _convert_input_type_range(img) * 255 123 | out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], 124 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126 125 | out_img = _convert_output_type_range(out_img, img_type) 126 | return out_img 127 | 128 | 129 | def _convert_input_type_range(img): 130 | """Convert the type and range of the input image. 131 | 132 | It converts the input image to np.float32 type and range of [0, 1]. 133 | It is mainly used for pre-processing the input image in colorspace 134 | conversion functions such as rgb2ycbcr and ycbcr2rgb. 135 | 136 | Args: 137 | img (ndarray): The input image. It accepts: 138 | 1. np.uint8 type with range [0, 255]; 139 | 2. np.float32 type with range [0, 1]. 140 | 141 | Returns: 142 | (ndarray): The converted image with type of np.float32 and range of 143 | [0, 1]. 144 | """ 145 | img_type = img.dtype 146 | img = img.astype(np.float32) 147 | if img_type == np.float32: 148 | pass 149 | elif img_type == np.uint8: 150 | img /= 255. 151 | else: 152 | raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}') 153 | return img 154 | 155 | 156 | def _convert_output_type_range(img, dst_type): 157 | """Convert the type and range of the image according to dst_type. 158 | 159 | It converts the image to desired type and range. If `dst_type` is np.uint8, 160 | images will be converted to np.uint8 type with range [0, 255]. If 161 | `dst_type` is np.float32, it converts the image to np.float32 type with 162 | range [0, 1]. 163 | It is mainly used for post-processing images in colorspace conversion 164 | functions such as rgb2ycbcr and ycbcr2rgb. 165 | 166 | Args: 167 | img (ndarray): The image to be converted with np.float32 type and 168 | range [0, 255]. 169 | dst_type (np.uint8 | np.float32): If dst_type is np.uint8, it 170 | converts the image to np.uint8 type with range [0, 255]. If 171 | dst_type is np.float32, it converts the image to np.float32 type 172 | with range [0, 1]. 173 | 174 | Returns: 175 | (ndarray): The converted image with desired type and range. 176 | """ 177 | if dst_type not in (np.uint8, np.float32): 178 | raise TypeError(f'The dst_type should be np.float32 or np.uint8, but got {dst_type}') 179 | if dst_type == np.uint8: 180 | img = img.round() 181 | else: 182 | img /= 255. 183 | return img.astype(dst_type) 184 | 185 | 186 | def rgb2ycbcr_pt(img, y_only=False): 187 | """Convert RGB images to YCbCr images (PyTorch version). 188 | 189 | It implements the ITU-R BT.601 conversion for standard-definition television. See more details in 190 | https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion. 191 | 192 | Args: 193 | img (Tensor): Images with shape (n, 3, h, w), the range [0, 1], float, RGB format. 194 | y_only (bool): Whether to only return Y channel. Default: False. 195 | 196 | Returns: 197 | (Tensor): converted images with the shape (n, 3/1, h, w), the range [0, 1], float. 198 | """ 199 | if y_only: 200 | weight = torch.tensor([[65.481], [128.553], [24.966]]).to(img) 201 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + 16.0 202 | else: 203 | weight = torch.tensor([[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]).to(img) 204 | bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img) 205 | out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias 206 | 207 | out_img = out_img / 255. 208 | return out_img 209 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput(f'scontrol show hostname {node_list} | head -n1') 45 | # specify master port 46 | if port is not None: 47 | os.environ['MASTER_PORT'] = str(port) 48 | elif 'MASTER_PORT' in os.environ: 49 | pass # use MASTER_PORT in the environment variable 50 | else: 51 | # 29500 is torch.distributed default port 52 | os.environ['MASTER_PORT'] = '29500' 53 | os.environ['MASTER_ADDR'] = addr 54 | os.environ['WORLD_SIZE'] = str(ntasks) 55 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 56 | os.environ['RANK'] = str(proc_id) 57 | dist.init_process_group(backend=backend) 58 | 59 | 60 | def get_dist_info(): 61 | if dist.is_available(): 62 | initialized = dist.is_initialized() 63 | else: 64 | initialized = False 65 | if initialized: 66 | rank = dist.get_rank() 67 | world_size = dist.get_world_size() 68 | else: 69 | rank = 0 70 | world_size = 1 71 | return rank, world_size 72 | 73 | 74 | def master_only(func): 75 | 76 | @functools.wraps(func) 77 | def wrapper(*args, **kwargs): 78 | rank, _ = get_dist_info() 79 | if rank == 0: 80 | return func(*args, **kwargs) 81 | 82 | return wrapper 83 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import requests 4 | from torch.hub import download_url_to_file, get_dir 5 | from tqdm import tqdm 6 | from urllib.parse import urlparse 7 | 8 | from .misc import sizeof_fmt 9 | 10 | 11 | def download_file_from_google_drive(file_id, save_path): 12 | """Download files from google drive. 13 | 14 | Reference: https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive 15 | 16 | Args: 17 | file_id (str): File id. 18 | save_path (str): Save path. 19 | """ 20 | 21 | session = requests.Session() 22 | URL = 'https://docs.google.com/uc?export=download' 23 | params = {'id': file_id} 24 | 25 | response = session.get(URL, params=params, stream=True) 26 | token = get_confirm_token(response) 27 | if token: 28 | params['confirm'] = token 29 | response = session.get(URL, params=params, stream=True) 30 | 31 | # get file size 32 | response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 33 | if 'Content-Range' in response_file_size.headers: 34 | file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, destination, file_size=None, chunk_size=32768): 49 | if file_size is not None: 50 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 51 | 52 | readable_file_size = sizeof_fmt(file_size) 53 | else: 54 | pbar = None 55 | 56 | with open(destination, 'wb') as f: 57 | downloaded_size = 0 58 | for chunk in response.iter_content(chunk_size): 59 | downloaded_size += chunk_size 60 | if pbar is not None: 61 | pbar.update(1) 62 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') 63 | if chunk: # filter out keep-alive new chunks 64 | f.write(chunk) 65 | if pbar is not None: 66 | pbar.close() 67 | 68 | 69 | def load_file_from_url(url, model_dir=None, progress=True, file_name=None): 70 | """Load file form http url, will download models if necessary. 71 | 72 | Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py 73 | 74 | Args: 75 | url (str): URL to be downloaded. 76 | model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. 77 | Default: None. 78 | progress (bool): Whether to show the download progress. Default: True. 79 | file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. 80 | 81 | Returns: 82 | str: The path to the downloaded file. 83 | """ 84 | if model_dir is None: # use the pytorch hub_dir 85 | hub_dir = get_dir() 86 | model_dir = os.path.join(hub_dir, 'checkpoints') 87 | 88 | os.makedirs(model_dir, exist_ok=True) 89 | 90 | parts = urlparse(url) 91 | filename = os.path.basename(parts.path) 92 | if file_name is not None: 93 | filename = file_name 94 | cached_file = os.path.abspath(os.path.join(model_dir, filename)) 95 | if not os.path.exists(cached_file): 96 | print(f'Downloading: "{url}" to {cached_file}\n') 97 | download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) 98 | return cached_file 99 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError('Please install memcached to enable MemcachedBackend.') 40 | 41 | self.server_list_cfg = server_list_cfg 42 | self.client_cfg = client_cfg 43 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) 44 | # mc.pyvector servers as a point which points to a memory cache 45 | self._mc_buffer = mc.pyvector() 46 | 47 | def get(self, filepath): 48 | filepath = str(filepath) 49 | import mc 50 | self._client.Get(filepath, self._mc_buffer) 51 | value_buf = mc.ConvertBuffer(self._mc_buffer) 52 | return value_buf 53 | 54 | def get_text(self, filepath): 55 | raise NotImplementedError 56 | 57 | 58 | class HardDiskBackend(BaseStorageBackend): 59 | """Raw hard disks storage backend.""" 60 | 61 | def get(self, filepath): 62 | filepath = str(filepath) 63 | with open(filepath, 'rb') as f: 64 | value_buf = f.read() 65 | return value_buf 66 | 67 | def get_text(self, filepath): 68 | filepath = str(filepath) 69 | with open(filepath, 'r') as f: 70 | value_buf = f.read() 71 | return value_buf 72 | 73 | 74 | class LmdbBackend(BaseStorageBackend): 75 | """Lmdb storage backend. 76 | 77 | Args: 78 | db_paths (str | list[str]): Lmdb database paths. 79 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 80 | readonly (bool, optional): Lmdb environment parameter. If True, 81 | disallow any write operations. Default: True. 82 | lock (bool, optional): Lmdb environment parameter. If False, when 83 | concurrent access occurs, do not lock the database. Default: False. 84 | readahead (bool, optional): Lmdb environment parameter. If False, 85 | disable the OS filesystem readahead mechanism, which may improve 86 | random read performance when a database is larger than RAM. 87 | Default: False. 88 | 89 | Attributes: 90 | db_paths (list): Lmdb database path. 91 | _client (list): A list of several lmdb envs. 92 | """ 93 | 94 | def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): 95 | try: 96 | import lmdb 97 | except ImportError: 98 | raise ImportError('Please install lmdb to enable LmdbBackend.') 99 | 100 | if isinstance(client_keys, str): 101 | client_keys = [client_keys] 102 | 103 | if isinstance(db_paths, list): 104 | self.db_paths = [str(v) for v in db_paths] 105 | elif isinstance(db_paths, str): 106 | self.db_paths = [str(db_paths)] 107 | assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' 108 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 109 | 110 | self._client = {} 111 | for client, path in zip(client_keys, self.db_paths): 112 | self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) 113 | 114 | def get(self, filepath, client_key): 115 | """Get values according to the filepath from one lmdb named client_key. 116 | 117 | Args: 118 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 119 | client_key (str): Used for distinguishing different lmdb envs. 120 | """ 121 | filepath = str(filepath) 122 | assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.') 123 | client = self._client[client_key] 124 | with client.begin(write=False) as txn: 125 | value_buf = txn.get(filepath.encode('ascii')) 126 | return value_buf 127 | 128 | def get_text(self, filepath): 129 | raise NotImplementedError 130 | 131 | 132 | class FileClient(object): 133 | """A general file client to access files in different backend. 134 | 135 | The client loads a file or text in a specified backend from its path 136 | and return it as a binary file. it can also register other backend 137 | accessor with a given name and backend class. 138 | 139 | Attributes: 140 | backend (str): The storage backend type. Options are "disk", 141 | "memcached" and "lmdb". 142 | client (:obj:`BaseStorageBackend`): The backend object. 143 | """ 144 | 145 | _backends = { 146 | 'disk': HardDiskBackend, 147 | 'memcached': MemcachedBackend, 148 | 'lmdb': LmdbBackend, 149 | } 150 | 151 | def __init__(self, backend='disk', **kwargs): 152 | if backend not in self._backends: 153 | raise ValueError(f'Backend {backend} is not supported. Currently supported ones' 154 | f' are {list(self._backends.keys())}') 155 | self.backend = backend 156 | self.client = self._backends[backend](**kwargs) 157 | 158 | def get(self, filepath, client_key='default'): 159 | # client_key is used only for lmdb, where different fileclients have 160 | # different lmdb environments. 161 | if self.backend == 'lmdb': 162 | return self.client.get(filepath, client_key) 163 | else: 164 | return self.client.get(filepath) 165 | 166 | def get_text(self, filepath): 167 | return self.client.get_text(filepath) 168 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') 25 | assert cat_flow.shape[concat_axis] % 2 == 0 26 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 27 | flow = dequantize_flow(dx, dy, *args, **kwargs) 28 | else: 29 | with open(flow_path, 'rb') as f: 30 | try: 31 | header = f.read(4).decode('utf-8') 32 | except Exception: 33 | raise IOError(f'Invalid flow file: {flow_path}') 34 | else: 35 | if header != 'PIEH': 36 | raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') 37 | 38 | w = np.fromfile(f, np.int32, 1).squeeze() 39 | h = np.fromfile(f, np.int32, 1).squeeze() 40 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 41 | 42 | return flow.astype(np.float32) 43 | 44 | 45 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 46 | """Write optical flow to file. 47 | 48 | If the flow is not quantized, it will be saved as a .flo file losslessly, 49 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 50 | will be concatenated horizontally into a single image if quantize is True.) 51 | 52 | Args: 53 | flow (ndarray): (h, w, 2) array of optical flow. 54 | filename (str): Output filepath. 55 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 56 | images. If set to True, remaining args will be passed to 57 | :func:`quantize_flow`. 58 | concat_axis (int): The axis that dx and dy are concatenated, 59 | can be either 0 or 1. Ignored if quantize is False. 60 | """ 61 | if not quantize: 62 | with open(filename, 'wb') as f: 63 | f.write('PIEH'.encode('utf-8')) 64 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 65 | flow = flow.astype(np.float32) 66 | flow.tofile(f) 67 | f.flush() 68 | else: 69 | assert concat_axis in [0, 1] 70 | dx, dy = quantize_flow(flow, *args, **kwargs) 71 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 72 | os.makedirs(os.path.dirname(filename), exist_ok=True) 73 | cv2.imwrite(filename, dxdy) 74 | 75 | 76 | def quantize_flow(flow, max_val=0.02, norm=True): 77 | """Quantize flow to [0, 255]. 78 | 79 | After this step, the size of flow will be much smaller, and can be 80 | dumped as jpeg images. 81 | 82 | Args: 83 | flow (ndarray): (h, w, 2) array of optical flow. 84 | max_val (float): Maximum value of flow, values beyond 85 | [-max_val, max_val] will be truncated. 86 | norm (bool): Whether to divide flow values by image width/height. 87 | 88 | Returns: 89 | tuple[ndarray]: Quantized dx and dy. 90 | """ 91 | h, w, _ = flow.shape 92 | dx = flow[..., 0] 93 | dy = flow[..., 1] 94 | if norm: 95 | dx = dx / w # avoid inplace operations 96 | dy = dy / h 97 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 98 | flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] 99 | return tuple(flow_comps) 100 | 101 | 102 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 103 | """Recover from quantized flow. 104 | 105 | Args: 106 | dx (ndarray): Quantized dx. 107 | dy (ndarray): Quantized dy. 108 | max_val (float): Maximum value used when quantizing. 109 | denorm (bool): Whether to multiply flow values with width/height. 110 | 111 | Returns: 112 | ndarray: Dequantized flow. 113 | """ 114 | assert dx.shape == dy.shape 115 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 116 | 117 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 118 | 119 | if denorm: 120 | dx *= dx.shape[1] 121 | dy *= dx.shape[0] 122 | flow = np.dstack((dx, dy)) 123 | return flow 124 | 125 | 126 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 127 | """Quantize an array of (-inf, inf) to [0, levels-1]. 128 | 129 | Args: 130 | arr (ndarray): Input array. 131 | min_val (scalar): Minimum value to be clipped. 132 | max_val (scalar): Maximum value to be clipped. 133 | levels (int): Quantization levels. 134 | dtype (np.type): The type of the quantized array. 135 | 136 | Returns: 137 | tuple: Quantized array. 138 | """ 139 | if not (isinstance(levels, int) and levels > 1): 140 | raise ValueError(f'levels must be a positive integer, but got {levels}') 141 | if min_val >= max_val: 142 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 143 | 144 | arr = np.clip(arr, min_val, max_val) - min_val 145 | quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 146 | 147 | return quantized_arr 148 | 149 | 150 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 151 | """Dequantize an array. 152 | 153 | Args: 154 | arr (ndarray): Input array. 155 | min_val (scalar): Minimum value to be clipped. 156 | max_val (scalar): Maximum value to be clipped. 157 | levels (int): Quantization levels. 158 | dtype (np.type): The type of the dequantized array. 159 | 160 | Returns: 161 | tuple: Dequantized array. 162 | """ 163 | if not (isinstance(levels, int) and levels > 1): 164 | raise ValueError(f'levels must be a positive integer, but got {levels}') 165 | if min_val >= max_val: 166 | raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') 167 | 168 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val 169 | 170 | return dequantized_arr 171 | -------------------------------------------------------------------------------- /basicsr/utils/img_process_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | from torch.nn import functional as F 5 | 6 | 7 | def filter2D(img, kernel): 8 | """PyTorch version of cv2.filter2D 9 | 10 | Args: 11 | img (Tensor): (b, c, h, w) 12 | kernel (Tensor): (b, k, k) 13 | """ 14 | k = kernel.size(-1) 15 | b, c, h, w = img.size() 16 | if k % 2 == 1: 17 | img = F.pad(img, (k // 2, k // 2, k // 2, k // 2), mode='reflect') 18 | else: 19 | raise ValueError('Wrong kernel size') 20 | 21 | ph, pw = img.size()[-2:] 22 | 23 | if kernel.size(0) == 1: 24 | # apply the same kernel to all batch images 25 | img = img.view(b * c, 1, ph, pw) 26 | kernel = kernel.view(1, 1, k, k) 27 | return F.conv2d(img, kernel, padding=0).view(b, c, h, w) 28 | else: 29 | img = img.view(1, b * c, ph, pw) 30 | kernel = kernel.view(b, 1, k, k).repeat(1, c, 1, 1).view(b * c, 1, k, k) 31 | return F.conv2d(img, kernel, groups=b * c).view(b, c, h, w) 32 | 33 | 34 | def usm_sharp(img, weight=0.5, radius=50, threshold=10): 35 | """USM sharpening. 36 | 37 | Input image: I; Blurry image: B. 38 | 1. sharp = I + weight * (I - B) 39 | 2. Mask = 1 if abs(I - B) > threshold, else: 0 40 | 3. Blur mask: 41 | 4. Out = Mask * sharp + (1 - Mask) * I 42 | 43 | 44 | Args: 45 | img (Numpy array): Input image, HWC, BGR; float32, [0, 1]. 46 | weight (float): Sharp weight. Default: 1. 47 | radius (float): Kernel size of Gaussian blur. Default: 50. 48 | threshold (int): 49 | """ 50 | if radius % 2 == 0: 51 | radius += 1 52 | blur = cv2.GaussianBlur(img, (radius, radius), 0) 53 | residual = img - blur 54 | mask = np.abs(residual) * 255 > threshold 55 | mask = mask.astype('float32') 56 | soft_mask = cv2.GaussianBlur(mask, (radius, radius), 0) 57 | 58 | sharp = img + weight * residual 59 | sharp = np.clip(sharp, 0, 1) 60 | return soft_mask * sharp + (1 - soft_mask) * img 61 | 62 | 63 | class USMSharp(torch.nn.Module): 64 | 65 | def __init__(self, radius=50, sigma=0): 66 | super(USMSharp, self).__init__() 67 | if radius % 2 == 0: 68 | radius += 1 69 | self.radius = radius 70 | kernel = cv2.getGaussianKernel(radius, sigma) 71 | kernel = torch.FloatTensor(np.dot(kernel, kernel.transpose())).unsqueeze_(0) 72 | self.register_buffer('kernel', kernel) 73 | 74 | def forward(self, img, weight=0.5, threshold=10): 75 | blur = filter2D(img, self.kernel) 76 | residual = img - blur 77 | 78 | mask = torch.abs(residual) * 255 > threshold 79 | mask = mask.float() 80 | soft_mask = filter2D(mask, self.kernel) 81 | sharp = img + weight * residual 82 | sharp = torch.clip(sharp, 0, 1) 83 | return soft_mask * sharp + (1 - soft_mask) * img 84 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | if img.dtype == 'float64': 25 | img = img.astype('float32') 26 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 27 | img = torch.from_numpy(img.transpose(2, 0, 1)) 28 | if float32: 29 | img = img.float() 30 | return img 31 | 32 | if isinstance(imgs, list): 33 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 34 | else: 35 | return _totensor(imgs, bgr2rgb, float32) 36 | 37 | 38 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 39 | """Convert torch Tensors into image numpy arrays. 40 | 41 | After clamping to [min, max], values will be normalized to [0, 1]. 42 | 43 | Args: 44 | tensor (Tensor or list[Tensor]): Accept shapes: 45 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 46 | 2) 3D Tensor of shape (3/1 x H x W); 47 | 3) 2D Tensor of shape (H x W). 48 | Tensor channel should be in RGB order. 49 | rgb2bgr (bool): Whether to change rgb to bgr. 50 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 51 | to uint8 type with range [0, 255]; otherwise, float type with 52 | range [0, 1]. Default: ``np.uint8``. 53 | min_max (tuple[int]): min and max values for clamp. 54 | 55 | Returns: 56 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 57 | shape (H x W). The channel order is BGR. 58 | """ 59 | if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') 61 | 62 | if torch.is_tensor(tensor): 63 | tensor = [tensor] 64 | result = [] 65 | for _tensor in tensor: 66 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 67 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 68 | 69 | n_dim = _tensor.dim() 70 | if n_dim == 4: 71 | img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() 72 | img_np = img_np.transpose(1, 2, 0) 73 | if rgb2bgr: 74 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 75 | elif n_dim == 3: 76 | img_np = _tensor.numpy() 77 | img_np = img_np.transpose(1, 2, 0) 78 | if img_np.shape[2] == 1: # gray image 79 | img_np = np.squeeze(img_np, axis=2) 80 | else: 81 | if rgb2bgr: 82 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 83 | elif n_dim == 2: 84 | img_np = _tensor.numpy() 85 | else: 86 | raise TypeError(f'Only support 4D, 3D or 2D tensor. But received with dimension: {n_dim}') 87 | if out_type == np.uint8: 88 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 89 | img_np = (img_np * 255.0).round() 90 | img_np = img_np.astype(out_type) 91 | result.append(img_np) 92 | if len(result) == 1: 93 | result = result[0] 94 | return result 95 | 96 | 97 | def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): 98 | """This implementation is slightly faster than tensor2img. 99 | It now only supports torch tensor with shape (1, c, h, w). 100 | 101 | Args: 102 | tensor (Tensor): Now only support torch tensor with (1, c, h, w). 103 | rgb2bgr (bool): Whether to change rgb to bgr. Default: True. 104 | min_max (tuple[int]): min and max values for clamp. 105 | """ 106 | output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) 107 | output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 108 | output = output.type(torch.uint8).cpu().numpy() 109 | if rgb2bgr: 110 | output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) 111 | return output 112 | 113 | 114 | def imfrombytes(content, flag='color', float32=False): 115 | """Read an image from bytes. 116 | 117 | Args: 118 | content (bytes): Image bytes got from files or other streams. 119 | flag (str): Flags specifying the color type of a loaded image, 120 | candidates are `color`, `grayscale` and `unchanged`. 121 | float32 (bool): Whether to change to float32., If True, will also norm 122 | to [0, 1]. Default: False. 123 | 124 | Returns: 125 | ndarray: Loaded image array. 126 | """ 127 | img_np = np.frombuffer(content, np.uint8) 128 | imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} 129 | img = cv2.imdecode(img_np, imread_flags[flag]) 130 | if float32: 131 | img = img.astype(np.float32) / 255. 132 | return img 133 | 134 | 135 | def imwrite(img, file_path, params=None, auto_mkdir=True): 136 | """Write image to file. 137 | 138 | Args: 139 | img (ndarray): Image array to be written. 140 | file_path (str): Image file path. 141 | params (None or list): Same as opencv's :func:`imwrite` interface. 142 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 143 | whether to create it automatically. 144 | 145 | Returns: 146 | bool: Successful or not. 147 | """ 148 | if auto_mkdir: 149 | dir_name = os.path.abspath(os.path.dirname(file_path)) 150 | os.makedirs(dir_name, exist_ok=True) 151 | ok = cv2.imwrite(file_path, img, params) 152 | if not ok: 153 | raise IOError('Failed in writing images.') 154 | 155 | 156 | def crop_border(imgs, crop_border): 157 | """Crop borders of images. 158 | 159 | Args: 160 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 161 | crop_border (int): Crop border for each end of height and weight. 162 | 163 | Returns: 164 | list[ndarray]: Cropped images. 165 | """ 166 | if crop_border == 0: 167 | return imgs 168 | else: 169 | if isinstance(imgs, list): 170 | return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] 171 | else: 172 | return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] 173 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | 22 | :: 23 | 24 | example.lmdb 25 | ├── data.mdb 26 | ├── lock.mdb 27 | ├── meta_info.txt 28 | 29 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 30 | https://lmdb.readthedocs.io/en/release/ for more details. 31 | 32 | The meta_info.txt is a specified txt file to record the meta information 33 | of our datasets. It will be automatically created when preparing 34 | datasets by our provided dataset tools. 35 | Each line in the txt file records 1)image name (with extension), 36 | 2)image shape, and 3)compression level, separated by a white space. 37 | 38 | For example, the meta information could be: 39 | `000_00000000.png (720,1280,3) 1`, which means: 40 | 1) image name (with extension): 000_00000000.png; 41 | 2) image shape: (720,1280,3); 42 | 3) compression level: 1 43 | 44 | We use the image name without extension as the lmdb key. 45 | 46 | If `multiprocessing_read` is True, it will read all the images to memory 47 | using multiprocessing. Thus, your server needs to have enough memory. 48 | 49 | Args: 50 | data_path (str): Data path for reading images. 51 | lmdb_path (str): Lmdb save path. 52 | img_path_list (str): Image path list. 53 | keys (str): Used for lmdb keys. 54 | batch (int): After processing batch images, lmdb commits. 55 | Default: 5000. 56 | compress_level (int): Compress level when encoding images. Default: 1. 57 | multiprocessing_read (bool): Whether use multiprocessing to read all 58 | the images to memory. Default: False. 59 | n_thread (int): For multiprocessing. 60 | map_size (int | None): Map size for lmdb env. If None, use the 61 | estimated size from images. Default: None 62 | """ 63 | 64 | assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, ' 65 | f'but got {len(img_path_list)} and {len(keys)}') 66 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 67 | print(f'Totoal images: {len(img_path_list)}') 68 | if not lmdb_path.endswith('.lmdb'): 69 | raise ValueError("lmdb_path must end with '.lmdb'.") 70 | if osp.exists(lmdb_path): 71 | print(f'Folder {lmdb_path} already exists. Exit.') 72 | sys.exit(1) 73 | 74 | if multiprocessing_read: 75 | # read all the images to memory (multiprocessing) 76 | dataset = {} # use dict to keep the order for multiprocessing 77 | shapes = {} 78 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 79 | pbar = tqdm(total=len(img_path_list), unit='image') 80 | 81 | def callback(arg): 82 | """get the image data and update pbar.""" 83 | key, dataset[key], shapes[key] = arg 84 | pbar.update(1) 85 | pbar.set_description(f'Read {key}') 86 | 87 | pool = Pool(n_thread) 88 | for path, key in zip(img_path_list, keys): 89 | pool.apply_async(read_img_worker, args=(osp.join(data_path, path), key, compress_level), callback=callback) 90 | pool.close() 91 | pool.join() 92 | pbar.close() 93 | print(f'Finish reading {len(img_path_list)} images.') 94 | 95 | # create lmdb environment 96 | if map_size is None: 97 | # obtain data size for one image 98 | img = cv2.imread(osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 99 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 100 | data_size_per_img = img_byte.nbytes 101 | print('Data size per image is: ', data_size_per_img) 102 | data_size = data_size_per_img * len(img_path_list) 103 | map_size = data_size * 10 104 | 105 | env = lmdb.open(lmdb_path, map_size=map_size) 106 | 107 | # write data to lmdb 108 | pbar = tqdm(total=len(img_path_list), unit='chunk') 109 | txn = env.begin(write=True) 110 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 111 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 112 | pbar.update(1) 113 | pbar.set_description(f'Write {key}') 114 | key_byte = key.encode('ascii') 115 | if multiprocessing_read: 116 | img_byte = dataset[key] 117 | h, w, c = shapes[key] 118 | else: 119 | _, img_byte, img_shape = read_img_worker(osp.join(data_path, path), key, compress_level) 120 | h, w, c = img_shape 121 | 122 | txn.put(key_byte, img_byte) 123 | # write meta information 124 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 125 | if idx % batch == 0: 126 | txn.commit() 127 | txn = env.begin(write=True) 128 | pbar.close() 129 | txn.commit() 130 | env.close() 131 | txt_file.close() 132 | print('\nFinish writing lmdb.') 133 | 134 | 135 | def read_img_worker(path, key, compress_level): 136 | """Read image worker. 137 | 138 | Args: 139 | path (str): Image path. 140 | key (str): Image key. 141 | compress_level (int): Compress level when encoding images. 142 | 143 | Returns: 144 | str: Image key. 145 | byte: Image byte. 146 | tuple[int]: Image shape. 147 | """ 148 | 149 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 150 | if img.ndim == 2: 151 | h, w = img.shape 152 | c = 1 153 | else: 154 | h, w, c = img.shape 155 | _, img_byte = cv2.imencode('.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 156 | return (key, img_byte, (h, w, c)) 157 | 158 | 159 | class LmdbMaker(): 160 | """LMDB Maker. 161 | 162 | Args: 163 | lmdb_path (str): Lmdb save path. 164 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 165 | batch (int): After processing batch images, lmdb commits. 166 | Default: 5000. 167 | compress_level (int): Compress level when encoding images. Default: 1. 168 | """ 169 | 170 | def __init__(self, lmdb_path, map_size=1024**4, batch=5000, compress_level=1): 171 | if not lmdb_path.endswith('.lmdb'): 172 | raise ValueError("lmdb_path must end with '.lmdb'.") 173 | if osp.exists(lmdb_path): 174 | print(f'Folder {lmdb_path} already exists. Exit.') 175 | sys.exit(1) 176 | 177 | self.lmdb_path = lmdb_path 178 | self.batch = batch 179 | self.compress_level = compress_level 180 | self.env = lmdb.open(lmdb_path, map_size=map_size) 181 | self.txn = self.env.begin(write=True) 182 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 183 | self.counter = 0 184 | 185 | def put(self, img_byte, key, img_shape): 186 | self.counter += 1 187 | key_byte = key.encode('ascii') 188 | self.txn.put(key_byte, img_byte) 189 | # write meta information 190 | h, w, c = img_shape 191 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 192 | if self.counter % self.batch == 0: 193 | self.txn.commit() 194 | self.txn = self.env.begin(write=True) 195 | 196 | def close(self): 197 | self.txn.commit() 198 | self.env.close() 199 | self.txt_file.close() 200 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class AvgTimer(): 11 | 12 | def __init__(self, window=200): 13 | self.window = window # average window 14 | self.current_time = 0 15 | self.total_time = 0 16 | self.count = 0 17 | self.avg_time = 0 18 | self.start() 19 | 20 | def start(self): 21 | self.start_time = self.tic = time.time() 22 | 23 | def record(self): 24 | self.count += 1 25 | self.toc = time.time() 26 | self.current_time = self.toc - self.tic 27 | self.total_time += self.current_time 28 | # calculate average time 29 | self.avg_time = self.total_time / self.count 30 | 31 | # reset 32 | if self.count > self.window: 33 | self.count = 0 34 | self.total_time = 0 35 | 36 | self.tic = time.time() 37 | 38 | def get_current_time(self): 39 | return self.current_time 40 | 41 | def get_avg_time(self): 42 | return self.avg_time 43 | 44 | 45 | class MessageLogger(): 46 | """Message logger for printing. 47 | 48 | Args: 49 | opt (dict): Config. It contains the following keys: 50 | name (str): Exp name. 51 | logger (dict): Contains 'print_freq' (str) for logger interval. 52 | train (dict): Contains 'total_iter' (int) for total iters. 53 | use_tb_logger (bool): Use tensorboard logger. 54 | start_iter (int): Start iter. Default: 1. 55 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 56 | """ 57 | 58 | def __init__(self, opt, start_iter=1, tb_logger=None): 59 | self.exp_name = opt['name'] 60 | self.interval = opt['logger']['print_freq'] 61 | self.start_iter = start_iter 62 | self.max_iters = opt['train']['total_iter'] 63 | self.use_tb_logger = opt['logger']['use_tb_logger'] 64 | self.tb_logger = tb_logger 65 | self.start_time = time.time() 66 | self.logger = get_root_logger() 67 | 68 | def reset_start_time(self): 69 | self.start_time = time.time() 70 | 71 | @master_only 72 | def __call__(self, log_vars): 73 | """Format logging message. 74 | 75 | Args: 76 | log_vars (dict): It contains the following keys: 77 | epoch (int): Epoch number. 78 | iter (int): Current iter. 79 | lrs (list): List for learning rates. 80 | 81 | time (float): Iter time. 82 | data_time (float): Data time for each iter. 83 | """ 84 | # epoch, iter, learning rates 85 | epoch = log_vars.pop('epoch') 86 | current_iter = log_vars.pop('iter') 87 | lrs = log_vars.pop('lrs') 88 | 89 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') 90 | for v in lrs: 91 | message += f'{v:.3e},' 92 | message += ')] ' 93 | 94 | # time and estimated time 95 | if 'time' in log_vars.keys(): 96 | iter_time = log_vars.pop('time') 97 | data_time = log_vars.pop('data_time') 98 | 99 | total_time = time.time() - self.start_time 100 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 101 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 102 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 103 | message += f'[eta: {eta_str}, ' 104 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 105 | 106 | # other items, especially losses 107 | for k, v in log_vars.items(): 108 | message += f'{k}: {v:.4e} ' 109 | # tensorboard logger 110 | if self.use_tb_logger and 'debug' not in self.exp_name: 111 | if k.startswith('l_'): 112 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 113 | else: 114 | self.tb_logger.add_scalar(k, v, current_iter) 115 | self.logger.info(message) 116 | 117 | 118 | @master_only 119 | def init_tb_logger(log_dir): 120 | from torch.utils.tensorboard import SummaryWriter 121 | tb_logger = SummaryWriter(log_dir=log_dir) 122 | return tb_logger 123 | 124 | 125 | @master_only 126 | def init_wandb_logger(opt): 127 | """We now only use wandb to sync tensorboard log.""" 128 | import wandb 129 | logger = get_root_logger() 130 | 131 | project = opt['logger']['wandb']['project'] 132 | resume_id = opt['logger']['wandb'].get('resume_id') 133 | if resume_id: 134 | wandb_id = resume_id 135 | resume = 'allow' 136 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 137 | else: 138 | wandb_id = wandb.util.generate_id() 139 | resume = 'never' 140 | 141 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 142 | 143 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 144 | 145 | 146 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 147 | """Get the root logger. 148 | 149 | The logger will be initialized if it has not been initialized. By default a 150 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 151 | also be added. 152 | 153 | Args: 154 | logger_name (str): root logger name. Default: 'basicsr'. 155 | log_file (str | None): The log filename. If specified, a FileHandler 156 | will be added to the root logger. 157 | log_level (int): The root logger level. Note that only the process of 158 | rank 0 is affected, while other processes will set the level to 159 | "Error" and be silent most of the time. 160 | 161 | Returns: 162 | logging.Logger: The root logger. 163 | """ 164 | logger = logging.getLogger(logger_name) 165 | # if the logger has been initialized, just return it 166 | if logger_name in initialized_logger: 167 | return logger 168 | 169 | format_str = '%(asctime)s %(levelname)s: %(message)s' 170 | stream_handler = logging.StreamHandler() 171 | stream_handler.setFormatter(logging.Formatter(format_str)) 172 | logger.addHandler(stream_handler) 173 | logger.propagate = False 174 | rank, _ = get_dist_info() 175 | if rank != 0: 176 | logger.setLevel('ERROR') 177 | elif log_file is not None: 178 | logger.setLevel(log_level) 179 | # add file handler 180 | file_handler = logging.FileHandler(log_file, 'w') 181 | file_handler.setFormatter(logging.Formatter(format_str)) 182 | file_handler.setLevel(log_level) 183 | logger.addHandler(file_handler) 184 | initialized_logger[logger_name] = True 185 | return logger 186 | 187 | 188 | def get_env_info(): 189 | """Get environment information. 190 | 191 | Currently, only log the software version. 192 | """ 193 | import torch 194 | import torchvision 195 | 196 | from basicsr.version import __version__ 197 | msg = r""" 198 | ____ _ _____ ____ 199 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 200 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 201 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 202 | /_____/ \__,_//____//_/ \___//____//_/ |_| 203 | ______ __ __ __ __ 204 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 205 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 206 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 207 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 208 | """ 209 | msg += ('\nVersion Information: ' 210 | f'\n\tBasicSR: {__version__}' 211 | f'\n\tPyTorch: {torch.__version__}' 212 | f'\n\tTorchVision: {torchvision.__version__}') 213 | return msg 214 | -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | 10 | 11 | def set_random_seed(seed): 12 | """Set random seeds.""" 13 | random.seed(seed) 14 | np.random.seed(seed) 15 | torch.manual_seed(seed) 16 | torch.cuda.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | 20 | def get_time_str(): 21 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 22 | 23 | 24 | def mkdir_and_rename(path): 25 | """mkdirs. If path exists, rename it with timestamp and create a new one. 26 | 27 | Args: 28 | path (str): Folder path. 29 | """ 30 | if osp.exists(path): 31 | new_name = path + '_archived_' + get_time_str() 32 | print(f'Path already exists. Rename it to {new_name}', flush=True) 33 | os.rename(path, new_name) 34 | os.makedirs(path, exist_ok=True) 35 | 36 | 37 | @master_only 38 | def make_exp_dirs(opt): 39 | """Make dirs for experiments.""" 40 | path_opt = opt['path'].copy() 41 | if opt['is_train']: 42 | mkdir_and_rename(path_opt.pop('experiments_root')) 43 | else: 44 | mkdir_and_rename(path_opt.pop('results_root')) 45 | for key, path in path_opt.items(): 46 | if ('strict_load' in key) or ('pretrain_network' in key) or ('resume' in key) or ('param_key' in key): 47 | continue 48 | else: 49 | os.makedirs(path, exist_ok=True) 50 | 51 | 52 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 53 | """Scan a directory to find the interested files. 54 | 55 | Args: 56 | dir_path (str): Path of the directory. 57 | suffix (str | tuple(str), optional): File suffix that we are 58 | interested in. Default: None. 59 | recursive (bool, optional): If set to True, recursively scan the 60 | directory. Default: False. 61 | full_path (bool, optional): If set to True, include the dir_path. 62 | Default: False. 63 | 64 | Returns: 65 | A generator for all the interested files with relative paths. 66 | """ 67 | 68 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 69 | raise TypeError('"suffix" must be a string or tuple of strings') 70 | 71 | root = dir_path 72 | 73 | def _scandir(dir_path, suffix, recursive): 74 | for entry in os.scandir(dir_path): 75 | if not entry.name.startswith('.') and entry.is_file(): 76 | if full_path: 77 | return_path = entry.path 78 | else: 79 | return_path = osp.relpath(entry.path, root) 80 | 81 | if suffix is None: 82 | yield return_path 83 | elif return_path.endswith(suffix): 84 | yield return_path 85 | else: 86 | if recursive: 87 | yield from _scandir(entry.path, suffix=suffix, recursive=recursive) 88 | else: 89 | continue 90 | 91 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 92 | 93 | 94 | def check_resume(opt, resume_iter): 95 | """Check resume states and pretrain_network paths. 96 | 97 | Args: 98 | opt (dict): Options. 99 | resume_iter (int): Resume iteration. 100 | """ 101 | if opt['path']['resume_state']: 102 | # get all the networks 103 | networks = [key for key in opt.keys() if key.startswith('network_')] 104 | flag_pretrain = False 105 | for network in networks: 106 | if opt['path'].get(f'pretrain_{network}') is not None: 107 | flag_pretrain = True 108 | if flag_pretrain: 109 | print('pretrain_network path will be ignored during resuming.') 110 | # set pretrained model paths 111 | for network in networks: 112 | name = f'pretrain_{network}' 113 | basename = network.replace('network_', '') 114 | if opt['path'].get('ignore_resume_networks') is None or (network 115 | not in opt['path']['ignore_resume_networks']): 116 | opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 117 | print(f"Set {name} to {opt['path'][name]}") 118 | 119 | # change param_key to params in resume 120 | param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')] 121 | for param_key in param_keys: 122 | if opt['path'][param_key] == 'params_ema': 123 | opt['path'][param_key] = 'params' 124 | print(f'Set {param_key} to params') 125 | 126 | 127 | def sizeof_fmt(size, suffix='B'): 128 | """Get human readable file size. 129 | 130 | Args: 131 | size (int): File size. 132 | suffix (str): Suffix. Default: 'B'. 133 | 134 | Return: 135 | str: Formatted file size. 136 | """ 137 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 138 | if abs(size) < 1024.0: 139 | return f'{size:3.1f} {unit}{suffix}' 140 | size /= 1024.0 141 | return f'{size:3.1f} Y{suffix}' 142 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import random 4 | import torch 5 | import yaml 6 | from collections import OrderedDict 7 | from os import path as osp 8 | 9 | from basicsr.utils import set_random_seed 10 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 11 | 12 | 13 | def ordered_yaml(): 14 | """Support OrderedDict for yaml. 15 | 16 | Returns: 17 | tuple: yaml Loader and Dumper. 18 | """ 19 | try: 20 | from yaml import CDumper as Dumper 21 | from yaml import CLoader as Loader 22 | except ImportError: 23 | from yaml import Dumper, Loader 24 | 25 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 26 | 27 | def dict_representer(dumper, data): 28 | return dumper.represent_dict(data.items()) 29 | 30 | def dict_constructor(loader, node): 31 | return OrderedDict(loader.construct_pairs(node)) 32 | 33 | Dumper.add_representer(OrderedDict, dict_representer) 34 | Loader.add_constructor(_mapping_tag, dict_constructor) 35 | return Loader, Dumper 36 | 37 | 38 | def yaml_load(f): 39 | """Load yaml file or string. 40 | 41 | Args: 42 | f (str): File path or a python string. 43 | 44 | Returns: 45 | dict: Loaded dict. 46 | """ 47 | if os.path.isfile(f): 48 | with open(f, 'r') as f: 49 | return yaml.load(f, Loader=ordered_yaml()[0]) 50 | else: 51 | return yaml.load(f, Loader=ordered_yaml()[0]) 52 | 53 | 54 | def dict2str(opt, indent_level=1): 55 | """dict to string for printing options. 56 | 57 | Args: 58 | opt (dict): Option dict. 59 | indent_level (int): Indent level. Default: 1. 60 | 61 | Return: 62 | (str): Option string for printing. 63 | """ 64 | msg = '\n' 65 | for k, v in opt.items(): 66 | if isinstance(v, dict): 67 | msg += ' ' * (indent_level * 2) + k + ':[' 68 | msg += dict2str(v, indent_level + 1) 69 | msg += ' ' * (indent_level * 2) + ']\n' 70 | else: 71 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 72 | return msg 73 | 74 | 75 | def _postprocess_yml_value(value): 76 | # None 77 | if value == '~' or value.lower() == 'none': 78 | return None 79 | # bool 80 | if value.lower() == 'true': 81 | return True 82 | elif value.lower() == 'false': 83 | return False 84 | # !!float number 85 | if value.startswith('!!float'): 86 | return float(value.replace('!!float', '')) 87 | # number 88 | if value.isdigit(): 89 | return int(value) 90 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 91 | return float(value) 92 | # list 93 | if value.startswith('['): 94 | return eval(value) 95 | # str 96 | return value 97 | 98 | 99 | def parse_options(root_path, is_train=True): 100 | parser = argparse.ArgumentParser() 101 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 102 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 103 | parser.add_argument('--auto_resume', action='store_true') 104 | parser.add_argument('--debug', action='store_true') 105 | parser.add_argument('--local_rank', type=int, default=0) 106 | parser.add_argument( 107 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 108 | args = parser.parse_args() 109 | 110 | # parse yml to dict 111 | opt = yaml_load(args.opt) 112 | 113 | # distributed settings 114 | if args.launcher == 'none': 115 | opt['dist'] = False 116 | print('Disable distributed.', flush=True) 117 | else: 118 | opt['dist'] = True 119 | if args.launcher == 'slurm' and 'dist_params' in opt: 120 | init_dist(args.launcher, **opt['dist_params']) 121 | else: 122 | init_dist(args.launcher) 123 | opt['rank'], opt['world_size'] = get_dist_info() 124 | 125 | # random seed 126 | seed = opt.get('manual_seed') 127 | if seed is None: 128 | seed = random.randint(1, 10000) 129 | opt['manual_seed'] = seed 130 | set_random_seed(seed + opt['rank']) 131 | 132 | # force to update yml options 133 | if args.force_yml is not None: 134 | for entry in args.force_yml: 135 | # now do not support creating new keys 136 | keys, value = entry.split('=') 137 | keys, value = keys.strip(), value.strip() 138 | value = _postprocess_yml_value(value) 139 | eval_str = 'opt' 140 | for key in keys.split(':'): 141 | eval_str += f'["{key}"]' 142 | eval_str += '=value' 143 | # using exec function 144 | exec(eval_str) 145 | 146 | opt['auto_resume'] = args.auto_resume 147 | opt['is_train'] = is_train 148 | 149 | # debug setting 150 | if args.debug and not opt['name'].startswith('debug'): 151 | opt['name'] = 'debug_' + opt['name'] 152 | 153 | if opt['num_gpu'] == 'auto': 154 | opt['num_gpu'] = torch.cuda.device_count() 155 | 156 | # datasets 157 | for phase, dataset in opt['datasets'].items(): 158 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 159 | phase = phase.split('_')[0] 160 | dataset['phase'] = phase 161 | if 'scale' in opt: 162 | dataset['scale'] = opt['scale'] 163 | if dataset.get('dataroot_gt') is not None: 164 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 165 | if dataset.get('dataroot_lq') is not None: 166 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 167 | 168 | # paths 169 | for key, val in opt['path'].items(): 170 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 171 | opt['path'][key] = osp.expanduser(val) 172 | 173 | if is_train: 174 | experiments_root = opt['path'].get('experiments_root') 175 | if experiments_root is None: 176 | experiments_root = osp.join(root_path, 'experiments') 177 | experiments_root = osp.join(experiments_root, opt['name']) 178 | 179 | opt['path']['experiments_root'] = experiments_root 180 | opt['path']['models'] = osp.join(experiments_root, 'models') 181 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 182 | opt['path']['log'] = experiments_root 183 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 184 | 185 | # change some options for debug mode 186 | if 'debug' in opt['name']: 187 | if 'val' in opt: 188 | opt['val']['val_freq'] = 8 189 | opt['logger']['print_freq'] = 1 190 | opt['logger']['save_checkpoint_freq'] = 8 191 | else: # test 192 | results_root = opt['path'].get('results_root') 193 | if results_root is None: 194 | results_root = osp.join(root_path, 'results') 195 | results_root = osp.join(results_root, opt['name']) 196 | 197 | opt['path']['results_root'] = results_root 198 | opt['path']['log'] = results_root 199 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 200 | 201 | return opt, args 202 | 203 | 204 | @master_only 205 | def copy_opt_file(opt_file, experiments_root): 206 | # copy the yml file to the experiment root 207 | import sys 208 | import time 209 | from shutil import copyfile 210 | cmd = ' '.join(sys.argv) 211 | filename = osp.join(experiments_root, osp.basename(opt_file)) 212 | copyfile(opt_file, filename) 213 | 214 | with open(filename, 'r+') as f: 215 | lines = f.readlines() 216 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 217 | f.seek(0) 218 | f.writelines(lines) 219 | -------------------------------------------------------------------------------- /basicsr/utils/plot_util.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | 4 | def read_data_from_tensorboard(log_path, tag): 5 | """Get raw data (steps and values) from tensorboard events. 6 | 7 | Args: 8 | log_path (str): Path to the tensorboard log. 9 | tag (str): tag to be read. 10 | """ 11 | from tensorboard.backend.event_processing.event_accumulator import EventAccumulator 12 | 13 | # tensorboard event 14 | event_acc = EventAccumulator(log_path) 15 | event_acc.Reload() 16 | scalar_list = event_acc.Tags()['scalars'] 17 | print('tag list: ', scalar_list) 18 | steps = [int(s.step) for s in event_acc.Scalars(tag)] 19 | values = [s.value for s in event_acc.Scalars(tag)] 20 | return steps, values 21 | 22 | 23 | def read_data_from_txt_2v(path, pattern, step_one=False): 24 | """Read data from txt with 2 returned values (usually [step, value]). 25 | 26 | Args: 27 | path (str): path to the txt file. 28 | pattern (str): re (regular expression) pattern. 29 | step_one (bool): add 1 to steps. Default: False. 30 | """ 31 | with open(path) as f: 32 | lines = f.readlines() 33 | lines = [line.strip() for line in lines] 34 | steps = [] 35 | values = [] 36 | 37 | pattern = re.compile(pattern) 38 | for line in lines: 39 | match = pattern.match(line) 40 | if match: 41 | steps.append(int(match.group(1))) 42 | values.append(float(match.group(2))) 43 | if step_one: 44 | steps = [v + 1 for v in steps] 45 | return steps, values 46 | 47 | 48 | def read_data_from_txt_1v(path, pattern): 49 | """Read data from txt with 1 returned values. 50 | 51 | Args: 52 | path (str): path to the txt file. 53 | pattern (str): re (regular expression) pattern. 54 | """ 55 | with open(path) as f: 56 | lines = f.readlines() 57 | lines = [line.strip() for line in lines] 58 | data = [] 59 | 60 | pattern = re.compile(pattern) 61 | for line in lines: 62 | match = pattern.match(line) 63 | if match: 64 | data.append(float(match.group(1))) 65 | return data 66 | 67 | 68 | def smooth_data(values, smooth_weight): 69 | """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does). 70 | 71 | Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501 72 | 73 | Args: 74 | values (list): A list of values to be smoothed. 75 | smooth_weight (float): Smooth weight. 76 | """ 77 | values_sm = [] 78 | last_sm_value = values[0] 79 | for value in values: 80 | value_sm = last_sm_value * smooth_weight + (1 - smooth_weight) * value 81 | values_sm.append(value_sm) 82 | last_sm_value = value_sm 83 | return values_sm 84 | -------------------------------------------------------------------------------- /basicsr/utils/registry.py: -------------------------------------------------------------------------------- 1 | # Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501 2 | 3 | 4 | class Registry(): 5 | """ 6 | The registry that provides name -> object mapping, to support third-party 7 | users' custom modules. 8 | 9 | To create a registry (e.g. a backbone registry): 10 | 11 | .. code-block:: python 12 | 13 | BACKBONE_REGISTRY = Registry('BACKBONE') 14 | 15 | To register an object: 16 | 17 | .. code-block:: python 18 | 19 | @BACKBONE_REGISTRY.register() 20 | class MyBackbone(): 21 | ... 22 | 23 | Or: 24 | 25 | .. code-block:: python 26 | 27 | BACKBONE_REGISTRY.register(MyBackbone) 28 | """ 29 | 30 | def __init__(self, name): 31 | """ 32 | Args: 33 | name (str): the name of this registry 34 | """ 35 | self._name = name 36 | self._obj_map = {} 37 | 38 | def _do_register(self, name, obj, suffix=None): 39 | if isinstance(suffix, str): 40 | name = name + '_' + suffix 41 | 42 | assert (name not in self._obj_map), (f"An object named '{name}' was already registered " 43 | f"in '{self._name}' registry!") 44 | self._obj_map[name] = obj 45 | 46 | def register(self, obj=None, suffix=None): 47 | """ 48 | Register the given object under the the name `obj.__name__`. 49 | Can be used as either a decorator or not. 50 | See docstring of this class for usage. 51 | """ 52 | if obj is None: 53 | # used as a decorator 54 | def deco(func_or_class): 55 | name = func_or_class.__name__ 56 | self._do_register(name, func_or_class, suffix) 57 | return func_or_class 58 | 59 | return deco 60 | 61 | # used as a function call 62 | name = obj.__name__ 63 | self._do_register(name, obj, suffix) 64 | 65 | def get(self, name, suffix='basicsr'): 66 | ret = self._obj_map.get(name) 67 | if ret is None: 68 | ret = self._obj_map.get(name + '_' + suffix) 69 | print(f'Name {name} is not found, use name: {name}_{suffix}!') 70 | if ret is None: 71 | raise KeyError(f"No object named '{name}' found in '{self._name}' registry!") 72 | return ret 73 | 74 | def __contains__(self, name): 75 | return name in self._obj_map 76 | 77 | def __iter__(self): 78 | return iter(self._obj_map.items()) 79 | 80 | def keys(self): 81 | return self._obj_map.keys() 82 | 83 | 84 | DATASET_REGISTRY = Registry('dataset') 85 | ARCH_REGISTRY = Registry('arch') 86 | MODEL_REGISTRY = Registry('model') 87 | LOSS_REGISTRY = Registry('loss') 88 | METRIC_REGISTRY = Registry('metric') 89 | -------------------------------------------------------------------------------- /main.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/MMA/256cead19ff9540f0364f562c83486eec13c8a4a/main.jpg -------------------------------------------------------------------------------- /options/test/MMA/test_MMA_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x2_inference 3 | model_type: SRModel 4 | scale: 2 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test_1: 11 | name: Set5 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: ~ 22 | dataroot_lq: ~ 23 | io_backend: 24 | type: disk 25 | 26 | test_3: 27 | name: Urban100 28 | type: PairedImageDataset 29 | dataroot_gt: ~ 30 | dataroot_lq: ~ 31 | io_backend: 32 | type: disk 33 | 34 | test_4: 35 | name: BSD100 36 | type: PairedImageDataset 37 | dataroot_gt: ~ 38 | dataroot_lq: ~ 39 | io_backend: 40 | type: disk 41 | 42 | test_5: 43 | name: Manga109 44 | type: PairedImageDataset 45 | dataroot_gt: ~ 46 | dataroot_lq: ~ 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: MMA 53 | num_in_ch: 3 54 | num_out_ch: 3 55 | num_feat: 192 56 | num_block: 24 57 | upscale: 2 58 | img_range: 1. 59 | rgb_mean: [0.4488, 0.4371, 0.4040] 60 | 61 | # path 62 | path: 63 | pretrain_network_g: ~ 64 | strict_load_g: true 65 | resume_state: ~ 66 | 67 | # validation settings 68 | val: 69 | save_img: true 70 | suffix: ~ # add suffix to saved images, if None, use exp name 71 | 72 | metrics: 73 | psnr: # metric name, can be arbitrary 74 | type: calculate_psnr 75 | crop_border: 2 76 | test_y_channel: true 77 | ssim: 78 | type: calculate_ssim 79 | crop_border: 2 80 | test_y_channel: true 81 | -------------------------------------------------------------------------------- /options/test/MMA/test_MMA_x3.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x3_inference 3 | model_type: SRModel 4 | scale: 3 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test_1: 11 | name: Set5 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: ~ 22 | dataroot_lq: ~ 23 | io_backend: 24 | type: disk 25 | 26 | test_3: 27 | name: Urban100 28 | type: PairedImageDataset 29 | dataroot_gt: ~ 30 | dataroot_lq: ~ 31 | io_backend: 32 | type: disk 33 | 34 | test_4: 35 | name: BSD100 36 | type: PairedImageDataset 37 | dataroot_gt: ~ 38 | dataroot_lq: ~ 39 | io_backend: 40 | type: disk 41 | 42 | test_5: 43 | name: Manga109 44 | type: PairedImageDataset 45 | dataroot_gt: ~ 46 | dataroot_lq: ~ 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: MMA 53 | num_in_ch: 3 54 | num_out_ch: 3 55 | num_feat: 192 56 | num_block: 24 57 | upscale: 3 58 | res_scale: 0.1 59 | img_range: 1. 60 | rgb_mean: [0.4488, 0.4371, 0.4040] 61 | 62 | # path 63 | path: 64 | pretrain_network_g: ~ 65 | strict_load_g: true 66 | resume_state: ~ 67 | 68 | # validation settings 69 | val: 70 | save_img: true 71 | suffix: ~ # add suffix to saved images, if None, use exp name 72 | 73 | metrics: 74 | psnr: # metric name, can be arbitrary 75 | type: calculate_psnr 76 | crop_border: 3 77 | test_y_channel: true 78 | ssim: 79 | type: calculate_ssim 80 | crop_border: 3 81 | test_y_channel: true 82 | -------------------------------------------------------------------------------- /options/test/MMA/test_MMA_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x4_inference 3 | model_type: SRModel 4 | scale: 4 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test_1: 11 | name: Set5 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | test_2: 19 | name: Set14 20 | type: PairedImageDataset 21 | dataroot_gt: ~ 22 | dataroot_lq: ~ 23 | io_backend: 24 | type: disk 25 | 26 | test_3: 27 | name: Urban100 28 | type: PairedImageDataset 29 | dataroot_gt: ~ 30 | dataroot_lq: ~ 31 | io_backend: 32 | type: disk 33 | 34 | test_4: 35 | name: BSD100 36 | type: PairedImageDataset 37 | dataroot_gt: ~ 38 | dataroot_lq: ~ 39 | io_backend: 40 | type: disk 41 | 42 | test_5: 43 | name: Manga109 44 | type: PairedImageDataset 45 | dataroot_gt: ~ 46 | dataroot_lq: ~ 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: MMA 53 | num_in_ch: 3 54 | num_out_ch: 3 55 | num_feat: 192 56 | num_block: 24 57 | upscale: 4 58 | img_range: 1. 59 | rgb_mean: [0.4488, 0.4371, 0.4040] 60 | 61 | # path 62 | path: 63 | pretrain_network_g: ~ 64 | strict_load_g: true 65 | resume_state: ~ 66 | 67 | # validation settings 68 | val: 69 | save_img: true 70 | suffix: ~ # add suffix to saved images, if None, use exp name 71 | 72 | metrics: 73 | psnr: # metric name, can be arbitrary 74 | type: calculate_psnr 75 | crop_border: 4 76 | test_y_channel: true 77 | ssim: 78 | type: calculate_ssim 79 | crop_border: 4 80 | test_y_channel: true 81 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x2_finetune.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x2_finetune 3 | model_type: SRModel 4 | scale: 2 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: DIV2K 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | gt_size: 128 19 | use_hflip: true 20 | use_rot: true 21 | 22 | # data loader 23 | use_shuffle: true 24 | num_worker_per_gpu: 4 25 | batch_size_per_gpu: 4 26 | dataset_enlarge_ratio: 1 27 | prefetch_mode: cuda 28 | pin_memory: True 29 | 30 | val: 31 | name: Set5 32 | type: PairedImageDataset 33 | dataroot_gt: ~ 34 | dataroot_lq: ~ 35 | io_backend: 36 | type: disk 37 | 38 | # network structures 39 | network_g: 40 | type: MMA 41 | num_in_ch: 3 42 | num_out_ch: 3 43 | num_feat: 192 44 | num_block: 24 45 | upscale: 2 46 | img_range: 1. 47 | rgb_mean: [0.4488, 0.4371, 0.4040] 48 | 49 | # path 50 | path: 51 | pretrain_network_g: ~ 52 | strict_load_g: true 53 | resume_state: ~ 54 | 55 | # training settings 56 | train: 57 | ema_decay: 0.999 58 | optim_g: 59 | type: Adam 60 | lr: !!float 1e-5 61 | weight_decay: 0 62 | betas: [0.9, 0.99] 63 | 64 | scheduler: 65 | type: MultiStepLR 66 | milestones: [125000, 200000, 225000, 240000] 67 | gamma: 0.5 68 | 69 | total_iter: 250000 70 | warmup_iter: -1 # no warm up 71 | 72 | # losses 73 | pixel_opt: 74 | type: L1Loss 75 | loss_weight: 1.0 76 | reduction: mean 77 | 78 | # validation settings 79 | val: 80 | val_freq: !!float 5e3 81 | save_img: false 82 | 83 | metrics: 84 | psnr: # metric name, can be arbitrary 85 | type: calculate_psnr 86 | crop_border: 2 87 | test_y_channel: true 88 | 89 | ssim: # metric name, can be arbitrary 90 | type: calculate_ssim 91 | crop_border: 2 92 | test_y_channel: true 93 | 94 | # logging settings 95 | logger: 96 | print_freq: 100 97 | save_checkpoint_freq: !!float 5e4 98 | use_tb_logger: true 99 | wandb: 100 | project: ~ 101 | resume_id: ~ 102 | 103 | # dist training settings 104 | dist_params: 105 | backend: nccl 106 | port: 29500 107 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x2_pretrain.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x2_pretrain 3 | model_type: SRModel 4 | scale: 2 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: ImageNet 12 | type: ImageNetPairedDataset 13 | dataroot_gt: ~ 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 128 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 4 24 | batch_size_per_gpu: 4 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: cuda 27 | pin_memory: True 28 | 29 | val: 30 | name: Set5 31 | type: PairedImageDataset 32 | dataroot_gt: ~ 33 | dataroot_lq: ~ 34 | io_backend: 35 | type: disk 36 | 37 | # network structures 38 | network_g: 39 | type: MMA 40 | num_in_ch: 3 41 | num_out_ch: 3 42 | num_feat: 192 43 | num_block: 24 44 | upscale: 2 45 | img_range: 1. 46 | rgb_mean: [0.4488, 0.4371, 0.4040] 47 | 48 | # path 49 | path: 50 | pretrain_network_g: ~ 51 | strict_load_g: true 52 | resume_state: ~ 53 | 54 | # training settings 55 | train: 56 | ema_decay: 0.999 57 | optim_g: 58 | type: Adam 59 | lr: !!float 2e-4 60 | weight_decay: 0 61 | betas: [0.9, 0.99] 62 | 63 | scheduler: 64 | type: MultiStepLR 65 | milestones: [300000, 500000, 650000, 700000, 750000] 66 | gamma: 0.5 67 | 68 | total_iter: 800000 69 | warmup_iter: -1 # no warm up 70 | 71 | # losses 72 | pixel_opt: 73 | type: L1Loss 74 | loss_weight: 1.0 75 | reduction: mean 76 | 77 | # validation settings 78 | val: 79 | val_freq: !!float 5e3 80 | save_img: false 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 2 86 | test_y_channel: true 87 | 88 | ssim: # metric name, can be arbitrary 89 | type: calculate_ssim 90 | crop_border: 2 91 | test_y_channel: true 92 | 93 | # logging settings 94 | logger: 95 | print_freq: 100 96 | save_checkpoint_freq: !!float 5e4 97 | use_tb_logger: true 98 | wandb: 99 | project: ~ 100 | resume_id: ~ 101 | 102 | # dist training settings 103 | dist_params: 104 | backend: nccl 105 | port: 29500 106 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x3_finetune.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x3_Finetune 3 | model_type: SRModel 4 | scale: 3 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: DIV2K 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | gt_size: 192 19 | use_hflip: true 20 | use_rot: true 21 | 22 | # data loader 23 | use_shuffle: true 24 | num_worker_per_gpu: 4 25 | batch_size_per_gpu: 4 26 | dataset_enlarge_ratio: 1 27 | prefetch_mode: cuda 28 | pin_memory: True 29 | 30 | val: 31 | name: Set5 32 | type: PairedImageDataset 33 | dataroot_gt: ~ 34 | dataroot_lq: ~ 35 | io_backend: 36 | type: disk 37 | 38 | # network structures 39 | network_g: 40 | type: MMA 41 | num_in_ch: 3 42 | num_out_ch: 3 43 | num_feat: 192 44 | num_block: 24 45 | upscale: 3 46 | img_range: 1. 47 | rgb_mean: [0.4488, 0.4371, 0.4040] 48 | 49 | # path 50 | path: 51 | pretrain_network_g: ~ 52 | strict_load_g: true 53 | resume_state: ~ 54 | 55 | # training settings 56 | train: 57 | ema_decay: 0.999 58 | optim_g: 59 | type: Adam 60 | lr: !!float 1e-5 61 | weight_decay: 0 62 | betas: [0.9, 0.99] 63 | 64 | scheduler: 65 | type: MultiStepLR 66 | milestones: [125000, 200000, 225000, 240000] 67 | gamma: 0.5 68 | 69 | total_iter: 250000 70 | warmup_iter: -1 # no warm up 71 | 72 | # losses 73 | pixel_opt: 74 | type: L1Loss 75 | loss_weight: 1.0 76 | reduction: mean 77 | 78 | # validation settings 79 | val: 80 | val_freq: !!float 5e3 81 | save_img: false 82 | 83 | metrics: 84 | psnr: # metric name, can be arbitrary 85 | type: calculate_psnr 86 | crop_border: 3 87 | test_y_channel: true 88 | 89 | ssim: # metric name, can be arbitrary 90 | type: calculate_ssim 91 | crop_border: 3 92 | test_y_channel: true 93 | 94 | # logging settings 95 | logger: 96 | print_freq: 100 97 | save_checkpoint_freq: !!float 5e4 98 | use_tb_logger: true 99 | wandb: 100 | project: ~ 101 | resume_id: ~ 102 | 103 | # dist training settings 104 | dist_params: 105 | backend: nccl 106 | port: 29500 107 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x3_pretrain.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x3_pretrain 3 | model_type: SRModel 4 | scale: 3 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: ImageNet 12 | type: ImageNetPairedDataset 13 | dataroot_gt: ~ 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 192 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 4 24 | batch_size_per_gpu: 4 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: cuda 27 | pin_memory: True 28 | 29 | val: 30 | name: Set5 31 | type: PairedImageDataset 32 | dataroot_gt: ~ 33 | dataroot_lq: ~ 34 | io_backend: 35 | type: disk 36 | 37 | # network structures 38 | network_g: 39 | type: MMA 40 | num_in_ch: 3 41 | num_out_ch: 3 42 | num_feat: 192 43 | num_block: 24 44 | upscale: 3 45 | img_range: 1. 46 | rgb_mean: [0.4488, 0.4371, 0.4040] 47 | 48 | # path 49 | path: 50 | pretrain_network_g: ~ 51 | strict_load_g: true 52 | resume_state: ~ 53 | 54 | # training settings 55 | train: 56 | ema_decay: 0.999 57 | optim_g: 58 | type: Adam 59 | lr: !!float 2e-4 60 | weight_decay: 0 61 | betas: [0.9, 0.99] 62 | 63 | scheduler: 64 | type: MultiStepLR 65 | milestones: [300000, 500000, 650000, 700000, 750000] 66 | gamma: 0.5 67 | 68 | total_iter: 800000 69 | warmup_iter: -1 # no warm up 70 | 71 | # losses 72 | pixel_opt: 73 | type: L1Loss 74 | loss_weight: 1.0 75 | reduction: mean 76 | 77 | # validation settings 78 | val: 79 | val_freq: !!float 5e3 80 | save_img: false 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 3 86 | test_y_channel: true 87 | 88 | ssim: # metric name, can be arbitrary 89 | type: calculate_ssim 90 | crop_border: 3 91 | test_y_channel: true 92 | 93 | # logging settings 94 | logger: 95 | print_freq: 100 96 | save_checkpoint_freq: !!float 5e4 97 | use_tb_logger: true 98 | wandb: 99 | project: ~ 100 | resume_id: ~ 101 | 102 | # dist training settings 103 | dist_params: 104 | backend: nccl 105 | port: 29500 106 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x4_finetune.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x4_Finetune 3 | model_type: SRModel 4 | scale: 4 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: DIV2K 12 | type: PairedImageDataset 13 | dataroot_gt: ~ 14 | dataroot_lq: ~ 15 | io_backend: 16 | type: disk 17 | 18 | gt_size: 256 19 | use_hflip: true 20 | use_rot: true 21 | 22 | # data loader 23 | use_shuffle: true 24 | num_worker_per_gpu: 4 25 | batch_size_per_gpu: 4 26 | dataset_enlarge_ratio: 1 27 | prefetch_mode: cuda 28 | pin_memory: True 29 | 30 | val: 31 | name: Set5 32 | type: PairedImageDataset 33 | dataroot_gt: ~ 34 | dataroot_lq: ~ 35 | io_backend: 36 | type: disk 37 | 38 | # network structures 39 | network_g: 40 | type: MMA 41 | num_in_ch: 3 42 | num_out_ch: 3 43 | num_feat: 192 44 | num_block: 24 45 | upscale: 4 46 | img_range: 1. 47 | rgb_mean: [0.4488, 0.4371, 0.4040] 48 | 49 | # path 50 | path: 51 | pretrain_network_g: ~ 52 | strict_load_g: true 53 | resume_state: ~ 54 | 55 | # training settings 56 | train: 57 | ema_decay: 0.999 58 | optim_g: 59 | type: Adam 60 | lr: !!float 1e-5 61 | weight_decay: 0 62 | betas: [0.9, 0.99] 63 | 64 | scheduler: 65 | type: MultiStepLR 66 | milestones: [125000, 200000, 225000, 240000] 67 | gamma: 0.5 68 | 69 | total_iter: 250000 70 | warmup_iter: -1 # no warm up 71 | 72 | # losses 73 | pixel_opt: 74 | type: L1Loss 75 | loss_weight: 1.0 76 | reduction: mean 77 | 78 | # validation settings 79 | val: 80 | val_freq: !!float 5e3 81 | save_img: false 82 | 83 | metrics: 84 | psnr: # metric name, can be arbitrary 85 | type: calculate_psnr 86 | crop_border: 4 87 | test_y_channel: true 88 | 89 | ssim: # metric name, can be arbitrary 90 | type: calculate_ssim 91 | crop_border: 4 92 | test_y_channel: true 93 | 94 | # logging settings 95 | logger: 96 | print_freq: 100 97 | save_checkpoint_freq: !!float 5e4 98 | use_tb_logger: true 99 | wandb: 100 | project: ~ 101 | resume_id: ~ 102 | 103 | # dist training settings 104 | dist_params: 105 | backend: nccl 106 | port: 29500 107 | -------------------------------------------------------------------------------- /options/train/MMA/train_MMA_x4_pretrain.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: MMA_x4_pretrain 3 | model_type: SRModel 4 | scale: 4 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: ImageNet 12 | type: ImageNetPairedDataset 13 | dataroot_gt: ~ 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 256 18 | use_hflip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 4 24 | batch_size_per_gpu: 4 25 | dataset_enlarge_ratio: 1 26 | prefetch_mode: cuda 27 | pin_memory: True 28 | 29 | val: 30 | name: Set5 31 | type: PairedImageDataset 32 | dataroot_gt: ~ 33 | dataroot_lq: ~ 34 | io_backend: 35 | type: disk 36 | 37 | # network structures 38 | network_g: 39 | type: MMA 40 | num_in_ch: 3 41 | num_out_ch: 3 42 | num_feat: 192 43 | num_block: 24 44 | upscale: 4 45 | img_range: 1. 46 | rgb_mean: [0.4488, 0.4371, 0.4040] 47 | 48 | # path 49 | path: 50 | pretrain_network_g: ~ 51 | strict_load_g: true 52 | resume_state: ~ 53 | 54 | # training settings 55 | train: 56 | ema_decay: 0.999 57 | optim_g: 58 | type: Adam 59 | lr: !!float 2e-4 60 | weight_decay: 0 61 | betas: [0.9, 0.99] 62 | 63 | scheduler: 64 | type: MultiStepLR 65 | milestones: [300000, 500000, 650000, 700000, 750000] 66 | gamma: 0.5 67 | 68 | total_iter: 800000 69 | warmup_iter: -1 # no warm up 70 | 71 | # losses 72 | pixel_opt: 73 | type: L1Loss 74 | loss_weight: 1.0 75 | reduction: mean 76 | 77 | # validation settings 78 | val: 79 | val_freq: !!float 5e3 80 | save_img: false 81 | 82 | metrics: 83 | psnr: # metric name, can be arbitrary 84 | type: calculate_psnr 85 | crop_border: 4 86 | test_y_channel: true 87 | 88 | ssim: # metric name, can be arbitrary 89 | type: calculate_ssim 90 | crop_border: 4 91 | test_y_channel: true 92 | 93 | # logging settings 94 | logger: 95 | print_freq: 100 96 | save_checkpoint_freq: !!float 5e4 97 | use_tb_logger: true 98 | wandb: 99 | project: ~ 100 | resume_id: ~ 101 | 102 | # dist training settings 103 | dist_params: 104 | backend: nccl 105 | port: 29500 106 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | addict 2 | future 3 | lmdb 4 | numpy>=1.17 5 | opencv-python 6 | Pillow 7 | pyyaml 8 | requests 9 | scikit-image 10 | scipy 11 | tb-nightly 12 | torch==2.1.1 13 | torchvision 14 | tqdm 15 | yapf 16 | -------------------------------------------------------------------------------- /results.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ArsenalCheng/MMA/256cead19ff9540f0364f562c83486eec13c8a4a/results.jpg -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=120 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | column_limit = 120 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | 15 | [isort] 16 | line_length = 120 17 | multi_line_output = 0 18 | known_standard_library = pkg_resources,setuptools 19 | known_first_party = basicsr 20 | known_third_party = PIL,cv2,lmdb,numpy,pytest,requests,scipy,skimage,torch,torchvision,tqdm,yaml 21 | no_lines_before = STDLIB,LOCALFOLDER 22 | default_section = THIRDPARTY 23 | 24 | [codespell] 25 | skip = .git,./docs/build,*.cfg 26 | count = 27 | quiet-level = 3 28 | ignore-words-list = gool 29 | 30 | [aliases] 31 | test=pytest 32 | 33 | [tool:pytest] 34 | addopts=tests/ 35 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import time 8 | 9 | version_file = 'basicsr/version.py' 10 | 11 | 12 | def readme(): 13 | with open('README.md', encoding='utf-8') as f: 14 | content = f.read() 15 | return content 16 | 17 | 18 | def get_git_hash(): 19 | 20 | def _minimal_ext_cmd(cmd): 21 | # construct minimal environment 22 | env = {} 23 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 24 | v = os.environ.get(k) 25 | if v is not None: 26 | env[k] = v 27 | # LANGUAGE is used on win32 28 | env['LANGUAGE'] = 'C' 29 | env['LANG'] = 'C' 30 | env['LC_ALL'] = 'C' 31 | out = subprocess.Popen(cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 32 | return out 33 | 34 | try: 35 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 36 | sha = out.strip().decode('ascii') 37 | except OSError: 38 | sha = 'unknown' 39 | 40 | return sha 41 | 42 | 43 | def get_hash(): 44 | if os.path.exists('.git'): 45 | sha = get_git_hash()[:7] 46 | # currently ignore this 47 | # elif os.path.exists(version_file): 48 | # try: 49 | # from basicsr.version import __version__ 50 | # sha = __version__.split('+')[-1] 51 | # except ImportError: 52 | # raise ImportError('Unable to get git version') 53 | else: 54 | sha = 'unknown' 55 | 56 | return sha 57 | 58 | 59 | def write_version_py(): 60 | content = """# GENERATED VERSION FILE 61 | # TIME: {} 62 | __version__ = '{}' 63 | __gitsha__ = '{}' 64 | version_info = ({}) 65 | """ 66 | sha = get_hash() 67 | with open('VERSION', 'r') as f: 68 | SHORT_VERSION = f.read().strip() 69 | VERSION_INFO = ', '.join([x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 70 | 71 | version_file_str = content.format(time.asctime(), SHORT_VERSION, sha, VERSION_INFO) 72 | with open(version_file, 'w') as f: 73 | f.write(version_file_str) 74 | 75 | 76 | def get_version(): 77 | with open(version_file, 'r') as f: 78 | exec(compile(f.read(), version_file, 'exec')) 79 | return locals()['__version__'] 80 | 81 | 82 | def make_cuda_ext(name, module, sources, sources_cuda=None): 83 | if sources_cuda is None: 84 | sources_cuda = [] 85 | define_macros = [] 86 | extra_compile_args = {'cxx': []} 87 | 88 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 89 | define_macros += [('WITH_CUDA', None)] 90 | extension = CUDAExtension 91 | extra_compile_args['nvcc'] = [ 92 | '-D__CUDA_NO_HALF_OPERATORS__', 93 | '-D__CUDA_NO_HALF_CONVERSIONS__', 94 | '-D__CUDA_NO_HALF2_OPERATORS__', 95 | ] 96 | sources += sources_cuda 97 | else: 98 | print(f'Compiling {name} without CUDA') 99 | extension = CppExtension 100 | 101 | return extension( 102 | name=f'{module}.{name}', 103 | sources=[os.path.join(*module.split('.'), p) for p in sources], 104 | define_macros=define_macros, 105 | extra_compile_args=extra_compile_args) 106 | 107 | 108 | def get_requirements(filename='requirements.txt'): 109 | here = os.path.dirname(os.path.realpath(__file__)) 110 | with open(os.path.join(here, filename), 'r') as f: 111 | requires = [line.replace('\n', '') for line in f.readlines()] 112 | return requires 113 | 114 | 115 | if __name__ == '__main__': 116 | cuda_ext = os.getenv('BASICSR_EXT') # whether compile cuda ext 117 | if cuda_ext == 'True': 118 | try: 119 | import torch 120 | from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension 121 | except ImportError: 122 | raise ImportError('Unable to import torch - torch is needed to build cuda extensions') 123 | 124 | ext_modules = [ 125 | make_cuda_ext( 126 | name='deform_conv_ext', 127 | module='basicsr.ops.dcn', 128 | sources=['src/deform_conv_ext.cpp'], 129 | sources_cuda=['src/deform_conv_cuda.cpp', 'src/deform_conv_cuda_kernel.cu']), 130 | make_cuda_ext( 131 | name='fused_act_ext', 132 | module='basicsr.ops.fused_act', 133 | sources=['src/fused_bias_act.cpp'], 134 | sources_cuda=['src/fused_bias_act_kernel.cu']), 135 | make_cuda_ext( 136 | name='upfirdn2d_ext', 137 | module='basicsr.ops.upfirdn2d', 138 | sources=['src/upfirdn2d.cpp'], 139 | sources_cuda=['src/upfirdn2d_kernel.cu']), 140 | ] 141 | setup_kwargs = dict(cmdclass={'build_ext': BuildExtension}) 142 | else: 143 | ext_modules = [] 144 | setup_kwargs = dict() 145 | 146 | write_version_py() 147 | setup( 148 | name='basicsr', 149 | version=get_version(), 150 | description='Open Source Image and Video Super-Resolution Toolbox', 151 | long_description=readme(), 152 | long_description_content_type='text/markdown', 153 | author='Xintao Wang', 154 | author_email='xintao.wang@outlook.com', 155 | keywords='computer vision, restoration, super resolution', 156 | url='https://github.com/xinntao/BasicSR', 157 | include_package_data=True, 158 | packages=find_packages(exclude=('options', 'datasets', 'experiments', 'results', 'tb_logger', 'wandb')), 159 | classifiers=[ 160 | 'Development Status :: 4 - Beta', 161 | 'License :: OSI Approved :: Apache Software License', 162 | 'Operating System :: OS Independent', 163 | 'Programming Language :: Python :: 3', 164 | 'Programming Language :: Python :: 3.7', 165 | 'Programming Language :: Python :: 3.8', 166 | ], 167 | license='Apache License 2.0', 168 | setup_requires=['cython', 'numpy', 'torch'], 169 | install_requires=get_requirements(), 170 | ext_modules=ext_modules, 171 | zip_safe=False, 172 | **setup_kwargs) 173 | --------------------------------------------------------------------------------