├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── README_CN.md ├── VERSION ├── archs ├── __init__.py └── example_arch.py ├── data ├── __init__.py └── example_dataset.py ├── datasets └── README.md ├── experiments └── README.md ├── losses ├── __init__.py └── example_loss.py ├── models ├── __init__.py └── example_model.py ├── options └── example_option.yml ├── requirements.txt ├── scripts └── prepare_example_data.py ├── setup.cfg └── train.py /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: PyLint 2 | 3 | on: [push, pull_request] 4 | 5 | jobs: 6 | build: 7 | 8 | runs-on: ubuntu-latest 9 | strategy: 10 | matrix: 11 | python-version: [3.8] 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python ${{ matrix.python-version }} 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: ${{ matrix.python-version }} 19 | 20 | - name: Install dependencies 21 | run: | 22 | python -m pip install --upgrade pip 23 | pip install flake8 yapf isort 24 | 25 | - name: Lint 26 | run: | 27 | flake8 . 28 | isort --check-only --diff archs/ data/ models/ scripts/ train.py 29 | yapf -r -d archs/ data/ models/ scripts/ train.py 30 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | datasets/* 2 | experiments/* 3 | results/* 4 | tb_logger/* 5 | wandb/* 6 | tmp/* 7 | 8 | *.DS_Store 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | pip-wheel-metadata/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 104 | __pypackages__/ 105 | 106 | # Celery stuff 107 | celerybeat-schedule 108 | celerybeat.pid 109 | 110 | # SageMath parsed files 111 | *.sage.py 112 | 113 | # Environments 114 | .env 115 | .venv 116 | env/ 117 | venv/ 118 | ENV/ 119 | env.bak/ 120 | venv.bak/ 121 | 122 | # Spyder project settings 123 | .spyderproject 124 | .spyproject 125 | 126 | # Rope project settings 127 | .ropeproject 128 | 129 | # mkdocs documentation 130 | /site 131 | 132 | # mypy 133 | .mypy_cache/ 134 | .dmypy.json 135 | dmypy.json 136 | 137 | # Pyre type checker 138 | .pyre/ 139 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | # flake8 3 | - repo: https://github.com/PyCQA/flake8 4 | rev: 3.8.3 5 | hooks: 6 | - id: flake8 7 | args: ["--config=setup.cfg", "--ignore=W504, W503"] 8 | 9 | # modify known_third_party 10 | - repo: https://github.com/asottile/seed-isort-config 11 | rev: v2.2.0 12 | hooks: 13 | - id: seed-isort-config 14 | 15 | # isort 16 | - repo: https://github.com/timothycrosley/isort 17 | rev: 5.2.2 18 | hooks: 19 | - id: isort 20 | 21 | # yapf 22 | - repo: https://github.com/pre-commit/mirrors-yapf 23 | rev: v0.30.0 24 | hooks: 25 | - id: yapf 26 | 27 | # pre-commit-hooks 28 | - repo: https://github.com/pre-commit/pre-commit-hooks 29 | rev: v3.2.0 30 | hooks: 31 | - id: trailing-whitespace # Trim trailing whitespace 32 | - id: check-yaml # Attempt to load all yaml files to verify syntax 33 | - id: check-merge-conflict # Check for files that contain merge conflict strings 34 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings 35 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline 36 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0 37 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*- 38 | args: ["--remove"] 39 | - id: mixed-line-ending # Replace or check mixed line ending 40 | args: ["--fix=lf"] 41 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.trimTrailingWhitespace": true, 3 | "editor.wordWrap": "on", 4 | "editor.rulers": [ 5 | 80, 6 | 120 7 | ], 8 | "editor.renderWhitespace": "all", 9 | "editor.renderControlCharacters": true, 10 | "python.formatting.provider": "yapf", 11 | "python.formatting.yapfArgs": [ 12 | "--style", 13 | "{BASED_ON_STYLE = pep8, BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true, SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true, COLUMN_LIMIT = 120}" 14 | ], 15 | "python.linting.flake8Enabled": true, 16 | "python.linting.flake8Args": [ 17 | "max-line-length=120" 18 | ], 19 | } 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xintao Wang and BasicSR-examples contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # :rocket: BasicSR Examples 2 | 3 | [![download](https://img.shields.io/github/downloads/xinntao/BasicSR-examples/total.svg)](https://github.com/xinntao/BasicSR-examples/releases) 4 | [![Open issue](https://img.shields.io/github/issues/xinntao/BasicSR-examples)](https://github.com/xinntao/BasicSR-examples/issues) 5 | [![Closed issue](https://img.shields.io/github/issues-closed/xinntao/BasicSR-examples)](https://github.com/xinntao/BasicSR-examples/issues) 6 | [![LICENSE](https://img.shields.io/github/license/xinntao/basicsr-examples.svg)](https://github.com/xinntao/BasicSR-examples/blob/master/LICENSE) 7 | [![python lint](https://github.com/xinntao/BasicSR/actions/workflows/pylint.yml/badge.svg)](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml) 8 | 9 | [English](README.md) **|** [简体中文](README_CN.md)
10 | [`BasicSR repo`](https://github.com/xinntao/BasicSR) **|** [`simple mode example`](https://github.com/xinntao/BasicSR-examples/tree/master) **|** [`installation mode example`](https://github.com/xinntao/BasicSR-examples/tree/installation) 11 | 12 | In this repository, we give examples to illustrate **how to easily use** [`BasicSR`](https://github.com/xinntao/BasicSR) in **your own project**. 13 | 14 | :triangular_flag_on_post: **Projects that use BasicSR** 15 | - :white_check_mark: [**GFPGAN**](https://github.com/TencentARC/GFPGAN): A practical algorithm for real-world face restoration 16 | - :white_check_mark: [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration 17 | 18 | If you use `BasicSR` in your open-source projects, welcome to contact me (by [email](#e-mail-contact) or opening an issue/pull request). I will add your projects to the above list :blush: 19 | 20 | --- 21 | 22 | If this repo is helpful, please help to :star: this repo or recommend it to your friends. Thanks:blush:
23 | Other recommended projects:
24 | :arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.
25 | :arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison. 26 | 27 | --- 28 | 29 | ## Contents 30 | 31 | - [HowTO use BasicSR](#HowTO-use-BasicSR) 32 | - [As s Template](#As-a-Template) 33 | 34 | ## HowTO use BasicSR 35 | 36 | `BasicSR` can be used in two ways: 37 | - :arrow_right: Git clone the entire BasicSR. In this way, you can see the complete codes of BasicSR, and then modify them according to your own needs. 38 | - :arrow_right: Use basicsr as a [python package](https://pypi.org/project/basicsr/#history) (that is, install with pip). It provides the training framework, procedures, and some basic functions. You can easily build your own projects based on basicsr. 39 | ```bash 40 | pip install basicsr 41 | ``` 42 | 43 | Our example mainly focuses on the second one, that is, how to easily and concisely build your own project based on the basicsr package. 44 | 45 | There are two ways to use the python package of basicsr, which are provided in two branches: 46 | 47 | - :arrow_right: [simple mode](https://github.com/xinntao/BasicSR-examples/tree/master): the project can be run **without installation**. But it has limitations: it is inconvenient to import complex hierarchical relationships; It is not easy to access the functions in this project from other locations 48 | 49 | - :arrow_right: [installation mode](https://github.com/xinntao/BasicSR-examples/tree/installation): you need to install the project by running `python setup.py develop`. After installation, it is more convenient to import and use. 50 | 51 | As a simple introduction and explanation, we use the example of *simple mode*, but we recommend the *installation mode* in practical use. 52 | 53 | ```bash 54 | git clone https://github.com/xinntao/BasicSR-examples.git 55 | cd BasicSR-examples 56 | ``` 57 | 58 | ### Preliminary 59 | 60 | Most deep-learning projects can be divided into the following parts: 61 | 62 | 1. **data**: defines the training/validation data that is fed into the model training 63 | 2. **arch** (architecture): defines the network structure and the forward steps 64 | 3. **model**: defines the necessary components in training (such as loss) and a complete training process (including forward propagation, back-propagation, gradient optimization, *etc*.), as well as other functions, such as validation, *etc* 65 | 4. Training pipeline: defines the training process, that is, connect the data-loader, model, validation, saving checkpoints, *etc* 66 | 67 | When we are developing a new method, we often improve the **data**, **arch**, and **model**. Most training processes and basic functions are actually shared. Then, we hope to focus on the development of main functions instead of building wheels repeatedly. 68 | 69 | Therefore, we have BasicSR, which separates many shared functions. With BasicSR, we just need to care about the development of **data**, **arch**, and **model**. 70 | 71 | In order to further facilitate the use of BasicSR, we provide the basicsr package. You can easily install it through `pip install basicsr`. After that, you can use the training process of BasicSR and the functions already developed in BasicSR~ 72 | 73 | ### A Simple Example 74 | 75 | Let's use a simple example to illustrate how to use BasicSR to build your own project. 76 | 77 | We provide two sample data for demonstration: 78 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training 79 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation 80 | 81 | You can easily download them by running the following command in the BasicSR-examples root path: 82 | 83 | ```bash 84 | python scripts/prepare_example_data.py 85 | ``` 86 | 87 | The sample data are now in the `datasets/example` folder. 88 | 89 | #### :zero: Purpose 90 | 91 | Let's use a Super-Resolution task for the demo. 92 | It takes a low-resolution image as the input and outputs a high-resolution image. 93 | The low-resolution images contain: 1) CV2 bicubic X4 downsampling, and 2) JPEG compression (quality = 70). 94 | 95 | In order to better explain how to use the arch and model, we use 1) a network structure similar to SRCNN; 2) use L1 and L2 (MSE) loss simultaneously in training. 96 | 97 | So, in this task, what we should do are: 98 | 99 | 1. Build our own data loader 100 | 1. Determine the architecture 101 | 1. Build our own model 102 | 103 | Let's explain it separately in the following parts. 104 | 105 | #### :one: data 106 | 107 | We need to implement a new dataset to fulfill our purpose. The dataset is used to feed the data into the model. 108 | 109 | An example of this dataset is in [data/example_dataset.py](data/example_dataset.py). It has the following steps. 110 | 111 | 1. Read Ground-Truth (GT) images. BasicSR provides [FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py) for easily reading files in a folder, LMDB file and meta_info txt. In this example, we use the folder mode. For more reading modes, please refer to [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data) 112 | 1. Synthesize low resolution images. We can directly implement the data procedures in the `__getitem__(self, index)` function, such as downsampling and adding JPEG compression. Many basic operations can be found in [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) ,and [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) 113 | 1. Convert to torch tensor and return appropriate information 114 | 115 | **Note**: 116 | 117 | 1. Please add `@DATASET_REGISTRY.register()` before `ExampleDataset`. This operation is mainly used to prevent the occurrence of a dataset with the same name, which will result in potential bugs 118 | 1. The new dataset file should end with `_dataset.py`, such as `example_dataset.py`. In this way, the program can **automatically** import classes without manual import 119 | 120 | In the [option configuration file](options/example_option.yml), you can use the new dataset: 121 | 122 | ```yaml 123 | datasets: 124 | train: # training dataset 125 | name: ExampleBSDS100 126 | type: ExampleDataset # the class name 127 | 128 | # ----- the followings are the arguments of ExampleDataset ----- # 129 | dataroot_gt: datasets/example/BSDS100 130 | io_backend: 131 | type: disk 132 | 133 | gt_size: 128 134 | use_flip: true 135 | use_rot: true 136 | 137 | # ----- arguments of data loader ----- # 138 | use_shuffle: true 139 | num_worker_per_gpu: 3 140 | batch_size_per_gpu: 16 141 | dataset_enlarge_ratio: 10 142 | prefetch_mode: ~ 143 | 144 | val: # validation dataset 145 | name: ExampleSet5 146 | type: ExampleDataset 147 | dataroot_gt: datasets/example/Set5 148 | io_backend: 149 | type: disk 150 | ``` 151 | 152 | #### :two: arch 153 | 154 | An example of architecture is in [archs/example_arch.py](archs/example_arch.py). It mainly builds the network structure. 155 | 156 | **Note**: 157 | 158 | 1. Add `@ARCH_REGISTRY.register()` before `ExampleArch`, so as to register the newly implemented arch. This operation is mainly used to prevent the occurrence of arch with the same name, resulting in potential bugs 159 | 1. The new arch file should end with `_arch.py`, such as `example_arch.py`. In this way, the program can **automatically** import classes without manual import 160 | 161 | In the [option configuration file](options/example_option.yml), you can use the new arch: 162 | 163 | ```yaml 164 | # network structures 165 | network_g: 166 | type: ExampleArch # the class name 167 | 168 | # ----- the followings are the arguments of ExampleArch ----- # 169 | num_in_ch: 3 170 | num_out_ch: 3 171 | num_feat: 64 172 | upscale: 4 173 | ``` 174 | 175 | #### :three: model 176 | 177 | An example of model is in [models/example_model.py](models/example_model.py). It mainly builds the training process of a model. 178 | 179 | In this file: 180 | 1. We inherit `SRModel` from basicsr. Many models have similar operations, so you can inherit and modify from [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models). In this way, you can easily implement your ideas, such as GAN model, video model, *etc*. 181 | 1. Two losses are used: L1 and L2 (MSE) loss 182 | 1. Many other contents, such as `setup_optimizers`, `validation`, `save`, *etc*, are inherited from `SRModel` 183 | 184 | **Note**: 185 | 186 | 1. Add `@MODEL_REGISTRY.register()` before `ExampleModel`, so as to register the newly implemented model. This operation is mainly used to prevent the occurrence of model with the same name, resulting in potential bugs 187 | 1. The new model file should end with `_model.py`, such as `example_model.py`. In this way, the program can **automatically** import classes without manual import 188 | 189 | In the [option configuration file](options/example_option.yml), you can use the new model: 190 | 191 | ```yaml 192 | # training settings 193 | train: 194 | optim_g: 195 | type: Adam 196 | lr: !!float 2e-4 197 | weight_decay: 0 198 | betas: [0.9, 0.99] 199 | 200 | scheduler: 201 | type: MultiStepLR 202 | milestones: [50000] 203 | gamma: 0.5 204 | 205 | total_iter: 100000 206 | warmup_iter: -1 # no warm up 207 | 208 | # ----- the followings are the configurations for two losses ----- # 209 | # losses 210 | l1_opt: 211 | type: L1Loss 212 | loss_weight: 1.0 213 | reduction: mean 214 | 215 | l2_opt: 216 | type: MSELoss 217 | loss_weight: 1.0 218 | reduction: mean 219 | ``` 220 | 221 | #### :four: training pipeline 222 | 223 | The whole training pipeline can reuse the [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py) in BasicSR. 224 | 225 | Based on this, our [train.py](train.py) can be very concise: 226 | 227 | ```python 228 | import os.path as osp 229 | 230 | import archs # noqa: F401 231 | import data # noqa: F401 232 | import models # noqa: F401 233 | from basicsr.train import train_pipeline 234 | 235 | if __name__ == '__main__': 236 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 237 | train_pipeline(root_path) 238 | 239 | ``` 240 | 241 | #### :five: debug mode 242 | 243 | So far, we have completed the development of our project. We can quickly check whether there is a bug through the `debug` mode: 244 | 245 | ```bash 246 | python train.py -opt options/example_option.yml --debug 247 | ``` 248 | 249 | With `--debug`, the program will enter the debug mode. In the debug mode, the program will output at each iteration, and perform validation every 8 iterations, so that you can easily know whether the program has a bug~ 250 | 251 | #### :six: normal training 252 | 253 | After debugging, we can have the normal training. 254 | 255 | ```bash 256 | python train.py -opt options/example_option.yml 257 | ``` 258 | 259 | If the training process is interrupted unexpectedly and the resume is required. Please use `--auto_resume` in the command: 260 | 261 | ```bash 262 | python train.py -opt options/example_option.yml --auto_resume 263 | ``` 264 | 265 | So far, you have finished developing your own projects using `BasicSR`. Isn't it very convenient~ :grin: 266 | 267 | ## As a Template 268 | 269 | You can use BasicSR-Examples as a template for your project. Here are some modifications you may need. 270 | 271 | 1. Set up the *pre-commit* hook 272 | 1. In the root path, run: 273 | > pre-commit install 274 | 1. Modify the `LICENSE`
275 | This repository uses the *MIT* license, you may change it to other licenses 276 | 277 | The simple mode do not require many modifications. Those using the installation mode may need more modifications. See [here](https://github.com/xinntao/BasicSR-examples/blob/installation/README.md#As-a-Template) 278 | 279 | ## :e-mail: Contact 280 | 281 | If you have any questions or want to add your project to the list, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`. 282 | -------------------------------------------------------------------------------- /README_CN.md: -------------------------------------------------------------------------------- 1 | # :rocket: BasicSR Examples 2 | 3 | [![download](https://img.shields.io/github/downloads/xinntao/BasicSR-examples/total.svg)](https://github.com/xinntao/BasicSR-examples/releases) 4 | [![Open issue](https://img.shields.io/github/issues/xinntao/BasicSR-examples)](https://github.com/xinntao/BasicSR-examples/issues) 5 | [![Closed issue](https://img.shields.io/github/issues-closed/xinntao/BasicSR-examples)](https://github.com/xinntao/BasicSR-examples/issues) 6 | [![LICENSE](https://img.shields.io/github/license/xinntao/basicsr-examples.svg)](https://github.com/xinntao/BasicSR-examples/blob/master/LICENSE) 7 | [![python lint](https://github.com/xinntao/BasicSR/actions/workflows/pylint.yml/badge.svg)](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml) 8 | 9 | [English](README.md) **|** [简体中文](README_CN.md)
10 | [`BasicSR repo`](https://github.com/xinntao/BasicSR) **|** [`simple mode example`](https://github.com/xinntao/BasicSR-examples/tree/master) **|** [`installation mode example`](https://github.com/xinntao/BasicSR-examples/tree/installation) 11 | 12 | 在这个仓库中,我们通过简单的例子来说明:如何在**你自己的项目中**轻松地**使用** [`BasicSR`](https://github.com/xinntao/BasicSR)。 13 | 14 | :triangular_flag_on_post: **使用 BasicSR 的项目** 15 | - :white_check_mark: [**GFPGAN**](https://github.com/TencentARC/GFPGAN): 真实场景人脸复原的实用算法 16 | - :white_check_mark: [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法 17 | 18 | 如果你的开源项目中使用了`BasicSR`, 欢迎联系我 ([邮件](#e-mail-%E8%81%94%E7%B3%BB)或者开一个issue/pull request)。我会将你的开源项目添加到上面的列表中 :blush: 19 | 20 | --- 21 | 22 | 如果你觉得这个项目对你有帮助,欢迎 :star: 这个仓库或推荐给你的朋友。Thanks:blush:
23 | 其他推荐的项目:
24 | :arrow_forward: [facexlib](https://github.com/xinntao/facexlib): 提供实用的人脸相关功能的集合
25 | :arrow_forward: [HandyView](https://github.com/xinntao/HandyView): 基于PyQt5的 方便的看图比图工具 26 | 27 | --- 28 | 29 | ## 目录 30 | 31 | - [如何使用 BasicSR](#HowTO-use-BasicSR) 32 | - [作为Template](#As-a-Template) 33 | 34 | ## HowTO use BasicSR 35 | 36 | `BasicSR` 有两种使用方式: 37 | - :arrow_right: Git clone 整个 BasicSR 的代码。这样可以看到 BasicSR 完整的代码,然后根据你自己的需求进行修改 38 | - :arrow_right: BasicSR 作为一个 [python package](https://pypi.org/project/basicsr/#history) (即可以通过pip安装),提供了训练的框架,流程和一些基本功能。你可以基于 basicsr 方便地搭建你自己的项目 39 | ```bash 40 | pip install basicsr 41 | ``` 42 | 43 | 我们的样例主要针对第二种使用方式,即如何基于 basicsr 这个package来方便简洁地搭建你自己的项目。 44 | 45 | 使用 basicsr 的python package又有两种方式,我们分别提供在两个 branch 中: 46 | - :arrow_right: [简单模式](https://github.com/xinntao/BasicSR-examples/tree/master): 项目的仓库不需要安装,就可以运行使用。但它有局限:不方便 import 复杂的层级关系;在其他位置也不容易访问本项目中的函数 47 | - :arrow_right: [安装模式](https://github.com/xinntao/BasicSR-examples/tree/installation): 项目的仓库需要安装 `python setup.py develop`,安装之后 import 和使用都更加方便 48 | 49 | 作为简单的入门和讲解, 我们使用*简单模式*的样例,但在实际使用中我们推荐*安装模式*。 50 | 51 | ```bash 52 | git clone https://github.com/xinntao/BasicSR-examples.git 53 | cd BasicSR-examples 54 | ``` 55 | 56 | ### 预备 57 | 58 | 大部分的深度学习项目,都可以分为以下几个部分: 59 | 60 | 1. **data**: 定义了训练数据,来喂给模型的训练过程 61 | 2. **arch** (architecture): 定义了网络结构 和 forward 的步骤 62 | 3. **model**: 定义了在训练中必要的组件(比如 loss) 和 一次完整的训练过程(包括前向传播,反向传播,梯度优化等),还有其他功能,比如 validation等 63 | 4. training pipeline: 定义了训练的流程,即把数据 dataloader,模型,validation,保存 checkpoints 等等串联起来 64 | 65 | 当我们开发一个新的方法时,我们往往在改进: **data**, **arch**, **model**;而很多流程、基础的功能其实是共用的。那么,我们希望可以专注于主要功能的开发,而不要重复造轮子。 66 | 67 | 因此便有了 BasicSR,它把很多相似的功能都独立出来,我们只要关心 **data**, **arch**, **model** 的开发即可。 68 | 69 | 为了进一步方便大家使用,我们提供了 basicsr package,大家可以通过 `pip install basicsr` 方便地安装,然后就可以使用 BasicSR 的训练流程以及在 BasicSR 里面已开发好的功能啦~ 70 | 71 | ### 简单的例子 72 | 73 | 下面我们就通过一个简单的例子,来说明如何使用 BasicSR 来搭建你自己的项目。 74 | 75 | 我们提供了两个样例数据来做展示, 76 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training 77 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation 78 | 79 | 在 BasicSR-example 的根目录运行下面的命令来下载: 80 | 81 | ```bash 82 | python scripts/prepare_example_data.py 83 | ``` 84 | 85 | 样例数据就下载在 `datasets/example` 文件夹中。 86 | 87 | #### :zero: 目的 88 | 89 | 我们来假设一个超分辨率 (Super-Resolution) 的任务,输入一张低分辨率的图片,输出高分辨率的图片。低分辨率图片包含了 1) cv2 的 bicubic X4 downsampling 和 2) JPEG 压缩 (quality=70)。 90 | 91 | 为了更好的说明如何使用 arch 和 model,我们想要使用 1) 类似 SRCNN 的网络结构;2) 在训练中同时使用 L1 和 L2 (MSE) loss。 92 | 93 | 那么,在这个任务中,我们要做的是: 94 | 95 | 1. 构建自己的 data loader 96 | 1. 确定使用的 architecture 97 | 1. 构建自己的 model 98 | 99 | 下面我们分别来说明一下。 100 | 101 | #### :one: data 102 | 103 | 这个部分是用来确定喂给模型的数据的。 104 | 105 | 这个 dataset 的例子在[data/example_dataset.py](data/example_dataset.py) 中,它完成了: 106 | 1. 我们读取 Ground-Truth (GT) 的图像。读取的操作,BasicSR 提供了[FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py), 可以方便地读取 folder, lmdb 和 meta_info txt 指定的文件。在这个例子中,我们通过读取 folder 来说明,更多的读取模式可以参考 [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data) 107 | 1. 合成低分辨率的图像。我们直接可以在 `__getitem__(self, index)` 的函数中实现我们想要的操作,比如降采样和添加 JPEG 压缩。很多基本操作都可以在 [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) 和 [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) 中找到 108 | 1. 转换成 Torch Tensor,返回合适的信息 109 | 110 | **注意**: 111 | 1. 需要在 `ExampleDataset` 前添加 `@DATASET_REGISTRY.register()`,以便注册好新写的 dataset。这个操作主要用来防止出现同名的 dataset,从而带来潜在的 bug 112 | 1. 新写的 dataset 文件要以 `_dataset.py` 结尾,比如 `example_dataset.py`。 这样,程序可以**自动地** import,而不需要手动地 import 113 | 114 | 在 [option 配置文件中](options/example_option.yml)使用新写的 dataset: 115 | 116 | ```yaml 117 | datasets: 118 | train: # training dataset 119 | name: ExampleBSDS100 120 | type: ExampleDataset # the class name 121 | 122 | # ----- the followings are the arguments of ExampleDataset ----- # 123 | dataroot_gt: datasets/example/BSDS100 124 | io_backend: 125 | type: disk 126 | 127 | gt_size: 128 128 | use_flip: true 129 | use_rot: true 130 | 131 | # ----- arguments of data loader ----- # 132 | use_shuffle: true 133 | num_worker_per_gpu: 3 134 | batch_size_per_gpu: 16 135 | dataset_enlarge_ratio: 10 136 | prefetch_mode: ~ 137 | 138 | val: # validation dataset 139 | name: ExampleSet5 140 | type: ExampleDataset 141 | dataroot_gt: datasets/example/Set5 142 | io_backend: 143 | type: disk 144 | ``` 145 | 146 | #### :two: arch 147 | 148 | Architecture 的例子在 [archs/example_arch.py](archs/example_arch.py)中。它主要搭建了网络结构。 149 | 150 | **注意**: 151 | 1. 需要在 `ExampleArch` 前添加 `@ARCH_REGISTRY.register()`,以便注册好新写的 arch。这个操作主要用来防止出现同名的 arch,从而带来潜在的 bug 152 | 1. 新写的 arch 文件要以 `_arch.py` 结尾,比如 `example_arch.py`。 这样,程序可以**自动地** import,而不需要手动地 import 153 | 154 | 在 [option 配置文件中](options/example_option.yml)使用新写的 arch: 155 | 156 | ```yaml 157 | # network structures 158 | network_g: 159 | type: ExampleArch # the class name 160 | 161 | # ----- the followings are the arguments of ExampleArch ----- # 162 | num_in_ch: 3 163 | num_out_ch: 3 164 | num_feat: 64 165 | upscale: 4 166 | ``` 167 | 168 | #### :three: model 169 | 170 | Model 的例子在 [models/example_model.py](models/example_model.py)中。它主要搭建了模型的训练过程。 171 | 在这个文件中: 172 | 1. 我们从 basicsr 中继承了 `SRModel`。很多模型都有相似的操作,因此可以通过继承 [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models) 中的模型来更方便地实现自己的想法,比如GAN模型,Video模型等 173 | 1. 使用了两个 Loss: L1 和 L2 (MSE) loss 174 | 1. 其他很多内容,比如 `setup_optimizers`, `validation`, `save`等,都是继承于 `SRModel` 175 | 176 | **注意**: 177 | 1. 需要在 `ExampleModel` 前添加 `@MODEL_REGISTRY.register()`,以便注册好新写的 model。这个操作主要用来防止出现同名的 model,从而带来潜在的 bug 178 | 1. 新写的 model 文件要以 `_model.py` 结尾,比如 `example_model.py`。 这样,程序可以**自动地** import,而不需要手动地 import 179 | 180 | 在 [option 配置文件中](options/example_option.yml)使用新写的 model: 181 | 182 | ```yaml 183 | # training settings 184 | train: 185 | optim_g: 186 | type: Adam 187 | lr: !!float 2e-4 188 | weight_decay: 0 189 | betas: [0.9, 0.99] 190 | 191 | scheduler: 192 | type: MultiStepLR 193 | milestones: [50000] 194 | gamma: 0.5 195 | 196 | total_iter: 100000 197 | warmup_iter: -1 # no warm up 198 | 199 | # ----- the followings are the configurations for two losses ----- # 200 | # losses 201 | l1_opt: 202 | type: L1Loss 203 | loss_weight: 1.0 204 | reduction: mean 205 | 206 | l2_opt: 207 | type: MSELoss 208 | loss_weight: 1.0 209 | reduction: mean 210 | ``` 211 | 212 | #### :four: training pipeline 213 | 214 | 整个 training pipeline 可以复用 basicsr 里面的 [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py)。 215 | 216 | 基于此,我们的 [train.py](train.py)可以非常简洁。 217 | 218 | ```python 219 | import os.path as osp 220 | 221 | import archs # noqa: F401 222 | import data # noqa: F401 223 | import models # noqa: F401 224 | from basicsr.train import train_pipeline 225 | 226 | if __name__ == '__main__': 227 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 228 | train_pipeline(root_path) 229 | 230 | ``` 231 | 232 | #### :five: debug mode 233 | 234 | 至此,我们已经完成了我们这个项目的开发,下面可以通过 `debug` 模式来快捷地看看是否有问题: 235 | 236 | ```bash 237 | python train.py -opt options/example_option.yml --debug 238 | ``` 239 | 240 | 只要带上 `--debug` 就进入 debug 模式。在 debug 模式中,程序每个iter都会输出,8个iter后就会进行validation,这样可以很方便地知道程序有没有bug啦~ 241 | 242 | #### :six: normal training 243 | 244 | 经过debug没有问题后,我们就可以正式训练了。 245 | 246 | ```bash 247 | python train.py -opt options/example_option.yml 248 | ``` 249 | 250 | 如果训练过程意外中断需要 resume, 则使用 `--auto_resume` 可以方便地自动resume: 251 | ```bash 252 | python train.py -opt options/example_option.yml --auto_resume 253 | ``` 254 | 255 | 至此,使用 `BasicSR` 开发你自己的项目就介绍完了,是不是很方便呀~ :grin: 256 | 257 | ## As a Template 258 | 259 | 你可以使用 BasicSR-Examples 作为你项目的模板。下面主要展示一下你可能需要的修改。 260 | 261 | 1. 设置 *pre-commit* hook 262 | 1. 在文件夹根目录, 运行 263 | > pre-commit install 264 | 1. 修改 `LICENSE` 文件
265 | 本仓库使用 *MIT* 许可, 根据需要可以修改成其他许可 266 | 267 | 使用 简单模式 的基本不需要修改,使用 安装模式 的可能需要较多修改,参见[这里](https://github.com/xinntao/BasicSR-examples/blob/installation/README_CN.md#As-a-Template) 268 | 269 | ## :e-mail: 联系 270 | 271 | 如果你有任何问题,或者想要添加你的项目到列表中,欢迎电邮 272 | `xintao.wang@outlook.com` or `xintaowang@tencent.com`. 273 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.1.0 2 | -------------------------------------------------------------------------------- /archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules for registry 7 | # scan all the files that end with '_arch.py' under the archs folder 8 | arch_folder = osp.dirname(osp.abspath(__file__)) 9 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 10 | # import all the arch modules 11 | _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames] 12 | -------------------------------------------------------------------------------- /archs/example_arch.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.archs.arch_util import default_init_weights 5 | from basicsr.utils.registry import ARCH_REGISTRY 6 | 7 | 8 | @ARCH_REGISTRY.register() 9 | class ExampleArch(nn.Module): 10 | """Example architecture. 11 | 12 | Args: 13 | num_in_ch (int): Channel number of inputs. Default: 3. 14 | num_out_ch (int): Channel number of outputs. Default: 3. 15 | num_feat (int): Channel number of intermediate features. Default: 64. 16 | upscale (int): Upsampling factor. Default: 4. 17 | """ 18 | 19 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, upscale=4): 20 | super(ExampleArch, self).__init__() 21 | self.upscale = upscale 22 | 23 | self.conv1 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 24 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 25 | self.conv3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 26 | 27 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 28 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1) 29 | self.pixel_shuffle = nn.PixelShuffle(2) 30 | 31 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) 32 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) 33 | 34 | # activation function 35 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 36 | 37 | # initialization 38 | default_init_weights( 39 | [self.conv1, self.conv2, self.conv3, self.upconv1, self.upconv2, self.conv_hr, self.conv_last], 0.1) 40 | 41 | def forward(self, x): 42 | feat = self.lrelu(self.conv1(x)) 43 | feat = self.lrelu(self.conv2(feat)) 44 | feat = self.lrelu(self.conv3(feat)) 45 | 46 | out = self.lrelu(self.pixel_shuffle(self.upconv1(feat))) 47 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 48 | 49 | out = self.conv_last(self.lrelu(self.conv_hr(out))) 50 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 51 | out += base 52 | return out 53 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import dataset modules for registry 7 | # scan all the files that end with '_dataset.py' under the data folder 8 | data_folder = osp.dirname(osp.abspath(__file__)) 9 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 10 | # import all the dataset modules 11 | _dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames] 12 | -------------------------------------------------------------------------------- /data/example_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import torch 4 | from torch.utils import data as data 5 | from torchvision.transforms.functional import normalize 6 | 7 | from basicsr.data.degradations import add_jpg_compression 8 | from basicsr.data.transforms import augment, mod_crop, paired_random_crop 9 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 10 | from basicsr.utils.registry import DATASET_REGISTRY 11 | 12 | 13 | @DATASET_REGISTRY.register() 14 | class ExampleDataset(data.Dataset): 15 | """Example dataset. 16 | 17 | 1. Read GT image 18 | 2. Generate LQ (Low Quality) image with cv2 bicubic downsampling and JPEG compression 19 | 20 | Args: 21 | opt (dict): Config for train datasets. It contains the following keys: 22 | dataroot_gt (str): Data root path for gt. 23 | io_backend (dict): IO backend type and other kwarg. 24 | gt_size (int): Cropped patched size for gt patches. 25 | use_flip (bool): Use horizontal flips. 26 | use_rot (bool): Use rotation (use vertical flip and transposing h 27 | and w for implementation). 28 | 29 | scale (bool): Scale, which will be added automatically. 30 | phase (str): 'train' or 'val'. 31 | """ 32 | 33 | def __init__(self, opt): 34 | super(ExampleDataset, self).__init__() 35 | self.opt = opt 36 | # file client (io backend) 37 | self.file_client = None 38 | self.io_backend_opt = opt['io_backend'] 39 | self.mean = opt['mean'] if 'mean' in opt else None 40 | self.std = opt['std'] if 'std' in opt else None 41 | 42 | self.gt_folder = opt['dataroot_gt'] 43 | # it now only supports folder mode, for other modes such as lmdb and meta_info file, please see: 44 | # https://github.com/xinntao/BasicSR/blob/master/basicsr/data/ 45 | self.paths = [os.path.join(self.gt_folder, v) for v in list(scandir(self.gt_folder))] 46 | 47 | def __getitem__(self, index): 48 | if self.file_client is None: 49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 50 | 51 | scale = self.opt['scale'] 52 | 53 | # Load gt images. Dimension order: HWC; channel order: BGR; 54 | # image range: [0, 1], float32. 55 | gt_path = self.paths[index] 56 | img_bytes = self.file_client.get(gt_path, 'gt') 57 | img_gt = imfrombytes(img_bytes, float32=True) 58 | img_gt = mod_crop(img_gt, scale) 59 | 60 | # generate lq image 61 | # downsample 62 | h, w = img_gt.shape[0:2] 63 | img_lq = cv2.resize(img_gt, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC) 64 | # add JPEG compression 65 | img_lq = add_jpg_compression(img_lq, quality=70) 66 | 67 | # augmentation for training 68 | if self.opt['phase'] == 'train': 69 | gt_size = self.opt['gt_size'] 70 | # random crop 71 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 72 | # flip, rotation 73 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot']) 74 | 75 | # BGR to RGB, HWC to CHW, numpy to tensor 76 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 77 | 78 | img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255. 79 | 80 | # normalize 81 | if self.mean is not None or self.std is not None: 82 | normalize(img_lq, self.mean, self.std, inplace=True) 83 | normalize(img_gt, self.mean, self.std, inplace=True) 84 | 85 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': gt_path, 'gt_path': gt_path} 86 | 87 | def __len__(self): 88 | return len(self.paths) 89 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | This folder is mainly for storing datasets used for training/validation/testing. 4 | 5 | ## Practice 6 | 7 | 1. Separate your codes and datasets. So it is better to soft link your dataset (such as DIV2K, FFHQ, *etc*) here. 8 | ```bash 9 | ln -s DATASET_PATH ./ 10 | ``` 11 | 12 | ## Example Datasets 13 | 14 | We provide two example datasets for demo. 15 | 16 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training 17 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation 18 | 19 | You can easily download them by running the following command in the BasicSR-examples root path: 20 | 21 | ```bash 22 | python scripts/prepare_example_data.py 23 | ``` 24 | 25 | The example datasets are now in the `datasets/example` folder. 26 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | # Experiments 2 | 3 | Your experiment runs will be put in this folder. 4 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import loss modules for registry 7 | # scan all the files that end with '_loss.py' under the loss folder 8 | loss_folder = osp.dirname(osp.abspath(__file__)) 9 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 10 | # import all the loss modules 11 | _model_modules = [importlib.import_module(f'losses.{file_name}') for file_name in loss_filenames] 12 | -------------------------------------------------------------------------------- /losses/example_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn as nn 2 | from torch.nn import functional as F 3 | 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | 6 | 7 | @LOSS_REGISTRY.register() 8 | class ExampleLoss(nn.Module): 9 | """Example Loss. 10 | 11 | Args: 12 | loss_weight (float): Loss weight for Example loss. Default: 1.0. 13 | """ 14 | 15 | def __init__(self, loss_weight=1.0): 16 | super(ExampleLoss, self).__init__() 17 | self.loss_weight = loss_weight 18 | 19 | def forward(self, pred, target, **kwargs): 20 | """ 21 | Args: 22 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 23 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 24 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None. 25 | """ 26 | return self.loss_weight * F.l1_loss(pred, target, reduction='mean') 27 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import model modules for registry 7 | # scan all the files that end with '_model.py' under the model folder 8 | model_folder = osp.dirname(osp.abspath(__file__)) 9 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 10 | # import all the model modules 11 | _model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames] 12 | -------------------------------------------------------------------------------- /models/example_model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | 3 | from basicsr.archs import build_network 4 | from basicsr.losses import build_loss 5 | from basicsr.models.sr_model import SRModel 6 | from basicsr.utils import get_root_logger 7 | from basicsr.utils.registry import MODEL_REGISTRY 8 | 9 | 10 | @MODEL_REGISTRY.register() 11 | class ExampleModel(SRModel): 12 | """Example model based on the SRModel class. 13 | 14 | In this example model, we want to implement a new model that trains with both L1 and L2 loss. 15 | 16 | New defined functions: 17 | init_training_settings(self) 18 | feed_data(self, data) 19 | optimize_parameters(self, current_iter) 20 | 21 | Inherited functions: 22 | __init__(self, opt) 23 | setup_optimizers(self) 24 | test(self) 25 | dist_validation(self, dataloader, current_iter, tb_logger, save_img) 26 | nondist_validation(self, dataloader, current_iter, tb_logger, save_img) 27 | _log_validation_metric_values(self, current_iter, dataset_name, tb_logger) 28 | get_current_visuals(self) 29 | save(self, epoch, current_iter) 30 | """ 31 | 32 | def init_training_settings(self): 33 | self.net_g.train() 34 | train_opt = self.opt['train'] 35 | 36 | self.ema_decay = train_opt.get('ema_decay', 0) 37 | if self.ema_decay > 0: 38 | logger = get_root_logger() 39 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') 40 | # define network net_g with Exponential Moving Average (EMA) 41 | # net_g_ema is used only for testing on one GPU and saving 42 | # There is no need to wrap with DistributedDataParallel 43 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device) 44 | # load pretrained model 45 | load_path = self.opt['path'].get('pretrain_network_g', None) 46 | if load_path is not None: 47 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') 48 | else: 49 | self.model_ema(0) # copy net_g weight 50 | self.net_g_ema.eval() 51 | 52 | # define losses 53 | self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device) 54 | self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device) 55 | 56 | # set up optimizers and schedulers 57 | self.setup_optimizers() 58 | self.setup_schedulers() 59 | 60 | def feed_data(self, data): 61 | self.lq = data['lq'].to(self.device) 62 | if 'gt' in data: 63 | self.gt = data['gt'].to(self.device) 64 | 65 | def optimize_parameters(self, current_iter): 66 | self.optimizer_g.zero_grad() 67 | self.output = self.net_g(self.lq) 68 | 69 | l_total = 0 70 | loss_dict = OrderedDict() 71 | # l1 loss 72 | l_l1 = self.l1_pix(self.output, self.gt) 73 | l_total += l_l1 74 | loss_dict['l_l1'] = l_l1 75 | # l2 loss 76 | l_l2 = self.l2_pix(self.output, self.gt) 77 | l_total += l_l2 78 | loss_dict['l_l2'] = l_l2 79 | 80 | l_total.backward() 81 | self.optimizer_g.step() 82 | 83 | self.log_dict = self.reduce_loss_dict(loss_dict) 84 | 85 | if self.ema_decay > 0: 86 | self.model_ema(decay=self.ema_decay) 87 | -------------------------------------------------------------------------------- /options/example_option.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: BasicSR_example 3 | model_type: ExampleModel 4 | scale: 4 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 0 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: ExampleBSDS100 12 | type: ExampleDataset 13 | dataroot_gt: datasets/example/BSDS100 14 | io_backend: 15 | type: disk 16 | 17 | gt_size: 128 18 | use_flip: true 19 | use_rot: true 20 | 21 | # data loader 22 | use_shuffle: true 23 | num_worker_per_gpu: 3 24 | batch_size_per_gpu: 16 25 | dataset_enlarge_ratio: 10 26 | prefetch_mode: ~ 27 | 28 | val: 29 | name: ExampleSet5 30 | type: ExampleDataset 31 | dataroot_gt: datasets/example/Set5 32 | io_backend: 33 | type: disk 34 | 35 | # network structures 36 | network_g: 37 | type: ExampleArch 38 | num_in_ch: 3 39 | num_out_ch: 3 40 | num_feat: 64 41 | upscale: 4 42 | 43 | 44 | # path 45 | path: 46 | pretrain_network_g: ~ 47 | strict_load_g: true 48 | resume_state: ~ 49 | 50 | # training settings 51 | train: 52 | optim_g: 53 | type: Adam 54 | lr: !!float 2e-4 55 | weight_decay: 0 56 | betas: [0.9, 0.99] 57 | 58 | scheduler: 59 | type: MultiStepLR 60 | milestones: [50000] 61 | gamma: 0.5 62 | 63 | total_iter: 100000 64 | warmup_iter: -1 # no warm up 65 | 66 | # losses 67 | l1_opt: 68 | type: ExampleLoss 69 | loss_weight: 1.0 70 | 71 | l2_opt: 72 | type: MSELoss 73 | loss_weight: 1.0 74 | reduction: mean 75 | 76 | # validation settings 77 | val: 78 | val_freq: !!float 5e3 79 | save_img: false 80 | 81 | metrics: 82 | psnr: # metric name, can be arbitrary 83 | type: calculate_psnr 84 | crop_border: 4 85 | test_y_channel: false 86 | niqe: 87 | type: calculate_niqe 88 | crop_border: 4 89 | 90 | # logging settings 91 | logger: 92 | print_freq: 100 93 | save_checkpoint_freq: !!float 5e3 94 | use_tb_logger: true 95 | wandb: 96 | project: ~ 97 | resume_id: ~ 98 | 99 | # dist training settings 100 | dist_params: 101 | backend: nccl 102 | port: 29500 103 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | basicsr 2 | numpy 3 | opencv-python 4 | requests 5 | torch 6 | torchvision 7 | -------------------------------------------------------------------------------- /scripts/prepare_example_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | 4 | 5 | def main(url, dataset): 6 | # download 7 | print(f'Download {url} ...') 8 | response = requests.get(url) 9 | with open(f'datasets/example/{dataset}.zip', 'wb') as f: 10 | f.write(response.content) 11 | 12 | # unzip 13 | import zipfile 14 | with zipfile.ZipFile(f'datasets/example/{dataset}.zip', 'r') as zip_ref: 15 | zip_ref.extractall(f'datasets/example/{dataset}') 16 | 17 | 18 | if __name__ == '__main__': 19 | """This script will download and prepare the example data: 20 | 1. BSDS100 for training 21 | 2. Set5 for testing 22 | """ 23 | os.makedirs('datasets/example', exist_ok=True) 24 | 25 | urls = [ 26 | 'https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip', 27 | 'https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip' 28 | ] 29 | datasets = ['BSDS100', 'Set5'] 30 | for url, dataset in zip(urls, datasets): 31 | main(url, dataset) 32 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=120 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | column_limit = 120 12 | blank_line_before_nested_class_or_def = true 13 | split_before_expression_after_opening_paren = true 14 | 15 | [isort] 16 | line_length = 120 17 | multi_line_output = 0 18 | known_standard_library = pkg_resources,setuptools 19 | known_first_party = basicsr 20 | known_third_party = cv2,requests,torch,torchvision 21 | no_lines_before = STDLIB,LOCALFOLDER 22 | default_section = THIRDPARTY 23 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | import os.path as osp 3 | 4 | import archs 5 | import data 6 | import losses 7 | import models 8 | from basicsr.train import train_pipeline 9 | 10 | if __name__ == '__main__': 11 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 12 | train_pipeline(root_path) 13 | --------------------------------------------------------------------------------