├── .github
└── workflows
│ └── pylint.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── README_CN.md
├── VERSION
├── archs
├── __init__.py
└── example_arch.py
├── data
├── __init__.py
└── example_dataset.py
├── datasets
└── README.md
├── experiments
└── README.md
├── losses
├── __init__.py
└── example_loss.py
├── models
├── __init__.py
└── example_model.py
├── options
└── example_option.yml
├── requirements.txt
├── scripts
└── prepare_example_data.py
├── setup.cfg
└── train.py
/.github/workflows/pylint.yml:
--------------------------------------------------------------------------------
1 | name: PyLint
2 |
3 | on: [push, pull_request]
4 |
5 | jobs:
6 | build:
7 |
8 | runs-on: ubuntu-latest
9 | strategy:
10 | matrix:
11 | python-version: [3.8]
12 |
13 | steps:
14 | - uses: actions/checkout@v2
15 | - name: Set up Python ${{ matrix.python-version }}
16 | uses: actions/setup-python@v2
17 | with:
18 | python-version: ${{ matrix.python-version }}
19 |
20 | - name: Install dependencies
21 | run: |
22 | python -m pip install --upgrade pip
23 | pip install flake8 yapf isort
24 |
25 | - name: Lint
26 | run: |
27 | flake8 .
28 | isort --check-only --diff archs/ data/ models/ scripts/ train.py
29 | yapf -r -d archs/ data/ models/ scripts/ train.py
30 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | datasets/*
2 | experiments/*
3 | results/*
4 | tb_logger/*
5 | wandb/*
6 | tmp/*
7 |
8 | *.DS_Store
9 |
10 | # Byte-compiled / optimized / DLL files
11 | __pycache__/
12 | *.py[cod]
13 | *$py.class
14 |
15 | # C extensions
16 | *.so
17 |
18 | # Distribution / packaging
19 | .Python
20 | build/
21 | develop-eggs/
22 | dist/
23 | downloads/
24 | eggs/
25 | .eggs/
26 | lib/
27 | lib64/
28 | parts/
29 | sdist/
30 | var/
31 | wheels/
32 | pip-wheel-metadata/
33 | share/python-wheels/
34 | *.egg-info/
35 | .installed.cfg
36 | *.egg
37 | MANIFEST
38 |
39 | # PyInstaller
40 | # Usually these files are written by a python script from a template
41 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
42 | *.manifest
43 | *.spec
44 |
45 | # Installer logs
46 | pip-log.txt
47 | pip-delete-this-directory.txt
48 |
49 | # Unit test / coverage reports
50 | htmlcov/
51 | .tox/
52 | .nox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | *.py,cover
60 | .hypothesis/
61 | .pytest_cache/
62 |
63 | # Translations
64 | *.mo
65 | *.pot
66 |
67 | # Django stuff:
68 | *.log
69 | local_settings.py
70 | db.sqlite3
71 | db.sqlite3-journal
72 |
73 | # Flask stuff:
74 | instance/
75 | .webassets-cache
76 |
77 | # Scrapy stuff:
78 | .scrapy
79 |
80 | # Sphinx documentation
81 | docs/_build/
82 |
83 | # PyBuilder
84 | target/
85 |
86 | # Jupyter Notebook
87 | .ipynb_checkpoints
88 |
89 | # IPython
90 | profile_default/
91 | ipython_config.py
92 |
93 | # pyenv
94 | .python-version
95 |
96 | # pipenv
97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
100 | # install all needed dependencies.
101 | #Pipfile.lock
102 |
103 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
104 | __pypackages__/
105 |
106 | # Celery stuff
107 | celerybeat-schedule
108 | celerybeat.pid
109 |
110 | # SageMath parsed files
111 | *.sage.py
112 |
113 | # Environments
114 | .env
115 | .venv
116 | env/
117 | venv/
118 | ENV/
119 | env.bak/
120 | venv.bak/
121 |
122 | # Spyder project settings
123 | .spyderproject
124 | .spyproject
125 |
126 | # Rope project settings
127 | .ropeproject
128 |
129 | # mkdocs documentation
130 | /site
131 |
132 | # mypy
133 | .mypy_cache/
134 | .dmypy.json
135 | dmypy.json
136 |
137 | # Pyre type checker
138 | .pyre/
139 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | # flake8
3 | - repo: https://github.com/PyCQA/flake8
4 | rev: 3.8.3
5 | hooks:
6 | - id: flake8
7 | args: ["--config=setup.cfg", "--ignore=W504, W503"]
8 |
9 | # modify known_third_party
10 | - repo: https://github.com/asottile/seed-isort-config
11 | rev: v2.2.0
12 | hooks:
13 | - id: seed-isort-config
14 |
15 | # isort
16 | - repo: https://github.com/timothycrosley/isort
17 | rev: 5.2.2
18 | hooks:
19 | - id: isort
20 |
21 | # yapf
22 | - repo: https://github.com/pre-commit/mirrors-yapf
23 | rev: v0.30.0
24 | hooks:
25 | - id: yapf
26 |
27 | # pre-commit-hooks
28 | - repo: https://github.com/pre-commit/pre-commit-hooks
29 | rev: v3.2.0
30 | hooks:
31 | - id: trailing-whitespace # Trim trailing whitespace
32 | - id: check-yaml # Attempt to load all yaml files to verify syntax
33 | - id: check-merge-conflict # Check for files that contain merge conflict strings
34 | - id: double-quote-string-fixer # Replace double quoted strings with single quoted strings
35 | - id: end-of-file-fixer # Make sure files end in a newline and only a newline
36 | - id: requirements-txt-fixer # Sort entries in requirements.txt and remove incorrect entry for pkg-resources==0.0.0
37 | - id: fix-encoding-pragma # Remove the coding pragma: # -*- coding: utf-8 -*-
38 | args: ["--remove"]
39 | - id: mixed-line-ending # Replace or check mixed line ending
40 | args: ["--fix=lf"]
41 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "files.trimTrailingWhitespace": true,
3 | "editor.wordWrap": "on",
4 | "editor.rulers": [
5 | 80,
6 | 120
7 | ],
8 | "editor.renderWhitespace": "all",
9 | "editor.renderControlCharacters": true,
10 | "python.formatting.provider": "yapf",
11 | "python.formatting.yapfArgs": [
12 | "--style",
13 | "{BASED_ON_STYLE = pep8, BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = true, SPLIT_BEFORE_EXPRESSION_AFTER_OPENING_PAREN = true, COLUMN_LIMIT = 120}"
14 | ],
15 | "python.linting.flake8Enabled": true,
16 | "python.linting.flake8Args": [
17 | "max-line-length=120"
18 | ],
19 | }
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Xintao Wang and BasicSR-examples contributors
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # :rocket: BasicSR Examples
2 |
3 | [](https://github.com/xinntao/BasicSR-examples/releases)
4 | [](https://github.com/xinntao/BasicSR-examples/issues)
5 | [](https://github.com/xinntao/BasicSR-examples/issues)
6 | [](https://github.com/xinntao/BasicSR-examples/blob/master/LICENSE)
7 | [](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml)
8 |
9 | [English](README.md) **|** [简体中文](README_CN.md)
10 | [`BasicSR repo`](https://github.com/xinntao/BasicSR) **|** [`simple mode example`](https://github.com/xinntao/BasicSR-examples/tree/master) **|** [`installation mode example`](https://github.com/xinntao/BasicSR-examples/tree/installation)
11 |
12 | In this repository, we give examples to illustrate **how to easily use** [`BasicSR`](https://github.com/xinntao/BasicSR) in **your own project**.
13 |
14 | :triangular_flag_on_post: **Projects that use BasicSR**
15 | - :white_check_mark: [**GFPGAN**](https://github.com/TencentARC/GFPGAN): A practical algorithm for real-world face restoration
16 | - :white_check_mark: [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): A practical algorithm for general image restoration
17 |
18 | If you use `BasicSR` in your open-source projects, welcome to contact me (by [email](#e-mail-contact) or opening an issue/pull request). I will add your projects to the above list :blush:
19 |
20 | ---
21 |
22 | If this repo is helpful, please help to :star: this repo or recommend it to your friends. Thanks:blush:
23 | Other recommended projects:
24 | :arrow_forward: [facexlib](https://github.com/xinntao/facexlib): A collection that provides useful face-relation functions.
25 | :arrow_forward: [HandyView](https://github.com/xinntao/HandyView): A PyQt5-based image viewer that is handy for view and comparison.
26 |
27 | ---
28 |
29 | ## Contents
30 |
31 | - [HowTO use BasicSR](#HowTO-use-BasicSR)
32 | - [As s Template](#As-a-Template)
33 |
34 | ## HowTO use BasicSR
35 |
36 | `BasicSR` can be used in two ways:
37 | - :arrow_right: Git clone the entire BasicSR. In this way, you can see the complete codes of BasicSR, and then modify them according to your own needs.
38 | - :arrow_right: Use basicsr as a [python package](https://pypi.org/project/basicsr/#history) (that is, install with pip). It provides the training framework, procedures, and some basic functions. You can easily build your own projects based on basicsr.
39 | ```bash
40 | pip install basicsr
41 | ```
42 |
43 | Our example mainly focuses on the second one, that is, how to easily and concisely build your own project based on the basicsr package.
44 |
45 | There are two ways to use the python package of basicsr, which are provided in two branches:
46 |
47 | - :arrow_right: [simple mode](https://github.com/xinntao/BasicSR-examples/tree/master): the project can be run **without installation**. But it has limitations: it is inconvenient to import complex hierarchical relationships; It is not easy to access the functions in this project from other locations
48 |
49 | - :arrow_right: [installation mode](https://github.com/xinntao/BasicSR-examples/tree/installation): you need to install the project by running `python setup.py develop`. After installation, it is more convenient to import and use.
50 |
51 | As a simple introduction and explanation, we use the example of *simple mode*, but we recommend the *installation mode* in practical use.
52 |
53 | ```bash
54 | git clone https://github.com/xinntao/BasicSR-examples.git
55 | cd BasicSR-examples
56 | ```
57 |
58 | ### Preliminary
59 |
60 | Most deep-learning projects can be divided into the following parts:
61 |
62 | 1. **data**: defines the training/validation data that is fed into the model training
63 | 2. **arch** (architecture): defines the network structure and the forward steps
64 | 3. **model**: defines the necessary components in training (such as loss) and a complete training process (including forward propagation, back-propagation, gradient optimization, *etc*.), as well as other functions, such as validation, *etc*
65 | 4. Training pipeline: defines the training process, that is, connect the data-loader, model, validation, saving checkpoints, *etc*
66 |
67 | When we are developing a new method, we often improve the **data**, **arch**, and **model**. Most training processes and basic functions are actually shared. Then, we hope to focus on the development of main functions instead of building wheels repeatedly.
68 |
69 | Therefore, we have BasicSR, which separates many shared functions. With BasicSR, we just need to care about the development of **data**, **arch**, and **model**.
70 |
71 | In order to further facilitate the use of BasicSR, we provide the basicsr package. You can easily install it through `pip install basicsr`. After that, you can use the training process of BasicSR and the functions already developed in BasicSR~
72 |
73 | ### A Simple Example
74 |
75 | Let's use a simple example to illustrate how to use BasicSR to build your own project.
76 |
77 | We provide two sample data for demonstration:
78 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training
79 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation
80 |
81 | You can easily download them by running the following command in the BasicSR-examples root path:
82 |
83 | ```bash
84 | python scripts/prepare_example_data.py
85 | ```
86 |
87 | The sample data are now in the `datasets/example` folder.
88 |
89 | #### :zero: Purpose
90 |
91 | Let's use a Super-Resolution task for the demo.
92 | It takes a low-resolution image as the input and outputs a high-resolution image.
93 | The low-resolution images contain: 1) CV2 bicubic X4 downsampling, and 2) JPEG compression (quality = 70).
94 |
95 | In order to better explain how to use the arch and model, we use 1) a network structure similar to SRCNN; 2) use L1 and L2 (MSE) loss simultaneously in training.
96 |
97 | So, in this task, what we should do are:
98 |
99 | 1. Build our own data loader
100 | 1. Determine the architecture
101 | 1. Build our own model
102 |
103 | Let's explain it separately in the following parts.
104 |
105 | #### :one: data
106 |
107 | We need to implement a new dataset to fulfill our purpose. The dataset is used to feed the data into the model.
108 |
109 | An example of this dataset is in [data/example_dataset.py](data/example_dataset.py). It has the following steps.
110 |
111 | 1. Read Ground-Truth (GT) images. BasicSR provides [FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py) for easily reading files in a folder, LMDB file and meta_info txt. In this example, we use the folder mode. For more reading modes, please refer to [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data)
112 | 1. Synthesize low resolution images. We can directly implement the data procedures in the `__getitem__(self, index)` function, such as downsampling and adding JPEG compression. Many basic operations can be found in [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) ,and [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py)
113 | 1. Convert to torch tensor and return appropriate information
114 |
115 | **Note**:
116 |
117 | 1. Please add `@DATASET_REGISTRY.register()` before `ExampleDataset`. This operation is mainly used to prevent the occurrence of a dataset with the same name, which will result in potential bugs
118 | 1. The new dataset file should end with `_dataset.py`, such as `example_dataset.py`. In this way, the program can **automatically** import classes without manual import
119 |
120 | In the [option configuration file](options/example_option.yml), you can use the new dataset:
121 |
122 | ```yaml
123 | datasets:
124 | train: # training dataset
125 | name: ExampleBSDS100
126 | type: ExampleDataset # the class name
127 |
128 | # ----- the followings are the arguments of ExampleDataset ----- #
129 | dataroot_gt: datasets/example/BSDS100
130 | io_backend:
131 | type: disk
132 |
133 | gt_size: 128
134 | use_flip: true
135 | use_rot: true
136 |
137 | # ----- arguments of data loader ----- #
138 | use_shuffle: true
139 | num_worker_per_gpu: 3
140 | batch_size_per_gpu: 16
141 | dataset_enlarge_ratio: 10
142 | prefetch_mode: ~
143 |
144 | val: # validation dataset
145 | name: ExampleSet5
146 | type: ExampleDataset
147 | dataroot_gt: datasets/example/Set5
148 | io_backend:
149 | type: disk
150 | ```
151 |
152 | #### :two: arch
153 |
154 | An example of architecture is in [archs/example_arch.py](archs/example_arch.py). It mainly builds the network structure.
155 |
156 | **Note**:
157 |
158 | 1. Add `@ARCH_REGISTRY.register()` before `ExampleArch`, so as to register the newly implemented arch. This operation is mainly used to prevent the occurrence of arch with the same name, resulting in potential bugs
159 | 1. The new arch file should end with `_arch.py`, such as `example_arch.py`. In this way, the program can **automatically** import classes without manual import
160 |
161 | In the [option configuration file](options/example_option.yml), you can use the new arch:
162 |
163 | ```yaml
164 | # network structures
165 | network_g:
166 | type: ExampleArch # the class name
167 |
168 | # ----- the followings are the arguments of ExampleArch ----- #
169 | num_in_ch: 3
170 | num_out_ch: 3
171 | num_feat: 64
172 | upscale: 4
173 | ```
174 |
175 | #### :three: model
176 |
177 | An example of model is in [models/example_model.py](models/example_model.py). It mainly builds the training process of a model.
178 |
179 | In this file:
180 | 1. We inherit `SRModel` from basicsr. Many models have similar operations, so you can inherit and modify from [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models). In this way, you can easily implement your ideas, such as GAN model, video model, *etc*.
181 | 1. Two losses are used: L1 and L2 (MSE) loss
182 | 1. Many other contents, such as `setup_optimizers`, `validation`, `save`, *etc*, are inherited from `SRModel`
183 |
184 | **Note**:
185 |
186 | 1. Add `@MODEL_REGISTRY.register()` before `ExampleModel`, so as to register the newly implemented model. This operation is mainly used to prevent the occurrence of model with the same name, resulting in potential bugs
187 | 1. The new model file should end with `_model.py`, such as `example_model.py`. In this way, the program can **automatically** import classes without manual import
188 |
189 | In the [option configuration file](options/example_option.yml), you can use the new model:
190 |
191 | ```yaml
192 | # training settings
193 | train:
194 | optim_g:
195 | type: Adam
196 | lr: !!float 2e-4
197 | weight_decay: 0
198 | betas: [0.9, 0.99]
199 |
200 | scheduler:
201 | type: MultiStepLR
202 | milestones: [50000]
203 | gamma: 0.5
204 |
205 | total_iter: 100000
206 | warmup_iter: -1 # no warm up
207 |
208 | # ----- the followings are the configurations for two losses ----- #
209 | # losses
210 | l1_opt:
211 | type: L1Loss
212 | loss_weight: 1.0
213 | reduction: mean
214 |
215 | l2_opt:
216 | type: MSELoss
217 | loss_weight: 1.0
218 | reduction: mean
219 | ```
220 |
221 | #### :four: training pipeline
222 |
223 | The whole training pipeline can reuse the [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py) in BasicSR.
224 |
225 | Based on this, our [train.py](train.py) can be very concise:
226 |
227 | ```python
228 | import os.path as osp
229 |
230 | import archs # noqa: F401
231 | import data # noqa: F401
232 | import models # noqa: F401
233 | from basicsr.train import train_pipeline
234 |
235 | if __name__ == '__main__':
236 | root_path = osp.abspath(osp.join(__file__, osp.pardir))
237 | train_pipeline(root_path)
238 |
239 | ```
240 |
241 | #### :five: debug mode
242 |
243 | So far, we have completed the development of our project. We can quickly check whether there is a bug through the `debug` mode:
244 |
245 | ```bash
246 | python train.py -opt options/example_option.yml --debug
247 | ```
248 |
249 | With `--debug`, the program will enter the debug mode. In the debug mode, the program will output at each iteration, and perform validation every 8 iterations, so that you can easily know whether the program has a bug~
250 |
251 | #### :six: normal training
252 |
253 | After debugging, we can have the normal training.
254 |
255 | ```bash
256 | python train.py -opt options/example_option.yml
257 | ```
258 |
259 | If the training process is interrupted unexpectedly and the resume is required. Please use `--auto_resume` in the command:
260 |
261 | ```bash
262 | python train.py -opt options/example_option.yml --auto_resume
263 | ```
264 |
265 | So far, you have finished developing your own projects using `BasicSR`. Isn't it very convenient~ :grin:
266 |
267 | ## As a Template
268 |
269 | You can use BasicSR-Examples as a template for your project. Here are some modifications you may need.
270 |
271 | 1. Set up the *pre-commit* hook
272 | 1. In the root path, run:
273 | > pre-commit install
274 | 1. Modify the `LICENSE`
275 | This repository uses the *MIT* license, you may change it to other licenses
276 |
277 | The simple mode do not require many modifications. Those using the installation mode may need more modifications. See [here](https://github.com/xinntao/BasicSR-examples/blob/installation/README.md#As-a-Template)
278 |
279 | ## :e-mail: Contact
280 |
281 | If you have any questions or want to add your project to the list, please email `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
282 |
--------------------------------------------------------------------------------
/README_CN.md:
--------------------------------------------------------------------------------
1 | # :rocket: BasicSR Examples
2 |
3 | [](https://github.com/xinntao/BasicSR-examples/releases)
4 | [](https://github.com/xinntao/BasicSR-examples/issues)
5 | [](https://github.com/xinntao/BasicSR-examples/issues)
6 | [](https://github.com/xinntao/BasicSR-examples/blob/master/LICENSE)
7 | [](https://github.com/xinntao/BasicSR/blob/master/.github/workflows/pylint.yml)
8 |
9 | [English](README.md) **|** [简体中文](README_CN.md)
10 | [`BasicSR repo`](https://github.com/xinntao/BasicSR) **|** [`simple mode example`](https://github.com/xinntao/BasicSR-examples/tree/master) **|** [`installation mode example`](https://github.com/xinntao/BasicSR-examples/tree/installation)
11 |
12 | 在这个仓库中,我们通过简单的例子来说明:如何在**你自己的项目中**轻松地**使用** [`BasicSR`](https://github.com/xinntao/BasicSR)。
13 |
14 | :triangular_flag_on_post: **使用 BasicSR 的项目**
15 | - :white_check_mark: [**GFPGAN**](https://github.com/TencentARC/GFPGAN): 真实场景人脸复原的实用算法
16 | - :white_check_mark: [**Real-ESRGAN**](https://github.com/xinntao/Real-ESRGAN): 通用图像复原的实用算法
17 |
18 | 如果你的开源项目中使用了`BasicSR`, 欢迎联系我 ([邮件](#e-mail-%E8%81%94%E7%B3%BB)或者开一个issue/pull request)。我会将你的开源项目添加到上面的列表中 :blush:
19 |
20 | ---
21 |
22 | 如果你觉得这个项目对你有帮助,欢迎 :star: 这个仓库或推荐给你的朋友。Thanks:blush:
23 | 其他推荐的项目:
24 | :arrow_forward: [facexlib](https://github.com/xinntao/facexlib): 提供实用的人脸相关功能的集合
25 | :arrow_forward: [HandyView](https://github.com/xinntao/HandyView): 基于PyQt5的 方便的看图比图工具
26 |
27 | ---
28 |
29 | ## 目录
30 |
31 | - [如何使用 BasicSR](#HowTO-use-BasicSR)
32 | - [作为Template](#As-a-Template)
33 |
34 | ## HowTO use BasicSR
35 |
36 | `BasicSR` 有两种使用方式:
37 | - :arrow_right: Git clone 整个 BasicSR 的代码。这样可以看到 BasicSR 完整的代码,然后根据你自己的需求进行修改
38 | - :arrow_right: BasicSR 作为一个 [python package](https://pypi.org/project/basicsr/#history) (即可以通过pip安装),提供了训练的框架,流程和一些基本功能。你可以基于 basicsr 方便地搭建你自己的项目
39 | ```bash
40 | pip install basicsr
41 | ```
42 |
43 | 我们的样例主要针对第二种使用方式,即如何基于 basicsr 这个package来方便简洁地搭建你自己的项目。
44 |
45 | 使用 basicsr 的python package又有两种方式,我们分别提供在两个 branch 中:
46 | - :arrow_right: [简单模式](https://github.com/xinntao/BasicSR-examples/tree/master): 项目的仓库不需要安装,就可以运行使用。但它有局限:不方便 import 复杂的层级关系;在其他位置也不容易访问本项目中的函数
47 | - :arrow_right: [安装模式](https://github.com/xinntao/BasicSR-examples/tree/installation): 项目的仓库需要安装 `python setup.py develop`,安装之后 import 和使用都更加方便
48 |
49 | 作为简单的入门和讲解, 我们使用*简单模式*的样例,但在实际使用中我们推荐*安装模式*。
50 |
51 | ```bash
52 | git clone https://github.com/xinntao/BasicSR-examples.git
53 | cd BasicSR-examples
54 | ```
55 |
56 | ### 预备
57 |
58 | 大部分的深度学习项目,都可以分为以下几个部分:
59 |
60 | 1. **data**: 定义了训练数据,来喂给模型的训练过程
61 | 2. **arch** (architecture): 定义了网络结构 和 forward 的步骤
62 | 3. **model**: 定义了在训练中必要的组件(比如 loss) 和 一次完整的训练过程(包括前向传播,反向传播,梯度优化等),还有其他功能,比如 validation等
63 | 4. training pipeline: 定义了训练的流程,即把数据 dataloader,模型,validation,保存 checkpoints 等等串联起来
64 |
65 | 当我们开发一个新的方法时,我们往往在改进: **data**, **arch**, **model**;而很多流程、基础的功能其实是共用的。那么,我们希望可以专注于主要功能的开发,而不要重复造轮子。
66 |
67 | 因此便有了 BasicSR,它把很多相似的功能都独立出来,我们只要关心 **data**, **arch**, **model** 的开发即可。
68 |
69 | 为了进一步方便大家使用,我们提供了 basicsr package,大家可以通过 `pip install basicsr` 方便地安装,然后就可以使用 BasicSR 的训练流程以及在 BasicSR 里面已开发好的功能啦~
70 |
71 | ### 简单的例子
72 |
73 | 下面我们就通过一个简单的例子,来说明如何使用 BasicSR 来搭建你自己的项目。
74 |
75 | 我们提供了两个样例数据来做展示,
76 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training
77 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation
78 |
79 | 在 BasicSR-example 的根目录运行下面的命令来下载:
80 |
81 | ```bash
82 | python scripts/prepare_example_data.py
83 | ```
84 |
85 | 样例数据就下载在 `datasets/example` 文件夹中。
86 |
87 | #### :zero: 目的
88 |
89 | 我们来假设一个超分辨率 (Super-Resolution) 的任务,输入一张低分辨率的图片,输出高分辨率的图片。低分辨率图片包含了 1) cv2 的 bicubic X4 downsampling 和 2) JPEG 压缩 (quality=70)。
90 |
91 | 为了更好的说明如何使用 arch 和 model,我们想要使用 1) 类似 SRCNN 的网络结构;2) 在训练中同时使用 L1 和 L2 (MSE) loss。
92 |
93 | 那么,在这个任务中,我们要做的是:
94 |
95 | 1. 构建自己的 data loader
96 | 1. 确定使用的 architecture
97 | 1. 构建自己的 model
98 |
99 | 下面我们分别来说明一下。
100 |
101 | #### :one: data
102 |
103 | 这个部分是用来确定喂给模型的数据的。
104 |
105 | 这个 dataset 的例子在[data/example_dataset.py](data/example_dataset.py) 中,它完成了:
106 | 1. 我们读取 Ground-Truth (GT) 的图像。读取的操作,BasicSR 提供了[FileClient](https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/file_client.py), 可以方便地读取 folder, lmdb 和 meta_info txt 指定的文件。在这个例子中,我们通过读取 folder 来说明,更多的读取模式可以参考 [basicsr/data](https://github.com/xinntao/BasicSR/tree/master/basicsr/data)
107 | 1. 合成低分辨率的图像。我们直接可以在 `__getitem__(self, index)` 的函数中实现我们想要的操作,比如降采样和添加 JPEG 压缩。很多基本操作都可以在 [[basicsr/data/degradations]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py), [[basicsr/data/tranforms]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/transforms.py) 和 [[basicsr/data/data_util]](https://github.com/xinntao/BasicSR/blob/master/basicsr/data/data_util.py) 中找到
108 | 1. 转换成 Torch Tensor,返回合适的信息
109 |
110 | **注意**:
111 | 1. 需要在 `ExampleDataset` 前添加 `@DATASET_REGISTRY.register()`,以便注册好新写的 dataset。这个操作主要用来防止出现同名的 dataset,从而带来潜在的 bug
112 | 1. 新写的 dataset 文件要以 `_dataset.py` 结尾,比如 `example_dataset.py`。 这样,程序可以**自动地** import,而不需要手动地 import
113 |
114 | 在 [option 配置文件中](options/example_option.yml)使用新写的 dataset:
115 |
116 | ```yaml
117 | datasets:
118 | train: # training dataset
119 | name: ExampleBSDS100
120 | type: ExampleDataset # the class name
121 |
122 | # ----- the followings are the arguments of ExampleDataset ----- #
123 | dataroot_gt: datasets/example/BSDS100
124 | io_backend:
125 | type: disk
126 |
127 | gt_size: 128
128 | use_flip: true
129 | use_rot: true
130 |
131 | # ----- arguments of data loader ----- #
132 | use_shuffle: true
133 | num_worker_per_gpu: 3
134 | batch_size_per_gpu: 16
135 | dataset_enlarge_ratio: 10
136 | prefetch_mode: ~
137 |
138 | val: # validation dataset
139 | name: ExampleSet5
140 | type: ExampleDataset
141 | dataroot_gt: datasets/example/Set5
142 | io_backend:
143 | type: disk
144 | ```
145 |
146 | #### :two: arch
147 |
148 | Architecture 的例子在 [archs/example_arch.py](archs/example_arch.py)中。它主要搭建了网络结构。
149 |
150 | **注意**:
151 | 1. 需要在 `ExampleArch` 前添加 `@ARCH_REGISTRY.register()`,以便注册好新写的 arch。这个操作主要用来防止出现同名的 arch,从而带来潜在的 bug
152 | 1. 新写的 arch 文件要以 `_arch.py` 结尾,比如 `example_arch.py`。 这样,程序可以**自动地** import,而不需要手动地 import
153 |
154 | 在 [option 配置文件中](options/example_option.yml)使用新写的 arch:
155 |
156 | ```yaml
157 | # network structures
158 | network_g:
159 | type: ExampleArch # the class name
160 |
161 | # ----- the followings are the arguments of ExampleArch ----- #
162 | num_in_ch: 3
163 | num_out_ch: 3
164 | num_feat: 64
165 | upscale: 4
166 | ```
167 |
168 | #### :three: model
169 |
170 | Model 的例子在 [models/example_model.py](models/example_model.py)中。它主要搭建了模型的训练过程。
171 | 在这个文件中:
172 | 1. 我们从 basicsr 中继承了 `SRModel`。很多模型都有相似的操作,因此可以通过继承 [basicsr/models](https://github.com/xinntao/BasicSR/tree/master/basicsr/models) 中的模型来更方便地实现自己的想法,比如GAN模型,Video模型等
173 | 1. 使用了两个 Loss: L1 和 L2 (MSE) loss
174 | 1. 其他很多内容,比如 `setup_optimizers`, `validation`, `save`等,都是继承于 `SRModel`
175 |
176 | **注意**:
177 | 1. 需要在 `ExampleModel` 前添加 `@MODEL_REGISTRY.register()`,以便注册好新写的 model。这个操作主要用来防止出现同名的 model,从而带来潜在的 bug
178 | 1. 新写的 model 文件要以 `_model.py` 结尾,比如 `example_model.py`。 这样,程序可以**自动地** import,而不需要手动地 import
179 |
180 | 在 [option 配置文件中](options/example_option.yml)使用新写的 model:
181 |
182 | ```yaml
183 | # training settings
184 | train:
185 | optim_g:
186 | type: Adam
187 | lr: !!float 2e-4
188 | weight_decay: 0
189 | betas: [0.9, 0.99]
190 |
191 | scheduler:
192 | type: MultiStepLR
193 | milestones: [50000]
194 | gamma: 0.5
195 |
196 | total_iter: 100000
197 | warmup_iter: -1 # no warm up
198 |
199 | # ----- the followings are the configurations for two losses ----- #
200 | # losses
201 | l1_opt:
202 | type: L1Loss
203 | loss_weight: 1.0
204 | reduction: mean
205 |
206 | l2_opt:
207 | type: MSELoss
208 | loss_weight: 1.0
209 | reduction: mean
210 | ```
211 |
212 | #### :four: training pipeline
213 |
214 | 整个 training pipeline 可以复用 basicsr 里面的 [basicsr/train.py](https://github.com/xinntao/BasicSR/blob/master/basicsr/train.py)。
215 |
216 | 基于此,我们的 [train.py](train.py)可以非常简洁。
217 |
218 | ```python
219 | import os.path as osp
220 |
221 | import archs # noqa: F401
222 | import data # noqa: F401
223 | import models # noqa: F401
224 | from basicsr.train import train_pipeline
225 |
226 | if __name__ == '__main__':
227 | root_path = osp.abspath(osp.join(__file__, osp.pardir))
228 | train_pipeline(root_path)
229 |
230 | ```
231 |
232 | #### :five: debug mode
233 |
234 | 至此,我们已经完成了我们这个项目的开发,下面可以通过 `debug` 模式来快捷地看看是否有问题:
235 |
236 | ```bash
237 | python train.py -opt options/example_option.yml --debug
238 | ```
239 |
240 | 只要带上 `--debug` 就进入 debug 模式。在 debug 模式中,程序每个iter都会输出,8个iter后就会进行validation,这样可以很方便地知道程序有没有bug啦~
241 |
242 | #### :six: normal training
243 |
244 | 经过debug没有问题后,我们就可以正式训练了。
245 |
246 | ```bash
247 | python train.py -opt options/example_option.yml
248 | ```
249 |
250 | 如果训练过程意外中断需要 resume, 则使用 `--auto_resume` 可以方便地自动resume:
251 | ```bash
252 | python train.py -opt options/example_option.yml --auto_resume
253 | ```
254 |
255 | 至此,使用 `BasicSR` 开发你自己的项目就介绍完了,是不是很方便呀~ :grin:
256 |
257 | ## As a Template
258 |
259 | 你可以使用 BasicSR-Examples 作为你项目的模板。下面主要展示一下你可能需要的修改。
260 |
261 | 1. 设置 *pre-commit* hook
262 | 1. 在文件夹根目录, 运行
263 | > pre-commit install
264 | 1. 修改 `LICENSE` 文件
265 | 本仓库使用 *MIT* 许可, 根据需要可以修改成其他许可
266 |
267 | 使用 简单模式 的基本不需要修改,使用 安装模式 的可能需要较多修改,参见[这里](https://github.com/xinntao/BasicSR-examples/blob/installation/README_CN.md#As-a-Template)
268 |
269 | ## :e-mail: 联系
270 |
271 | 如果你有任何问题,或者想要添加你的项目到列表中,欢迎电邮
272 | `xintao.wang@outlook.com` or `xintaowang@tencent.com`.
273 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.1.0
2 |
--------------------------------------------------------------------------------
/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules for registry
7 | # scan all the files that end with '_arch.py' under the archs folder
8 | arch_folder = osp.dirname(osp.abspath(__file__))
9 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
10 | # import all the arch modules
11 | _arch_modules = [importlib.import_module(f'archs.{file_name}') for file_name in arch_filenames]
12 |
--------------------------------------------------------------------------------
/archs/example_arch.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 | from basicsr.archs.arch_util import default_init_weights
5 | from basicsr.utils.registry import ARCH_REGISTRY
6 |
7 |
8 | @ARCH_REGISTRY.register()
9 | class ExampleArch(nn.Module):
10 | """Example architecture.
11 |
12 | Args:
13 | num_in_ch (int): Channel number of inputs. Default: 3.
14 | num_out_ch (int): Channel number of outputs. Default: 3.
15 | num_feat (int): Channel number of intermediate features. Default: 64.
16 | upscale (int): Upsampling factor. Default: 4.
17 | """
18 |
19 | def __init__(self, num_in_ch=3, num_out_ch=3, num_feat=64, upscale=4):
20 | super(ExampleArch, self).__init__()
21 | self.upscale = upscale
22 |
23 | self.conv1 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
24 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
25 | self.conv3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
26 |
27 | self.upconv1 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
28 | self.upconv2 = nn.Conv2d(num_feat, num_feat * 4, 3, 1, 1)
29 | self.pixel_shuffle = nn.PixelShuffle(2)
30 |
31 | self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
32 | self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
33 |
34 | # activation function
35 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
36 |
37 | # initialization
38 | default_init_weights(
39 | [self.conv1, self.conv2, self.conv3, self.upconv1, self.upconv2, self.conv_hr, self.conv_last], 0.1)
40 |
41 | def forward(self, x):
42 | feat = self.lrelu(self.conv1(x))
43 | feat = self.lrelu(self.conv2(feat))
44 | feat = self.lrelu(self.conv3(feat))
45 |
46 | out = self.lrelu(self.pixel_shuffle(self.upconv1(feat)))
47 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out)))
48 |
49 | out = self.conv_last(self.lrelu(self.conv_hr(out)))
50 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False)
51 | out += base
52 | return out
53 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import dataset modules for registry
7 | # scan all the files that end with '_dataset.py' under the data folder
8 | data_folder = osp.dirname(osp.abspath(__file__))
9 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
10 | # import all the dataset modules
11 | _dataset_modules = [importlib.import_module(f'data.{file_name}') for file_name in dataset_filenames]
12 |
--------------------------------------------------------------------------------
/data/example_dataset.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import os
3 | import torch
4 | from torch.utils import data as data
5 | from torchvision.transforms.functional import normalize
6 |
7 | from basicsr.data.degradations import add_jpg_compression
8 | from basicsr.data.transforms import augment, mod_crop, paired_random_crop
9 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
10 | from basicsr.utils.registry import DATASET_REGISTRY
11 |
12 |
13 | @DATASET_REGISTRY.register()
14 | class ExampleDataset(data.Dataset):
15 | """Example dataset.
16 |
17 | 1. Read GT image
18 | 2. Generate LQ (Low Quality) image with cv2 bicubic downsampling and JPEG compression
19 |
20 | Args:
21 | opt (dict): Config for train datasets. It contains the following keys:
22 | dataroot_gt (str): Data root path for gt.
23 | io_backend (dict): IO backend type and other kwarg.
24 | gt_size (int): Cropped patched size for gt patches.
25 | use_flip (bool): Use horizontal flips.
26 | use_rot (bool): Use rotation (use vertical flip and transposing h
27 | and w for implementation).
28 |
29 | scale (bool): Scale, which will be added automatically.
30 | phase (str): 'train' or 'val'.
31 | """
32 |
33 | def __init__(self, opt):
34 | super(ExampleDataset, self).__init__()
35 | self.opt = opt
36 | # file client (io backend)
37 | self.file_client = None
38 | self.io_backend_opt = opt['io_backend']
39 | self.mean = opt['mean'] if 'mean' in opt else None
40 | self.std = opt['std'] if 'std' in opt else None
41 |
42 | self.gt_folder = opt['dataroot_gt']
43 | # it now only supports folder mode, for other modes such as lmdb and meta_info file, please see:
44 | # https://github.com/xinntao/BasicSR/blob/master/basicsr/data/
45 | self.paths = [os.path.join(self.gt_folder, v) for v in list(scandir(self.gt_folder))]
46 |
47 | def __getitem__(self, index):
48 | if self.file_client is None:
49 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
50 |
51 | scale = self.opt['scale']
52 |
53 | # Load gt images. Dimension order: HWC; channel order: BGR;
54 | # image range: [0, 1], float32.
55 | gt_path = self.paths[index]
56 | img_bytes = self.file_client.get(gt_path, 'gt')
57 | img_gt = imfrombytes(img_bytes, float32=True)
58 | img_gt = mod_crop(img_gt, scale)
59 |
60 | # generate lq image
61 | # downsample
62 | h, w = img_gt.shape[0:2]
63 | img_lq = cv2.resize(img_gt, (w // scale, h // scale), interpolation=cv2.INTER_CUBIC)
64 | # add JPEG compression
65 | img_lq = add_jpg_compression(img_lq, quality=70)
66 |
67 | # augmentation for training
68 | if self.opt['phase'] == 'train':
69 | gt_size = self.opt['gt_size']
70 | # random crop
71 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
72 | # flip, rotation
73 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
74 |
75 | # BGR to RGB, HWC to CHW, numpy to tensor
76 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
77 |
78 | img_lq = torch.clamp((img_lq * 255.0).round(), 0, 255) / 255.
79 |
80 | # normalize
81 | if self.mean is not None or self.std is not None:
82 | normalize(img_lq, self.mean, self.std, inplace=True)
83 | normalize(img_gt, self.mean, self.std, inplace=True)
84 |
85 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': gt_path, 'gt_path': gt_path}
86 |
87 | def __len__(self):
88 | return len(self.paths)
89 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | This folder is mainly for storing datasets used for training/validation/testing.
4 |
5 | ## Practice
6 |
7 | 1. Separate your codes and datasets. So it is better to soft link your dataset (such as DIV2K, FFHQ, *etc*) here.
8 | ```bash
9 | ln -s DATASET_PATH ./
10 | ```
11 |
12 | ## Example Datasets
13 |
14 | We provide two example datasets for demo.
15 |
16 | 1. [BSDS100](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip) for training
17 | 1. [Set5](https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip) for validation
18 |
19 | You can easily download them by running the following command in the BasicSR-examples root path:
20 |
21 | ```bash
22 | python scripts/prepare_example_data.py
23 | ```
24 |
25 | The example datasets are now in the `datasets/example` folder.
26 |
--------------------------------------------------------------------------------
/experiments/README.md:
--------------------------------------------------------------------------------
1 | # Experiments
2 |
3 | Your experiment runs will be put in this folder.
4 |
--------------------------------------------------------------------------------
/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import loss modules for registry
7 | # scan all the files that end with '_loss.py' under the loss folder
8 | loss_folder = osp.dirname(osp.abspath(__file__))
9 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
10 | # import all the loss modules
11 | _model_modules = [importlib.import_module(f'losses.{file_name}') for file_name in loss_filenames]
12 |
--------------------------------------------------------------------------------
/losses/example_loss.py:
--------------------------------------------------------------------------------
1 | from torch import nn as nn
2 | from torch.nn import functional as F
3 |
4 | from basicsr.utils.registry import LOSS_REGISTRY
5 |
6 |
7 | @LOSS_REGISTRY.register()
8 | class ExampleLoss(nn.Module):
9 | """Example Loss.
10 |
11 | Args:
12 | loss_weight (float): Loss weight for Example loss. Default: 1.0.
13 | """
14 |
15 | def __init__(self, loss_weight=1.0):
16 | super(ExampleLoss, self).__init__()
17 | self.loss_weight = loss_weight
18 |
19 | def forward(self, pred, target, **kwargs):
20 | """
21 | Args:
22 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
23 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
24 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
25 | """
26 | return self.loss_weight * F.l1_loss(pred, target, reduction='mean')
27 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import model modules for registry
7 | # scan all the files that end with '_model.py' under the model folder
8 | model_folder = osp.dirname(osp.abspath(__file__))
9 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
10 | # import all the model modules
11 | _model_modules = [importlib.import_module(f'models.{file_name}') for file_name in model_filenames]
12 |
--------------------------------------------------------------------------------
/models/example_model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 |
3 | from basicsr.archs import build_network
4 | from basicsr.losses import build_loss
5 | from basicsr.models.sr_model import SRModel
6 | from basicsr.utils import get_root_logger
7 | from basicsr.utils.registry import MODEL_REGISTRY
8 |
9 |
10 | @MODEL_REGISTRY.register()
11 | class ExampleModel(SRModel):
12 | """Example model based on the SRModel class.
13 |
14 | In this example model, we want to implement a new model that trains with both L1 and L2 loss.
15 |
16 | New defined functions:
17 | init_training_settings(self)
18 | feed_data(self, data)
19 | optimize_parameters(self, current_iter)
20 |
21 | Inherited functions:
22 | __init__(self, opt)
23 | setup_optimizers(self)
24 | test(self)
25 | dist_validation(self, dataloader, current_iter, tb_logger, save_img)
26 | nondist_validation(self, dataloader, current_iter, tb_logger, save_img)
27 | _log_validation_metric_values(self, current_iter, dataset_name, tb_logger)
28 | get_current_visuals(self)
29 | save(self, epoch, current_iter)
30 | """
31 |
32 | def init_training_settings(self):
33 | self.net_g.train()
34 | train_opt = self.opt['train']
35 |
36 | self.ema_decay = train_opt.get('ema_decay', 0)
37 | if self.ema_decay > 0:
38 | logger = get_root_logger()
39 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
40 | # define network net_g with Exponential Moving Average (EMA)
41 | # net_g_ema is used only for testing on one GPU and saving
42 | # There is no need to wrap with DistributedDataParallel
43 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
44 | # load pretrained model
45 | load_path = self.opt['path'].get('pretrain_network_g', None)
46 | if load_path is not None:
47 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
48 | else:
49 | self.model_ema(0) # copy net_g weight
50 | self.net_g_ema.eval()
51 |
52 | # define losses
53 | self.l1_pix = build_loss(train_opt['l1_opt']).to(self.device)
54 | self.l2_pix = build_loss(train_opt['l2_opt']).to(self.device)
55 |
56 | # set up optimizers and schedulers
57 | self.setup_optimizers()
58 | self.setup_schedulers()
59 |
60 | def feed_data(self, data):
61 | self.lq = data['lq'].to(self.device)
62 | if 'gt' in data:
63 | self.gt = data['gt'].to(self.device)
64 |
65 | def optimize_parameters(self, current_iter):
66 | self.optimizer_g.zero_grad()
67 | self.output = self.net_g(self.lq)
68 |
69 | l_total = 0
70 | loss_dict = OrderedDict()
71 | # l1 loss
72 | l_l1 = self.l1_pix(self.output, self.gt)
73 | l_total += l_l1
74 | loss_dict['l_l1'] = l_l1
75 | # l2 loss
76 | l_l2 = self.l2_pix(self.output, self.gt)
77 | l_total += l_l2
78 | loss_dict['l_l2'] = l_l2
79 |
80 | l_total.backward()
81 | self.optimizer_g.step()
82 |
83 | self.log_dict = self.reduce_loss_dict(loss_dict)
84 |
85 | if self.ema_decay > 0:
86 | self.model_ema(decay=self.ema_decay)
87 |
--------------------------------------------------------------------------------
/options/example_option.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: BasicSR_example
3 | model_type: ExampleModel
4 | scale: 4
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 0
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: ExampleBSDS100
12 | type: ExampleDataset
13 | dataroot_gt: datasets/example/BSDS100
14 | io_backend:
15 | type: disk
16 |
17 | gt_size: 128
18 | use_flip: true
19 | use_rot: true
20 |
21 | # data loader
22 | use_shuffle: true
23 | num_worker_per_gpu: 3
24 | batch_size_per_gpu: 16
25 | dataset_enlarge_ratio: 10
26 | prefetch_mode: ~
27 |
28 | val:
29 | name: ExampleSet5
30 | type: ExampleDataset
31 | dataroot_gt: datasets/example/Set5
32 | io_backend:
33 | type: disk
34 |
35 | # network structures
36 | network_g:
37 | type: ExampleArch
38 | num_in_ch: 3
39 | num_out_ch: 3
40 | num_feat: 64
41 | upscale: 4
42 |
43 |
44 | # path
45 | path:
46 | pretrain_network_g: ~
47 | strict_load_g: true
48 | resume_state: ~
49 |
50 | # training settings
51 | train:
52 | optim_g:
53 | type: Adam
54 | lr: !!float 2e-4
55 | weight_decay: 0
56 | betas: [0.9, 0.99]
57 |
58 | scheduler:
59 | type: MultiStepLR
60 | milestones: [50000]
61 | gamma: 0.5
62 |
63 | total_iter: 100000
64 | warmup_iter: -1 # no warm up
65 |
66 | # losses
67 | l1_opt:
68 | type: ExampleLoss
69 | loss_weight: 1.0
70 |
71 | l2_opt:
72 | type: MSELoss
73 | loss_weight: 1.0
74 | reduction: mean
75 |
76 | # validation settings
77 | val:
78 | val_freq: !!float 5e3
79 | save_img: false
80 |
81 | metrics:
82 | psnr: # metric name, can be arbitrary
83 | type: calculate_psnr
84 | crop_border: 4
85 | test_y_channel: false
86 | niqe:
87 | type: calculate_niqe
88 | crop_border: 4
89 |
90 | # logging settings
91 | logger:
92 | print_freq: 100
93 | save_checkpoint_freq: !!float 5e3
94 | use_tb_logger: true
95 | wandb:
96 | project: ~
97 | resume_id: ~
98 |
99 | # dist training settings
100 | dist_params:
101 | backend: nccl
102 | port: 29500
103 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | basicsr
2 | numpy
3 | opencv-python
4 | requests
5 | torch
6 | torchvision
7 |
--------------------------------------------------------------------------------
/scripts/prepare_example_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import requests
3 |
4 |
5 | def main(url, dataset):
6 | # download
7 | print(f'Download {url} ...')
8 | response = requests.get(url)
9 | with open(f'datasets/example/{dataset}.zip', 'wb') as f:
10 | f.write(response.content)
11 |
12 | # unzip
13 | import zipfile
14 | with zipfile.ZipFile(f'datasets/example/{dataset}.zip', 'r') as zip_ref:
15 | zip_ref.extractall(f'datasets/example/{dataset}')
16 |
17 |
18 | if __name__ == '__main__':
19 | """This script will download and prepare the example data:
20 | 1. BSDS100 for training
21 | 2. Set5 for testing
22 | """
23 | os.makedirs('datasets/example', exist_ok=True)
24 |
25 | urls = [
26 | 'https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/BSDS100.zip',
27 | 'https://github.com/xinntao/BasicSR-examples/releases/download/0.0.0/Set5.zip'
28 | ]
29 | datasets = ['BSDS100', 'Set5']
30 | for url, dataset in zip(urls, datasets):
31 | main(url, dataset)
32 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=120
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | column_limit = 120
12 | blank_line_before_nested_class_or_def = true
13 | split_before_expression_after_opening_paren = true
14 |
15 | [isort]
16 | line_length = 120
17 | multi_line_output = 0
18 | known_standard_library = pkg_resources,setuptools
19 | known_first_party = basicsr
20 | known_third_party = cv2,requests,torch,torchvision
21 | no_lines_before = STDLIB,LOCALFOLDER
22 | default_section = THIRDPARTY
23 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # flake8: noqa
2 | import os.path as osp
3 |
4 | import archs
5 | import data
6 | import losses
7 | import models
8 | from basicsr.train import train_pipeline
9 |
10 | if __name__ == '__main__':
11 | root_path = osp.abspath(osp.join(__file__, osp.pardir))
12 | train_pipeline(root_path)
13 |
--------------------------------------------------------------------------------