├── .gitignore
├── LICENSE
├── README.md
├── configs
├── test
│ ├── test-b100-2.yaml
│ ├── test-b100-3.yaml
│ ├── test-b100-4.yaml
│ ├── test-b100-6.yaml
│ ├── test-b100-8.yaml
│ ├── test-celebAHQ-32-256.yaml
│ ├── test-celebAHQ-64-128.yaml
│ ├── test-div2k-12.yaml
│ ├── test-div2k-18.yaml
│ ├── test-div2k-2.yaml
│ ├── test-div2k-24.yaml
│ ├── test-div2k-3.yaml
│ ├── test-div2k-30.yaml
│ ├── test-div2k-4.yaml
│ ├── test-div2k-6.yaml
│ ├── test-set14-2.yaml
│ ├── test-set14-3.yaml
│ ├── test-set14-4.yaml
│ ├── test-set14-6.yaml
│ ├── test-set14-8.yaml
│ ├── test-set5-2.yaml
│ ├── test-set5-3.yaml
│ ├── test-set5-4.yaml
│ ├── test-set5-6.yaml
│ ├── test-set5-8.yaml
│ ├── test-urban100-2.yaml
│ ├── test-urban100-3.yaml
│ ├── test-urban100-4.yaml
│ ├── test-urban100-6.yaml
│ └── test-urban100-8.yaml
├── train-celebAHQ
│ ├── train_celebAHQ-32-256_liif.yaml
│ ├── train_celebAHQ-32-256_upsample.yaml
│ ├── train_celebAHQ-64-128_liif.yaml
│ └── train_celebAHQ-64-128_upsample.yaml
└── train-div2k
│ ├── ablation
│ ├── train_edsr-baseline-liif-c.yaml
│ ├── train_edsr-baseline-liif-d.yaml
│ ├── train_edsr-baseline-liif-e.yaml
│ ├── train_edsr-baseline-liif-u.yaml
│ ├── train_edsr-baseline-liif-x2.yaml
│ ├── train_edsr-baseline-liif-x3.yaml
│ └── train_edsr-baseline-liif-x4.yaml
│ ├── train_edsr-baseline-liif.yaml
│ ├── train_edsr-baseline-metasr.yaml
│ ├── train_edsr-baseline-x2.yaml
│ ├── train_rdn-liif.yaml
│ ├── train_rdn-metasr.yaml
│ └── train_rdn-x2.yaml
├── datasets
├── __init__.py
├── datasets.py
├── image_folder.py
└── wrappers.py
├── demo.py
├── models
├── __init__.py
├── edsr.py
├── liif.py
├── misc.py
├── mlp.py
├── models.py
├── rcan.py
└── rdn.py
├── scripts
├── resize.py
├── test-benchmark.sh
└── test-div2k.sh
├── test.py
├── train_liif.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 | folders
3 | load
4 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | BSD 3-Clause License
2 |
3 | Copyright (c) 2020, Yinbo Chen
4 | All rights reserved.
5 |
6 | Redistribution and use in source and binary forms, with or without
7 | modification, are permitted provided that the following conditions are met:
8 |
9 | 1. Redistributions of source code must retain the above copyright notice, this
10 | list of conditions and the following disclaimer.
11 |
12 | 2. Redistributions in binary form must reproduce the above copyright notice,
13 | this list of conditions and the following disclaimer in the documentation
14 | and/or other materials provided with the distribution.
15 |
16 | 3. Neither the name of the copyright holder nor the names of its
17 | contributors may be used to endorse or promote products derived from
18 | this software without specific prior written permission.
19 |
20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # LIIF
2 |
3 | This repository contains the official implementation for LIIF introduced in the following paper:
4 |
5 | [**Learning Continuous Image Representation with Local Implicit Image Function**](https://arxiv.org/abs/2012.09161)
6 |
7 | [Yinbo Chen](https://yinboc.github.io/), [Sifei Liu](https://www.sifeiliu.net/), [Xiaolong Wang](https://xiaolonw.github.io/)
8 |
9 | CVPR 2021 (Oral)
10 |
11 | The project page with video is at https://yinboc.github.io/liif/.
12 |
13 |
14 |
15 | ### Citation
16 |
17 | If you find our work useful in your research, please cite:
18 |
19 | ```
20 | @inproceedings{chen2021learning,
21 | title={Learning continuous image representation with local implicit image function},
22 | author={Chen, Yinbo and Liu, Sifei and Wang, Xiaolong},
23 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
24 | pages={8628--8638},
25 | year={2021}
26 | }
27 | ```
28 |
29 | ### Environment
30 | - Python 3
31 | - Pytorch 1.6.0
32 | - TensorboardX
33 | - yaml, numpy, tqdm, imageio
34 |
35 | ## Quick Start
36 |
37 | 1. Download a DIV2K pre-trained model.
38 |
39 | Model|File size|Download
40 | :-:|:-:|:-:
41 | EDSR-baseline-LIIF|18M|[Dropbox](https://www.dropbox.com/s/6f402wcn4v83w2v/edsr-baseline-liif.pth?dl=0) | [Google Drive](https://drive.google.com/file/d/1wBHSrgPLOHL_QVhPAIAcDC30KSJLf67x/view?usp=sharing)
42 | RDN-LIIF|256M|[Dropbox](https://www.dropbox.com/s/mzha6ll9kb9bwy0/rdn-liif.pth?dl=0) | [Google Drive](https://drive.google.com/file/d/1xaAx6lBVVw_PJ3YVp02h3k4HuOAXcUkt/view?usp=sharing)
43 |
44 | 2. Convert your image to LIIF and present it in a given resolution (with GPU 0, `[MODEL_PATH]` denotes the `.pth` file)
45 |
46 | ```
47 | python demo.py --input xxx.png --model [MODEL_PATH] --resolution [HEIGHT],[WIDTH] --output output.png --gpu 0
48 | ```
49 |
50 | ## Reproducing Experiments
51 |
52 | ### Data
53 |
54 | `mkdir load` for putting the dataset folders.
55 |
56 | - **DIV2K**: `mkdir` and `cd` into `load/div2k`. Download HR images and bicubic validation LR images from [DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (i.e. [Train_HR](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip), [Valid_HR](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip), [Valid_LR_X2](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip), [Valid_LR_X3](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X3.zip), [Valid_LR_X4](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip)). `unzip` these files to get the image folders.
57 |
58 | - **benchmark datasets**: `cd` into `load/`. Download and `tar -xf` the [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (provided by [this repo](https://github.com/thstkdgus35/EDSR-PyTorch)), get a `load/benchmark` folder with sub-folders `Set5/, Set14/, B100/, Urban100/`.
59 |
60 | - **celebAHQ**: `mkdir load/celebAHQ` and `cp scripts/resize.py load/celebAHQ/`, then `cd load/celebAHQ/`. Download and `unzip` data1024x1024.zip from the [Google Drive link](https://drive.google.com/drive/folders/11Vz0fqHS2rXDb5pprgTjpD7S2BAJhi1P?usp=sharing) (provided by [this repo](github.com/suvojit-0x55aa/celebA-HQ-dataset-download)). Run `python resize.py` and get image folders `256/, 128/, 64/, 32/`. Download the [split.json](https://www.dropbox.com/s/2qeijojdjzvp3b9/split.json?dl=0).
61 |
62 | ### Running the code
63 |
64 | **0. Preliminaries**
65 |
66 | - For `train_liif.py` or `test.py`, use `--gpu [GPU]` to specify the GPUs (e.g. `--gpu 0` or `--gpu 0,1`).
67 |
68 | - For `train_liif.py`, by default, the save folder is at `save/_[CONFIG_NAME]`. We can use `--name` to specify a name if needed.
69 |
70 | - For dataset args in configs, `cache: in_memory` denotes pre-loading into memory (may require large memory, e.g. ~40GB for DIV2K), `cache: bin` denotes creating binary files (in a sibling folder) for the first time, `cache: none` denotes direct loading. We can modify it according to the hardware resources before running the training scripts.
71 |
72 | **1. DIV2K experiments**
73 |
74 | **Train**: `python train_liif.py --config configs/train-div2k/train_edsr-baseline-liif.yaml` (with EDSR-baseline backbone, for RDN replace `edsr-baseline` with `rdn`). We use 1 GPU for training EDSR-baseline-LIIF and 4 GPUs for RDN-LIIF.
75 |
76 | **Test**: `bash scripts/test-div2k.sh [MODEL_PATH] [GPU]` for div2k validation set, `bash scripts/test-benchmark.sh [MODEL_PATH] [GPU]` for benchmark datasets. `[MODEL_PATH]` is the path to a `.pth` file, we use `epoch-last.pth` in corresponding save folder.
77 |
78 | **2. celebAHQ experiments**
79 |
80 | **Train**: `python train_liif.py --config configs/train-celebAHQ/[CONFIG_NAME].yaml`.
81 |
82 | **Test**: `python test.py --config configs/test/test-celebAHQ-32-256.yaml --model [MODEL_PATH]` (or `test-celebAHQ-64-128.yaml` for another task). We use `epoch-best.pth` in corresponding save folder.
83 |
--------------------------------------------------------------------------------
/configs/test/test-b100-2.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X2
6 | root_path_2: ./load/benchmark/B100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-2
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-b100-3.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X3
6 | root_path_2: ./load/benchmark/B100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-3
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-b100-4.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X4
6 | root_path_2: ./load/benchmark/B100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-4
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-b100-6.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/B100/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 6
10 | batch_size: 1
11 | eval_type: benchmark-6
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-b100-8.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/B100/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 8
10 | batch_size: 1
11 | eval_type: benchmark-8
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-celebAHQ-32-256.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/32
6 | root_path_2: ./load/celebAHQ/256
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: val
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-paired
12 | args: {}
13 | batch_size: 1
14 |
15 | data_norm:
16 | inp: {sub: [0.5], div: [0.5]}
17 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-celebAHQ-64-128.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/64
6 | root_path_2: ./load/celebAHQ/128
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: val
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-paired
12 | args: {}
13 | batch_size: 1
14 |
15 | data_norm:
16 | inp: {sub: [0.5], div: [0.5]}
17 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-12.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_valid_HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 12
10 | batch_size: 1
11 | eval_type: div2k-12
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-18.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_valid_HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 18
10 | batch_size: 1
11 | eval_type: div2k-18
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-2.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2
6 | root_path_2: ./load/div2k/DIV2K_valid_HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: div2k-2
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-24.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_valid_HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 24
10 | batch_size: 1
11 | eval_type: div2k-24
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-3.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X3
6 | root_path_2: ./load/div2k/DIV2K_valid_HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: div2k-3
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-30.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_valid_HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 30
10 | batch_size: 1
11 | eval_type: div2k-30
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-4.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X4
6 | root_path_2: ./load/div2k/DIV2K_valid_HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: div2k-4
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-div2k-6.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_valid_HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 6
10 | batch_size: 1
11 | eval_type: div2k-6
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set14-2.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X2
6 | root_path_2: ./load/benchmark/Set14/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-2
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set14-3.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X3
6 | root_path_2: ./load/benchmark/Set14/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-3
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set14-4.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X4
6 | root_path_2: ./load/benchmark/Set14/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-4
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set14-6.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Set14/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 6
10 | batch_size: 1
11 | eval_type: benchmark-6
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set14-8.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Set14/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 8
10 | batch_size: 1
11 | eval_type: benchmark-8
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set5-2.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X2
6 | root_path_2: ./load/benchmark/Set5/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-2
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set5-3.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X3
6 | root_path_2: ./load/benchmark/Set5/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-3
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set5-4.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X4
6 | root_path_2: ./load/benchmark/Set5/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-4
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set5-6.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Set5/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 6
10 | batch_size: 1
11 | eval_type: benchmark-6
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-set5-8.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Set5/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 8
10 | batch_size: 1
11 | eval_type: benchmark-8
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-urban100-2.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X2
6 | root_path_2: ./load/benchmark/Urban100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-2
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-urban100-3.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X3
6 | root_path_2: ./load/benchmark/Urban100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-3
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-urban100-4.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X4
6 | root_path_2: ./load/benchmark/Urban100/HR
7 | wrapper:
8 | name: sr-implicit-paired
9 | args: {}
10 | batch_size: 1
11 | eval_type: benchmark-4
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-urban100-6.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Urban100/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 6
10 | batch_size: 1
11 | eval_type: benchmark-6
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/test/test-urban100-8.yaml:
--------------------------------------------------------------------------------
1 | test_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/benchmark/Urban100/HR
6 | wrapper:
7 | name: sr-implicit-downsampled
8 | args:
9 | scale_min: 8
10 | batch_size: 1
11 | eval_type: benchmark-8
12 | eval_bsize: 30000
13 |
14 | data_norm:
15 | inp: {sub: [0.5], div: [0.5]}
16 | gt: {sub: [0.5], div: [0.5]}
--------------------------------------------------------------------------------
/configs/train-celebAHQ/train_celebAHQ-32-256_liif.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/32
6 | root_path_2: ./load/celebAHQ/256
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: train
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-uniform-varied
12 | args:
13 | size_min: 32
14 | size_max: 256
15 | sample_q: 1024
16 | augment: true
17 | batch_size: 16
18 |
19 | val_dataset:
20 | dataset:
21 | name: paired-image-folders
22 | args:
23 | root_path_1: ./load/celebAHQ/32
24 | root_path_2: ./load/celebAHQ/256
25 | split_file: ./load/celebAHQ/split.json
26 | split_key: val
27 | first_k: 100
28 | cache: bin
29 | wrapper:
30 | name: sr-implicit-paired
31 | args:
32 | sample_q: 1024
33 | batch_size: 16
34 |
35 | data_norm:
36 | inp: {sub: [0.5], div: [0.5]}
37 | gt: {sub: [0.5], div: [0.5]}
38 |
39 | model:
40 | name: liif
41 | args:
42 | encoder_spec:
43 | name: edsr-baseline
44 | args:
45 | no_upsampling: true
46 | imnet_spec:
47 | name: mlp
48 | args:
49 | out_dim: 3
50 | hidden_list: [256, 256, 256, 256]
51 |
52 | optimizer:
53 | name: adam
54 | args:
55 | lr: 1.e-4
56 | epoch_max: 200
57 | multi_step_lr:
58 | milestones: [100]
59 | gamma: 0.1
60 |
61 | epoch_val: 1
62 | epoch_save: 100
63 |
--------------------------------------------------------------------------------
/configs/train-celebAHQ/train_celebAHQ-32-256_upsample.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/32
6 | root_path_2: ./load/celebAHQ/256
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: train
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-uniform-varied
12 | args:
13 | size_min: 32
14 | size_max: 256
15 | gt_resize: 256
16 | sample_q: 1024
17 | augment: true
18 | batch_size: 16
19 |
20 | val_dataset:
21 | dataset:
22 | name: paired-image-folders
23 | args:
24 | root_path_1: ./load/celebAHQ/32
25 | root_path_2: ./load/celebAHQ/256
26 | split_file: ./load/celebAHQ/split.json
27 | split_key: val
28 | first_k: 100
29 | cache: bin
30 | wrapper:
31 | name: sr-implicit-paired
32 | args:
33 | sample_q: 1024
34 | batch_size: 16
35 |
36 | data_norm:
37 | inp: {sub: [0.5], div: [0.5]}
38 | gt: {sub: [0.5], div: [0.5]}
39 |
40 | model:
41 | name: liif
42 | args:
43 | encoder_spec:
44 | name: edsr-baseline
45 | args:
46 | scale: 8
47 |
48 | optimizer:
49 | name: adam
50 | args:
51 | lr: 1.e-4
52 | epoch_max: 200
53 | multi_step_lr:
54 | milestones: [100]
55 | gamma: 0.1
56 |
57 | epoch_val: 1
58 | epoch_save: 100
59 |
--------------------------------------------------------------------------------
/configs/train-celebAHQ/train_celebAHQ-64-128_liif.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/64
6 | root_path_2: ./load/celebAHQ/128
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: train
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-uniform-varied
12 | args:
13 | size_min: 64
14 | size_max: 128
15 | sample_q: 1024
16 | augment: true
17 | batch_size: 16
18 |
19 | val_dataset:
20 | dataset:
21 | name: paired-image-folders
22 | args:
23 | root_path_1: ./load/celebAHQ/64
24 | root_path_2: ./load/celebAHQ/128
25 | split_file: ./load/celebAHQ/split.json
26 | split_key: val
27 | first_k: 100
28 | cache: bin
29 | wrapper:
30 | name: sr-implicit-paired
31 | args:
32 | sample_q: 1024
33 | batch_size: 16
34 |
35 | data_norm:
36 | inp: {sub: [0.5], div: [0.5]}
37 | gt: {sub: [0.5], div: [0.5]}
38 |
39 | model:
40 | name: liif
41 | args:
42 | encoder_spec:
43 | name: edsr-baseline
44 | args:
45 | no_upsampling: true
46 | imnet_spec:
47 | name: mlp
48 | args:
49 | out_dim: 3
50 | hidden_list: [256, 256, 256, 256]
51 |
52 | optimizer:
53 | name: adam
54 | args:
55 | lr: 1.e-4
56 | epoch_max: 200
57 | multi_step_lr:
58 | milestones: [100]
59 | gamma: 0.1
60 |
61 | epoch_val: 1
62 | epoch_save: 100
63 |
--------------------------------------------------------------------------------
/configs/train-celebAHQ/train_celebAHQ-64-128_upsample.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/celebAHQ/64
6 | root_path_2: ./load/celebAHQ/128
7 | split_file: ./load/celebAHQ/split.json
8 | split_key: train
9 | cache: bin
10 | wrapper:
11 | name: sr-implicit-uniform-varied
12 | args:
13 | size_min: 64
14 | size_max: 128
15 | gt_resize: 128
16 | sample_q: 1024
17 | augment: true
18 | batch_size: 16
19 |
20 | val_dataset:
21 | dataset:
22 | name: paired-image-folders
23 | args:
24 | root_path_1: ./load/celebAHQ/64
25 | root_path_2: ./load/celebAHQ/128
26 | split_file: ./load/celebAHQ/split.json
27 | split_key: val
28 | first_k: 100
29 | cache: bin
30 | wrapper:
31 | name: sr-implicit-paired
32 | args:
33 | sample_q: 1024
34 | batch_size: 16
35 |
36 | data_norm:
37 | inp: {sub: [0.5], div: [0.5]}
38 | gt: {sub: [0.5], div: [0.5]}
39 |
40 | model:
41 | name: liif
42 | args:
43 | encoder_spec:
44 | name: edsr-baseline
45 | args:
46 | scale: 2
47 |
48 | optimizer:
49 | name: adam
50 | args:
51 | lr: 1.e-4
52 | epoch_max: 200
53 | multi_step_lr:
54 | milestones: [100]
55 | gamma: 0.1
56 |
57 | epoch_val: 1
58 | epoch_save: 100
59 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-c.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 | cell_decode: false
50 |
51 | optimizer:
52 | name: adam
53 | args:
54 | lr: 1.e-4
55 | epoch_max: 1000
56 | multi_step_lr:
57 | milestones: [200, 400, 600, 800]
58 | gamma: 0.5
59 |
60 | epoch_val: 1
61 | epoch_save: 100
62 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-d.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-e.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 | local_ensemble: false
50 |
51 | optimizer:
52 | name: adam
53 | args:
54 | lr: 1.e-4
55 | epoch_max: 1000
56 | multi_step_lr:
57 | milestones: [200, 400, 600, 800]
58 | gamma: 0.5
59 |
60 | epoch_val: 1
61 | epoch_save: 100
62 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-u.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 | feat_unfold: false
50 |
51 | optimizer:
52 | name: adam
53 | args:
54 | lr: 1.e-4
55 | epoch_max: 1000
56 | multi_step_lr:
57 | milestones: [200, 400, 600, 800]
58 | gamma: 0.5
59 |
60 | epoch_val: 1
61 | epoch_save: 100
62 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-x2.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_min: 2
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_min: 2
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-x3.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_min: 3
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_min: 3
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/ablation/train_edsr-baseline-liif-x4.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_min: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_min: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_edsr-baseline-liif.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_edsr-baseline-metasr.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: metasr
39 | args:
40 | encoder_spec:
41 | name: edsr-baseline
42 | args:
43 | no_upsampling: true
44 |
45 | optimizer:
46 | name: adam
47 | args:
48 | lr: 1.e-4
49 | epoch_max: 1000
50 | multi_step_lr:
51 | milestones: [200, 400, 600, 800]
52 | gamma: 0.5
53 |
54 | epoch_val: 1
55 | epoch_save: 100
56 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_edsr-baseline-x2.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/div2k/DIV2K_train_LR_bicubic/X2
6 | root_path_2: ./load/div2k/DIV2K_train_HR
7 | repeat: 20
8 | cache: in_memory
9 | wrapper:
10 | name: sr-implicit-paired
11 | args:
12 | inp_size: 48
13 | augment: true
14 | batch_size: 16
15 |
16 | val_dataset:
17 | dataset:
18 | name: paired-image-folders
19 | args:
20 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2
21 | root_path_2: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-paired
27 | args:
28 | inp_size: 48
29 | batch_size: 16
30 |
31 | data_norm:
32 | inp: {sub: [0.5], div: [0.5]}
33 | gt: {sub: [0.5], div: [0.5]}
34 |
35 | model:
36 | name: liif
37 | args:
38 | encoder_spec:
39 | name: edsr-baseline
40 | args:
41 | scale: 2
42 |
43 | optimizer:
44 | name: adam
45 | args:
46 | lr: 1.e-4
47 | epoch_max: 1000
48 | multi_step_lr:
49 | milestones: [200, 400, 600, 800]
50 | gamma: 0.5
51 |
52 | epoch_val: 1
53 | epoch_save: 100
54 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_rdn-liif.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: liif
39 | args:
40 | encoder_spec:
41 | name: rdn
42 | args:
43 | no_upsampling: true
44 | imnet_spec:
45 | name: mlp
46 | args:
47 | out_dim: 3
48 | hidden_list: [256, 256, 256, 256]
49 |
50 | optimizer:
51 | name: adam
52 | args:
53 | lr: 1.e-4
54 | epoch_max: 1000
55 | multi_step_lr:
56 | milestones: [200, 400, 600, 800]
57 | gamma: 0.5
58 |
59 | epoch_val: 1
60 | epoch_save: 100
61 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_rdn-metasr.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: image-folder
4 | args:
5 | root_path: ./load/div2k/DIV2K_train_HR
6 | repeat: 20
7 | cache: in_memory
8 | wrapper:
9 | name: sr-implicit-downsampled
10 | args:
11 | inp_size: 48
12 | scale_max: 4
13 | augment: true
14 | sample_q: 2304
15 | batch_size: 16
16 |
17 | val_dataset:
18 | dataset:
19 | name: image-folder
20 | args:
21 | root_path: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-downsampled
27 | args:
28 | inp_size: 48
29 | scale_max: 4
30 | sample_q: 2304
31 | batch_size: 16
32 |
33 | data_norm:
34 | inp: {sub: [0.5], div: [0.5]}
35 | gt: {sub: [0.5], div: [0.5]}
36 |
37 | model:
38 | name: metasr
39 | args:
40 | encoder_spec:
41 | name: rdn
42 | args:
43 | no_upsampling: true
44 |
45 | optimizer:
46 | name: adam
47 | args:
48 | lr: 1.e-4
49 | epoch_max: 1000
50 | multi_step_lr:
51 | milestones: [200, 400, 600, 800]
52 | gamma: 0.5
53 |
54 | epoch_val: 1
55 | epoch_save: 100
56 |
--------------------------------------------------------------------------------
/configs/train-div2k/train_rdn-x2.yaml:
--------------------------------------------------------------------------------
1 | train_dataset:
2 | dataset:
3 | name: paired-image-folders
4 | args:
5 | root_path_1: ./load/div2k/DIV2K_train_LR_bicubic/X2
6 | root_path_2: ./load/div2k/DIV2K_train_HR
7 | repeat: 20
8 | cache: in_memory
9 | wrapper:
10 | name: sr-implicit-paired
11 | args:
12 | inp_size: 48
13 | augment: true
14 | batch_size: 16
15 |
16 | val_dataset:
17 | dataset:
18 | name: paired-image-folders
19 | args:
20 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2
21 | root_path_2: ./load/div2k/DIV2K_valid_HR
22 | first_k: 10
23 | repeat: 160
24 | cache: in_memory
25 | wrapper:
26 | name: sr-implicit-paired
27 | args:
28 | inp_size: 48
29 | batch_size: 16
30 |
31 | data_norm:
32 | inp: {sub: [0.5], div: [0.5]}
33 | gt: {sub: [0.5], div: [0.5]}
34 |
35 | model:
36 | name: liif
37 | args:
38 | encoder_spec:
39 | name: rdn
40 | args:
41 | scale: 2
42 |
43 | optimizer:
44 | name: adam
45 | args:
46 | lr: 1.e-4
47 | epoch_max: 1000
48 | multi_step_lr:
49 | milestones: [200, 400, 600, 800]
50 | gamma: 0.5
51 |
52 | epoch_val: 1
53 | epoch_save: 100
54 |
--------------------------------------------------------------------------------
/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .datasets import register, make
2 | from . import image_folder
3 | from . import wrappers
4 |
--------------------------------------------------------------------------------
/datasets/datasets.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | datasets = {}
5 |
6 |
7 | def register(name):
8 | def decorator(cls):
9 | datasets[name] = cls
10 | return cls
11 | return decorator
12 |
13 |
14 | def make(dataset_spec, args=None):
15 | if args is not None:
16 | dataset_args = copy.deepcopy(dataset_spec['args'])
17 | dataset_args.update(args)
18 | else:
19 | dataset_args = dataset_spec['args']
20 | dataset = datasets[dataset_spec['name']](**dataset_args)
21 | return dataset
22 |
--------------------------------------------------------------------------------
/datasets/image_folder.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from PIL import Image
4 |
5 | import pickle
6 | import imageio
7 | import numpy as np
8 | import torch
9 | from torch.utils.data import Dataset
10 | from torchvision import transforms
11 |
12 | from datasets import register
13 |
14 |
15 | @register('image-folder')
16 | class ImageFolder(Dataset):
17 |
18 | def __init__(self, root_path, split_file=None, split_key=None, first_k=None,
19 | repeat=1, cache='none'):
20 | self.repeat = repeat
21 | self.cache = cache
22 |
23 | if split_file is None:
24 | filenames = sorted(os.listdir(root_path))
25 | else:
26 | with open(split_file, 'r') as f:
27 | filenames = json.load(f)[split_key]
28 | if first_k is not None:
29 | filenames = filenames[:first_k]
30 |
31 | self.files = []
32 | for filename in filenames:
33 | file = os.path.join(root_path, filename)
34 |
35 | if cache == 'none':
36 | self.files.append(file)
37 |
38 | elif cache == 'bin':
39 | bin_root = os.path.join(os.path.dirname(root_path),
40 | '_bin_' + os.path.basename(root_path))
41 | if not os.path.exists(bin_root):
42 | os.mkdir(bin_root)
43 | print('mkdir', bin_root)
44 | bin_file = os.path.join(
45 | bin_root, filename.split('.')[0] + '.pkl')
46 | if not os.path.exists(bin_file):
47 | with open(bin_file, 'wb') as f:
48 | pickle.dump(imageio.imread(file), f)
49 | print('dump', bin_file)
50 | self.files.append(bin_file)
51 |
52 | elif cache == 'in_memory':
53 | self.files.append(transforms.ToTensor()(
54 | Image.open(file).convert('RGB')))
55 |
56 | def __len__(self):
57 | return len(self.files) * self.repeat
58 |
59 | def __getitem__(self, idx):
60 | x = self.files[idx % len(self.files)]
61 |
62 | if self.cache == 'none':
63 | return transforms.ToTensor()(Image.open(x).convert('RGB'))
64 |
65 | elif self.cache == 'bin':
66 | with open(x, 'rb') as f:
67 | x = pickle.load(f)
68 | x = np.ascontiguousarray(x.transpose(2, 0, 1))
69 | x = torch.from_numpy(x).float() / 255
70 | return x
71 |
72 | elif self.cache == 'in_memory':
73 | return x
74 |
75 |
76 | @register('paired-image-folders')
77 | class PairedImageFolders(Dataset):
78 |
79 | def __init__(self, root_path_1, root_path_2, **kwargs):
80 | self.dataset_1 = ImageFolder(root_path_1, **kwargs)
81 | self.dataset_2 = ImageFolder(root_path_2, **kwargs)
82 |
83 | def __len__(self):
84 | return len(self.dataset_1)
85 |
86 | def __getitem__(self, idx):
87 | return self.dataset_1[idx], self.dataset_2[idx]
88 |
--------------------------------------------------------------------------------
/datasets/wrappers.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import numpy as np
7 | import torch
8 | from torch.utils.data import Dataset
9 | from torchvision import transforms
10 |
11 | from datasets import register
12 | from utils import to_pixel_samples
13 |
14 |
15 | @register('sr-implicit-paired')
16 | class SRImplicitPaired(Dataset):
17 |
18 | def __init__(self, dataset, inp_size=None, augment=False, sample_q=None):
19 | self.dataset = dataset
20 | self.inp_size = inp_size
21 | self.augment = augment
22 | self.sample_q = sample_q
23 |
24 | def __len__(self):
25 | return len(self.dataset)
26 |
27 | def __getitem__(self, idx):
28 | img_lr, img_hr = self.dataset[idx]
29 |
30 | s = img_hr.shape[-2] // img_lr.shape[-2] # assume int scale
31 | if self.inp_size is None:
32 | h_lr, w_lr = img_lr.shape[-2:]
33 | img_hr = img_hr[:, :h_lr * s, :w_lr * s]
34 | crop_lr, crop_hr = img_lr, img_hr
35 | else:
36 | w_lr = self.inp_size
37 | x0 = random.randint(0, img_lr.shape[-2] - w_lr)
38 | y0 = random.randint(0, img_lr.shape[-1] - w_lr)
39 | crop_lr = img_lr[:, x0: x0 + w_lr, y0: y0 + w_lr]
40 | w_hr = w_lr * s
41 | x1 = x0 * s
42 | y1 = y0 * s
43 | crop_hr = img_hr[:, x1: x1 + w_hr, y1: y1 + w_hr]
44 |
45 | if self.augment:
46 | hflip = random.random() < 0.5
47 | vflip = random.random() < 0.5
48 | dflip = random.random() < 0.5
49 |
50 | def augment(x):
51 | if hflip:
52 | x = x.flip(-2)
53 | if vflip:
54 | x = x.flip(-1)
55 | if dflip:
56 | x = x.transpose(-2, -1)
57 | return x
58 |
59 | crop_lr = augment(crop_lr)
60 | crop_hr = augment(crop_hr)
61 |
62 | hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous())
63 |
64 | if self.sample_q is not None:
65 | sample_lst = np.random.choice(
66 | len(hr_coord), self.sample_q, replace=False)
67 | hr_coord = hr_coord[sample_lst]
68 | hr_rgb = hr_rgb[sample_lst]
69 |
70 | cell = torch.ones_like(hr_coord)
71 | cell[:, 0] *= 2 / crop_hr.shape[-2]
72 | cell[:, 1] *= 2 / crop_hr.shape[-1]
73 |
74 | return {
75 | 'inp': crop_lr,
76 | 'coord': hr_coord,
77 | 'cell': cell,
78 | 'gt': hr_rgb
79 | }
80 |
81 |
82 | def resize_fn(img, size):
83 | return transforms.ToTensor()(
84 | transforms.Resize(size, Image.BICUBIC)(
85 | transforms.ToPILImage()(img)))
86 |
87 |
88 | @register('sr-implicit-downsampled')
89 | class SRImplicitDownsampled(Dataset):
90 |
91 | def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None,
92 | augment=False, sample_q=None):
93 | self.dataset = dataset
94 | self.inp_size = inp_size
95 | self.scale_min = scale_min
96 | if scale_max is None:
97 | scale_max = scale_min
98 | self.scale_max = scale_max
99 | self.augment = augment
100 | self.sample_q = sample_q
101 |
102 | def __len__(self):
103 | return len(self.dataset)
104 |
105 | def __getitem__(self, idx):
106 | img = self.dataset[idx]
107 | s = random.uniform(self.scale_min, self.scale_max)
108 |
109 | if self.inp_size is None:
110 | h_lr = math.floor(img.shape[-2] / s + 1e-9)
111 | w_lr = math.floor(img.shape[-1] / s + 1e-9)
112 | img = img[:, :round(h_lr * s), :round(w_lr * s)] # assume round int
113 | img_down = resize_fn(img, (h_lr, w_lr))
114 | crop_lr, crop_hr = img_down, img
115 | else:
116 | w_lr = self.inp_size
117 | w_hr = round(w_lr * s)
118 | x0 = random.randint(0, img.shape[-2] - w_hr)
119 | y0 = random.randint(0, img.shape[-1] - w_hr)
120 | crop_hr = img[:, x0: x0 + w_hr, y0: y0 + w_hr]
121 | crop_lr = resize_fn(crop_hr, w_lr)
122 |
123 | if self.augment:
124 | hflip = random.random() < 0.5
125 | vflip = random.random() < 0.5
126 | dflip = random.random() < 0.5
127 |
128 | def augment(x):
129 | if hflip:
130 | x = x.flip(-2)
131 | if vflip:
132 | x = x.flip(-1)
133 | if dflip:
134 | x = x.transpose(-2, -1)
135 | return x
136 |
137 | crop_lr = augment(crop_lr)
138 | crop_hr = augment(crop_hr)
139 |
140 | hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous())
141 |
142 | if self.sample_q is not None:
143 | sample_lst = np.random.choice(
144 | len(hr_coord), self.sample_q, replace=False)
145 | hr_coord = hr_coord[sample_lst]
146 | hr_rgb = hr_rgb[sample_lst]
147 |
148 | cell = torch.ones_like(hr_coord)
149 | cell[:, 0] *= 2 / crop_hr.shape[-2]
150 | cell[:, 1] *= 2 / crop_hr.shape[-1]
151 |
152 | return {
153 | 'inp': crop_lr,
154 | 'coord': hr_coord,
155 | 'cell': cell,
156 | 'gt': hr_rgb
157 | }
158 |
159 |
160 | @register('sr-implicit-uniform-varied')
161 | class SRImplicitUniformVaried(Dataset):
162 |
163 | def __init__(self, dataset, size_min, size_max=None,
164 | augment=False, gt_resize=None, sample_q=None):
165 | self.dataset = dataset
166 | self.size_min = size_min
167 | if size_max is None:
168 | size_max = size_min
169 | self.size_max = size_max
170 | self.augment = augment
171 | self.gt_resize = gt_resize
172 | self.sample_q = sample_q
173 |
174 | def __len__(self):
175 | return len(self.dataset)
176 |
177 | def __getitem__(self, idx):
178 | img_lr, img_hr = self.dataset[idx]
179 | p = idx / (len(self.dataset) - 1)
180 | w_hr = round(self.size_min + (self.size_max - self.size_min) * p)
181 | img_hr = resize_fn(img_hr, w_hr)
182 |
183 | if self.augment:
184 | if random.random() < 0.5:
185 | img_lr = img_lr.flip(-1)
186 | img_hr = img_hr.flip(-1)
187 |
188 | if self.gt_resize is not None:
189 | img_hr = resize_fn(img_hr, self.gt_resize)
190 |
191 | hr_coord, hr_rgb = to_pixel_samples(img_hr)
192 |
193 | if self.sample_q is not None:
194 | sample_lst = np.random.choice(
195 | len(hr_coord), self.sample_q, replace=False)
196 | hr_coord = hr_coord[sample_lst]
197 | hr_rgb = hr_rgb[sample_lst]
198 |
199 | cell = torch.ones_like(hr_coord)
200 | cell[:, 0] *= 2 / img_hr.shape[-2]
201 | cell[:, 1] *= 2 / img_hr.shape[-1]
202 |
203 | return {
204 | 'inp': img_lr,
205 | 'coord': hr_coord,
206 | 'cell': cell,
207 | 'gt': hr_rgb
208 | }
209 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from PIL import Image
4 |
5 | import torch
6 | from torchvision import transforms
7 |
8 | import models
9 | from utils import make_coord
10 | from test import batched_predict
11 |
12 |
13 | if __name__ == '__main__':
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--input', default='input.png')
16 | parser.add_argument('--model')
17 | parser.add_argument('--resolution')
18 | parser.add_argument('--output', default='output.png')
19 | parser.add_argument('--gpu', default='0')
20 | args = parser.parse_args()
21 |
22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
23 |
24 | img = transforms.ToTensor()(Image.open(args.input).convert('RGB'))
25 |
26 | model = models.make(torch.load(args.model)['model'], load_sd=True).cuda()
27 |
28 | h, w = list(map(int, args.resolution.split(',')))
29 | coord = make_coord((h, w)).cuda()
30 | cell = torch.ones_like(coord)
31 | cell[:, 0] *= 2 / h
32 | cell[:, 1] *= 2 / w
33 | pred = batched_predict(model, ((img - 0.5) / 0.5).cuda().unsqueeze(0),
34 | coord.unsqueeze(0), cell.unsqueeze(0), bsize=30000)[0]
35 | pred = (pred * 0.5 + 0.5).clamp(0, 1).view(h, w, 3).permute(2, 0, 1).cpu()
36 | transforms.ToPILImage()(pred).save(args.output)
37 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import register, make
2 | from . import edsr, rdn, rcan
3 | from . import mlp
4 | from . import liif
5 | from . import misc
6 |
--------------------------------------------------------------------------------
/models/edsr.py:
--------------------------------------------------------------------------------
1 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch
2 |
3 | import math
4 | from argparse import Namespace
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | from models import register
11 |
12 |
13 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
14 | return nn.Conv2d(
15 | in_channels, out_channels, kernel_size,
16 | padding=(kernel_size//2), bias=bias)
17 |
18 | class MeanShift(nn.Conv2d):
19 | def __init__(
20 | self, rgb_range,
21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
22 |
23 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
24 | std = torch.Tensor(rgb_std)
25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
27 | for p in self.parameters():
28 | p.requires_grad = False
29 |
30 | class ResBlock(nn.Module):
31 | def __init__(
32 | self, conv, n_feats, kernel_size,
33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
34 |
35 | super(ResBlock, self).__init__()
36 | m = []
37 | for i in range(2):
38 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
39 | if bn:
40 | m.append(nn.BatchNorm2d(n_feats))
41 | if i == 0:
42 | m.append(act)
43 |
44 | self.body = nn.Sequential(*m)
45 | self.res_scale = res_scale
46 |
47 | def forward(self, x):
48 | res = self.body(x).mul(self.res_scale)
49 | res += x
50 |
51 | return res
52 |
53 | class Upsampler(nn.Sequential):
54 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
55 |
56 | m = []
57 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
58 | for _ in range(int(math.log(scale, 2))):
59 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
60 | m.append(nn.PixelShuffle(2))
61 | if bn:
62 | m.append(nn.BatchNorm2d(n_feats))
63 | if act == 'relu':
64 | m.append(nn.ReLU(True))
65 | elif act == 'prelu':
66 | m.append(nn.PReLU(n_feats))
67 |
68 | elif scale == 3:
69 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
70 | m.append(nn.PixelShuffle(3))
71 | if bn:
72 | m.append(nn.BatchNorm2d(n_feats))
73 | if act == 'relu':
74 | m.append(nn.ReLU(True))
75 | elif act == 'prelu':
76 | m.append(nn.PReLU(n_feats))
77 | else:
78 | raise NotImplementedError
79 |
80 | super(Upsampler, self).__init__(*m)
81 |
82 |
83 | url = {
84 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
85 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
86 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
87 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
88 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
89 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
90 | }
91 |
92 | class EDSR(nn.Module):
93 | def __init__(self, args, conv=default_conv):
94 | super(EDSR, self).__init__()
95 | self.args = args
96 | n_resblocks = args.n_resblocks
97 | n_feats = args.n_feats
98 | kernel_size = 3
99 | scale = args.scale[0]
100 | act = nn.ReLU(True)
101 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
102 | if url_name in url:
103 | self.url = url[url_name]
104 | else:
105 | self.url = None
106 | self.sub_mean = MeanShift(args.rgb_range)
107 | self.add_mean = MeanShift(args.rgb_range, sign=1)
108 |
109 | # define head module
110 | m_head = [conv(args.n_colors, n_feats, kernel_size)]
111 |
112 | # define body module
113 | m_body = [
114 | ResBlock(
115 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
116 | ) for _ in range(n_resblocks)
117 | ]
118 | m_body.append(conv(n_feats, n_feats, kernel_size))
119 |
120 | self.head = nn.Sequential(*m_head)
121 | self.body = nn.Sequential(*m_body)
122 |
123 | if args.no_upsampling:
124 | self.out_dim = n_feats
125 | else:
126 | self.out_dim = args.n_colors
127 | # define tail module
128 | m_tail = [
129 | Upsampler(conv, scale, n_feats, act=False),
130 | conv(n_feats, args.n_colors, kernel_size)
131 | ]
132 | self.tail = nn.Sequential(*m_tail)
133 |
134 | def forward(self, x):
135 | #x = self.sub_mean(x)
136 | x = self.head(x)
137 |
138 | res = self.body(x)
139 | res += x
140 |
141 | if self.args.no_upsampling:
142 | x = res
143 | else:
144 | x = self.tail(res)
145 | #x = self.add_mean(x)
146 | return x
147 |
148 | def load_state_dict(self, state_dict, strict=True):
149 | own_state = self.state_dict()
150 | for name, param in state_dict.items():
151 | if name in own_state:
152 | if isinstance(param, nn.Parameter):
153 | param = param.data
154 | try:
155 | own_state[name].copy_(param)
156 | except Exception:
157 | if name.find('tail') == -1:
158 | raise RuntimeError('While copying the parameter named {}, '
159 | 'whose dimensions in the model are {} and '
160 | 'whose dimensions in the checkpoint are {}.'
161 | .format(name, own_state[name].size(), param.size()))
162 | elif strict:
163 | if name.find('tail') == -1:
164 | raise KeyError('unexpected key "{}" in state_dict'
165 | .format(name))
166 |
167 |
168 | @register('edsr-baseline')
169 | def make_edsr_baseline(n_resblocks=16, n_feats=64, res_scale=1,
170 | scale=2, no_upsampling=False, rgb_range=1):
171 | args = Namespace()
172 | args.n_resblocks = n_resblocks
173 | args.n_feats = n_feats
174 | args.res_scale = res_scale
175 |
176 | args.scale = [scale]
177 | args.no_upsampling = no_upsampling
178 |
179 | args.rgb_range = rgb_range
180 | args.n_colors = 3
181 | return EDSR(args)
182 |
183 |
184 | @register('edsr')
185 | def make_edsr(n_resblocks=32, n_feats=256, res_scale=0.1,
186 | scale=2, no_upsampling=False, rgb_range=1):
187 | args = Namespace()
188 | args.n_resblocks = n_resblocks
189 | args.n_feats = n_feats
190 | args.res_scale = res_scale
191 |
192 | args.scale = [scale]
193 | args.no_upsampling = no_upsampling
194 |
195 | args.rgb_range = rgb_range
196 | args.n_colors = 3
197 | return EDSR(args)
198 |
--------------------------------------------------------------------------------
/models/liif.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import models
6 | from models import register
7 | from utils import make_coord
8 |
9 |
10 | @register('liif')
11 | class LIIF(nn.Module):
12 |
13 | def __init__(self, encoder_spec, imnet_spec=None,
14 | local_ensemble=True, feat_unfold=True, cell_decode=True):
15 | super().__init__()
16 | self.local_ensemble = local_ensemble
17 | self.feat_unfold = feat_unfold
18 | self.cell_decode = cell_decode
19 |
20 | self.encoder = models.make(encoder_spec)
21 |
22 | if imnet_spec is not None:
23 | imnet_in_dim = self.encoder.out_dim
24 | if self.feat_unfold:
25 | imnet_in_dim *= 9
26 | imnet_in_dim += 2 # attach coord
27 | if self.cell_decode:
28 | imnet_in_dim += 2
29 | self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim})
30 | else:
31 | self.imnet = None
32 |
33 | def gen_feat(self, inp):
34 | self.feat = self.encoder(inp)
35 | return self.feat
36 |
37 | def query_rgb(self, coord, cell=None):
38 | feat = self.feat
39 |
40 | if self.imnet is None:
41 | ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1),
42 | mode='nearest', align_corners=False)[:, :, 0, :] \
43 | .permute(0, 2, 1)
44 | return ret
45 |
46 | if self.feat_unfold:
47 | feat = F.unfold(feat, 3, padding=1).view(
48 | feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
49 |
50 | if self.local_ensemble:
51 | vx_lst = [-1, 1]
52 | vy_lst = [-1, 1]
53 | eps_shift = 1e-6
54 | else:
55 | vx_lst, vy_lst, eps_shift = [0], [0], 0
56 |
57 | # field radius (global: [-1, 1])
58 | rx = 2 / feat.shape[-2] / 2
59 | ry = 2 / feat.shape[-1] / 2
60 |
61 | feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \
62 | .permute(2, 0, 1) \
63 | .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
64 |
65 | preds = []
66 | areas = []
67 | for vx in vx_lst:
68 | for vy in vy_lst:
69 | coord_ = coord.clone()
70 | coord_[:, :, 0] += vx * rx + eps_shift
71 | coord_[:, :, 1] += vy * ry + eps_shift
72 | coord_.clamp_(-1 + 1e-6, 1 - 1e-6)
73 | q_feat = F.grid_sample(
74 | feat, coord_.flip(-1).unsqueeze(1),
75 | mode='nearest', align_corners=False)[:, :, 0, :] \
76 | .permute(0, 2, 1)
77 | q_coord = F.grid_sample(
78 | feat_coord, coord_.flip(-1).unsqueeze(1),
79 | mode='nearest', align_corners=False)[:, :, 0, :] \
80 | .permute(0, 2, 1)
81 | rel_coord = coord - q_coord
82 | rel_coord[:, :, 0] *= feat.shape[-2]
83 | rel_coord[:, :, 1] *= feat.shape[-1]
84 | inp = torch.cat([q_feat, rel_coord], dim=-1)
85 |
86 | if self.cell_decode:
87 | rel_cell = cell.clone()
88 | rel_cell[:, :, 0] *= feat.shape[-2]
89 | rel_cell[:, :, 1] *= feat.shape[-1]
90 | inp = torch.cat([inp, rel_cell], dim=-1)
91 |
92 | bs, q = coord.shape[:2]
93 | pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1)
94 | preds.append(pred)
95 |
96 | area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1])
97 | areas.append(area + 1e-9)
98 |
99 | tot_area = torch.stack(areas).sum(dim=0)
100 | if self.local_ensemble:
101 | t = areas[0]; areas[0] = areas[3]; areas[3] = t
102 | t = areas[1]; areas[1] = areas[2]; areas[2] = t
103 | ret = 0
104 | for pred, area in zip(preds, areas):
105 | ret = ret + pred * (area / tot_area).unsqueeze(-1)
106 | return ret
107 |
108 | def forward(self, inp, coord, cell):
109 | self.gen_feat(inp)
110 | return self.query_rgb(coord, cell)
111 |
--------------------------------------------------------------------------------
/models/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | import models
6 | from models import register
7 | from utils import make_coord
8 |
9 |
10 | @register('metasr')
11 | class MetaSR(nn.Module):
12 |
13 | def __init__(self, encoder_spec):
14 | super().__init__()
15 |
16 | self.encoder = models.make(encoder_spec)
17 | imnet_spec = {
18 | 'name': 'mlp',
19 | 'args': {
20 | 'in_dim': 3,
21 | 'out_dim': self.encoder.out_dim * 9 * 3,
22 | 'hidden_list': [256]
23 | }
24 | }
25 | self.imnet = models.make(imnet_spec)
26 |
27 | def gen_feat(self, inp):
28 | self.feat = self.encoder(inp)
29 | return self.feat
30 |
31 | def query_rgb(self, coord, cell=None):
32 | feat = self.feat
33 | feat = F.unfold(feat, 3, padding=1).view(
34 | feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3])
35 |
36 | feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda()
37 | feat_coord[:, :, 0] -= (2 / feat.shape[-2]) / 2
38 | feat_coord[:, :, 1] -= (2 / feat.shape[-1]) / 2
39 | feat_coord = feat_coord.permute(2, 0, 1) \
40 | .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:])
41 |
42 | coord_ = coord.clone()
43 | coord_[:, :, 0] -= cell[:, :, 0] / 2
44 | coord_[:, :, 1] -= cell[:, :, 1] / 2
45 | coord_q = (coord_ + 1e-6).clamp(-1 + 1e-6, 1 - 1e-6)
46 | q_feat = F.grid_sample(
47 | feat, coord_q.flip(-1).unsqueeze(1),
48 | mode='nearest', align_corners=False)[:, :, 0, :] \
49 | .permute(0, 2, 1)
50 | q_coord = F.grid_sample(
51 | feat_coord, coord_q.flip(-1).unsqueeze(1),
52 | mode='nearest', align_corners=False)[:, :, 0, :] \
53 | .permute(0, 2, 1)
54 |
55 | rel_coord = coord_ - q_coord
56 | rel_coord[:, :, 0] *= feat.shape[-2] / 2
57 | rel_coord[:, :, 1] *= feat.shape[-1] / 2
58 |
59 | r_rev = cell[:, :, 0] * (feat.shape[-2] / 2)
60 | inp = torch.cat([rel_coord, r_rev.unsqueeze(-1)], dim=-1)
61 |
62 | bs, q = coord.shape[:2]
63 | pred = self.imnet(inp.view(bs * q, -1)).view(bs * q, feat.shape[1], 3)
64 | pred = torch.bmm(q_feat.contiguous().view(bs * q, 1, -1), pred)
65 | pred = pred.view(bs, q, 3)
66 | return pred
67 |
68 | def forward(self, inp, coord, cell):
69 | self.gen_feat(inp)
70 | return self.query_rgb(coord, cell)
71 |
--------------------------------------------------------------------------------
/models/mlp.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from models import register
4 |
5 |
6 | @register('mlp')
7 | class MLP(nn.Module):
8 |
9 | def __init__(self, in_dim, out_dim, hidden_list):
10 | super().__init__()
11 | layers = []
12 | lastv = in_dim
13 | for hidden in hidden_list:
14 | layers.append(nn.Linear(lastv, hidden))
15 | layers.append(nn.ReLU())
16 | lastv = hidden
17 | layers.append(nn.Linear(lastv, out_dim))
18 | self.layers = nn.Sequential(*layers)
19 |
20 | def forward(self, x):
21 | shape = x.shape[:-1]
22 | x = self.layers(x.view(-1, x.shape[-1]))
23 | return x.view(*shape, -1)
24 |
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | import copy
2 |
3 |
4 | models = {}
5 |
6 |
7 | def register(name):
8 | def decorator(cls):
9 | models[name] = cls
10 | return cls
11 | return decorator
12 |
13 |
14 | def make(model_spec, args=None, load_sd=False):
15 | if args is not None:
16 | model_args = copy.deepcopy(model_spec['args'])
17 | model_args.update(args)
18 | else:
19 | model_args = model_spec['args']
20 | model = models[model_spec['name']](**model_args)
21 | if load_sd:
22 | model.load_state_dict(model_spec['sd'])
23 | return model
24 |
--------------------------------------------------------------------------------
/models/rcan.py:
--------------------------------------------------------------------------------
1 | import math
2 | from argparse import Namespace
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 | from models import register
8 |
9 |
10 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
11 | return nn.Conv2d(
12 | in_channels, out_channels, kernel_size,
13 | padding=(kernel_size//2), bias=bias)
14 |
15 | class MeanShift(nn.Conv2d):
16 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
17 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
18 | std = torch.Tensor(rgb_std)
19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1)
20 | self.weight.data.div_(std.view(3, 1, 1, 1))
21 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
22 | self.bias.data.div_(std)
23 | self.requires_grad = False
24 |
25 | class Upsampler(nn.Sequential):
26 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True):
27 |
28 | m = []
29 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
30 | for _ in range(int(math.log(scale, 2))):
31 | m.append(conv(n_feat, 4 * n_feat, 3, bias))
32 | m.append(nn.PixelShuffle(2))
33 | if bn: m.append(nn.BatchNorm2d(n_feat))
34 | if act: m.append(act())
35 | elif scale == 3:
36 | m.append(conv(n_feat, 9 * n_feat, 3, bias))
37 | m.append(nn.PixelShuffle(3))
38 | if bn: m.append(nn.BatchNorm2d(n_feat))
39 | if act: m.append(act())
40 | else:
41 | raise NotImplementedError
42 |
43 | super(Upsampler, self).__init__(*m)
44 |
45 | ## Channel Attention (CA) Layer
46 | class CALayer(nn.Module):
47 | def __init__(self, channel, reduction=16):
48 | super(CALayer, self).__init__()
49 | # global average pooling: feature --> point
50 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
51 | # feature channel downscale and upscale --> channel weight
52 | self.conv_du = nn.Sequential(
53 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
54 | nn.ReLU(inplace=True),
55 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
56 | nn.Sigmoid()
57 | )
58 |
59 | def forward(self, x):
60 | y = self.avg_pool(x)
61 | y = self.conv_du(y)
62 | return x * y
63 |
64 | ## Residual Channel Attention Block (RCAB)
65 | class RCAB(nn.Module):
66 | def __init__(
67 | self, conv, n_feat, kernel_size, reduction,
68 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
69 |
70 | super(RCAB, self).__init__()
71 | modules_body = []
72 | for i in range(2):
73 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias))
74 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
75 | if i == 0: modules_body.append(act)
76 | modules_body.append(CALayer(n_feat, reduction))
77 | self.body = nn.Sequential(*modules_body)
78 | self.res_scale = res_scale
79 |
80 | def forward(self, x):
81 | res = self.body(x)
82 | #res = self.body(x).mul(self.res_scale)
83 | res += x
84 | return res
85 |
86 | ## Residual Group (RG)
87 | class ResidualGroup(nn.Module):
88 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks):
89 | super(ResidualGroup, self).__init__()
90 | modules_body = []
91 | modules_body = [
92 | RCAB(
93 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \
94 | for _ in range(n_resblocks)]
95 | modules_body.append(conv(n_feat, n_feat, kernel_size))
96 | self.body = nn.Sequential(*modules_body)
97 |
98 | def forward(self, x):
99 | res = self.body(x)
100 | res += x
101 | return res
102 |
103 | ## Residual Channel Attention Network (RCAN)
104 | class RCAN(nn.Module):
105 | def __init__(self, args, conv=default_conv):
106 | super(RCAN, self).__init__()
107 | self.args = args
108 |
109 | n_resgroups = args.n_resgroups
110 | n_resblocks = args.n_resblocks
111 | n_feats = args.n_feats
112 | kernel_size = 3
113 | reduction = args.reduction
114 | scale = args.scale[0]
115 | act = nn.ReLU(True)
116 |
117 | # RGB mean for DIV2K
118 | rgb_mean = (0.4488, 0.4371, 0.4040)
119 | rgb_std = (1.0, 1.0, 1.0)
120 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std)
121 |
122 | # define head module
123 | modules_head = [conv(args.n_colors, n_feats, kernel_size)]
124 |
125 | # define body module
126 | modules_body = [
127 | ResidualGroup(
128 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \
129 | for _ in range(n_resgroups)]
130 |
131 | modules_body.append(conv(n_feats, n_feats, kernel_size))
132 |
133 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1)
134 |
135 | self.head = nn.Sequential(*modules_head)
136 | self.body = nn.Sequential(*modules_body)
137 |
138 | if args.no_upsampling:
139 | self.out_dim = n_feats
140 | else:
141 | self.out_dim = args.n_colors
142 | # define tail module
143 | modules_tail = [
144 | Upsampler(conv, scale, n_feats, act=False),
145 | conv(n_feats, args.n_colors, kernel_size)]
146 | self.tail = nn.Sequential(*modules_tail)
147 |
148 | def forward(self, x):
149 | #x = self.sub_mean(x)
150 | x = self.head(x)
151 |
152 | res = self.body(x)
153 | res += x
154 |
155 | if self.args.no_upsampling:
156 | x = res
157 | else:
158 | x = self.tail(res)
159 | #x = self.add_mean(x)
160 | return x
161 |
162 | def load_state_dict(self, state_dict, strict=False):
163 | own_state = self.state_dict()
164 | for name, param in state_dict.items():
165 | if name in own_state:
166 | if isinstance(param, nn.Parameter):
167 | param = param.data
168 | try:
169 | own_state[name].copy_(param)
170 | except Exception:
171 | if name.find('tail') >= 0:
172 | print('Replace pre-trained upsampler to new one...')
173 | else:
174 | raise RuntimeError('While copying the parameter named {}, '
175 | 'whose dimensions in the model are {} and '
176 | 'whose dimensions in the checkpoint are {}.'
177 | .format(name, own_state[name].size(), param.size()))
178 | elif strict:
179 | if name.find('tail') == -1:
180 | raise KeyError('unexpected key "{}" in state_dict'
181 | .format(name))
182 |
183 | if strict:
184 | missing = set(own_state.keys()) - set(state_dict.keys())
185 | if len(missing) > 0:
186 | raise KeyError('missing keys in state_dict: "{}"'.format(missing))
187 |
188 |
189 | @register('rcan')
190 | def make_rcan(n_resgroups=10, n_resblocks=20, n_feats=64, reduction=16,
191 | scale=2, no_upsampling=False, rgb_range=1):
192 | args = Namespace()
193 | args.n_resgroups = n_resgroups
194 | args.n_resblocks = n_resblocks
195 | args.n_feats = n_feats
196 | args.reduction = reduction
197 |
198 | args.scale = [scale]
199 | args.no_upsampling = no_upsampling
200 |
201 | args.rgb_range = rgb_range
202 | args.res_scale = 1
203 | args.n_colors = 3
204 | return RCAN(args)
205 |
--------------------------------------------------------------------------------
/models/rdn.py:
--------------------------------------------------------------------------------
1 | # Residual Dense Network for Image Super-Resolution
2 | # https://arxiv.org/abs/1802.08797
3 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch
4 |
5 | from argparse import Namespace
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | from models import register
11 |
12 |
13 | class RDB_Conv(nn.Module):
14 | def __init__(self, inChannels, growRate, kSize=3):
15 | super(RDB_Conv, self).__init__()
16 | Cin = inChannels
17 | G = growRate
18 | self.conv = nn.Sequential(*[
19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1),
20 | nn.ReLU()
21 | ])
22 |
23 | def forward(self, x):
24 | out = self.conv(x)
25 | return torch.cat((x, out), 1)
26 |
27 | class RDB(nn.Module):
28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3):
29 | super(RDB, self).__init__()
30 | G0 = growRate0
31 | G = growRate
32 | C = nConvLayers
33 |
34 | convs = []
35 | for c in range(C):
36 | convs.append(RDB_Conv(G0 + c*G, G))
37 | self.convs = nn.Sequential(*convs)
38 |
39 | # Local Feature Fusion
40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1)
41 |
42 | def forward(self, x):
43 | return self.LFF(self.convs(x)) + x
44 |
45 | class RDN(nn.Module):
46 | def __init__(self, args):
47 | super(RDN, self).__init__()
48 | self.args = args
49 | r = args.scale[0]
50 | G0 = args.G0
51 | kSize = args.RDNkSize
52 |
53 | # number of RDB blocks, conv layers, out channels
54 | self.D, C, G = {
55 | 'A': (20, 6, 32),
56 | 'B': (16, 8, 64),
57 | }[args.RDNconfig]
58 |
59 | # Shallow feature extraction net
60 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1)
61 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
62 |
63 | # Redidual dense blocks and dense feature fusion
64 | self.RDBs = nn.ModuleList()
65 | for i in range(self.D):
66 | self.RDBs.append(
67 | RDB(growRate0 = G0, growRate = G, nConvLayers = C)
68 | )
69 |
70 | # Global Feature Fusion
71 | self.GFF = nn.Sequential(*[
72 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1),
73 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1)
74 | ])
75 |
76 | if args.no_upsampling:
77 | self.out_dim = G0
78 | else:
79 | self.out_dim = args.n_colors
80 | # Up-sampling net
81 | if r == 2 or r == 3:
82 | self.UPNet = nn.Sequential(*[
83 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1),
84 | nn.PixelShuffle(r),
85 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
86 | ])
87 | elif r == 4:
88 | self.UPNet = nn.Sequential(*[
89 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1),
90 | nn.PixelShuffle(2),
91 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1),
92 | nn.PixelShuffle(2),
93 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1)
94 | ])
95 | else:
96 | raise ValueError("scale must be 2 or 3 or 4.")
97 |
98 | def forward(self, x):
99 | f__1 = self.SFENet1(x)
100 | x = self.SFENet2(f__1)
101 |
102 | RDBs_out = []
103 | for i in range(self.D):
104 | x = self.RDBs[i](x)
105 | RDBs_out.append(x)
106 |
107 | x = self.GFF(torch.cat(RDBs_out,1))
108 | x += f__1
109 |
110 | if self.args.no_upsampling:
111 | return x
112 | else:
113 | return self.UPNet(x)
114 |
115 |
116 | @register('rdn')
117 | def make_rdn(G0=64, RDNkSize=3, RDNconfig='B',
118 | scale=2, no_upsampling=False):
119 | args = Namespace()
120 | args.G0 = G0
121 | args.RDNkSize = RDNkSize
122 | args.RDNconfig = RDNconfig
123 |
124 | args.scale = [scale]
125 | args.no_upsampling = no_upsampling
126 |
127 | args.n_colors = 3
128 | return RDN(args)
129 |
--------------------------------------------------------------------------------
/scripts/resize.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | from tqdm import tqdm
4 |
5 | for size in [256, 128, 64, 32]:
6 | if size == 256:
7 | inp = './data1024x1024'
8 | else:
9 | inp = './256'
10 | print(size)
11 | os.mkdir(str(size))
12 | filenames = os.listdir(inp)
13 | for filename in tqdm(filenames):
14 | Image.open(os.path.join(inp, filename)) \
15 | .resize((size, size), Image.BICUBIC) \
16 | .save(os.path.join('.', str(size), filename.split('.')[0] + '.png'))
17 |
--------------------------------------------------------------------------------
/scripts/test-benchmark.sh:
--------------------------------------------------------------------------------
1 | echo 'set5' &&
2 | echo 'x2' &&
3 | python test.py --config ./configs/test/test-set5-2.yaml --model $1 --gpu $2 &&
4 | echo 'x3' &&
5 | python test.py --config ./configs/test/test-set5-3.yaml --model $1 --gpu $2 &&
6 | echo 'x4' &&
7 | python test.py --config ./configs/test/test-set5-4.yaml --model $1 --gpu $2 &&
8 | echo 'x6*' &&
9 | python test.py --config ./configs/test/test-set5-6.yaml --model $1 --gpu $2 &&
10 | echo 'x8*' &&
11 | python test.py --config ./configs/test/test-set5-8.yaml --model $1 --gpu $2 &&
12 |
13 | echo 'set14' &&
14 | echo 'x2' &&
15 | python test.py --config ./configs/test/test-set14-2.yaml --model $1 --gpu $2 &&
16 | echo 'x3' &&
17 | python test.py --config ./configs/test/test-set14-3.yaml --model $1 --gpu $2 &&
18 | echo 'x4' &&
19 | python test.py --config ./configs/test/test-set14-4.yaml --model $1 --gpu $2 &&
20 | echo 'x6*' &&
21 | python test.py --config ./configs/test/test-set14-6.yaml --model $1 --gpu $2 &&
22 | echo 'x8*' &&
23 | python test.py --config ./configs/test/test-set14-8.yaml --model $1 --gpu $2 &&
24 |
25 | echo 'b100' &&
26 | echo 'x2' &&
27 | python test.py --config ./configs/test/test-b100-2.yaml --model $1 --gpu $2 &&
28 | echo 'x3' &&
29 | python test.py --config ./configs/test/test-b100-3.yaml --model $1 --gpu $2 &&
30 | echo 'x4' &&
31 | python test.py --config ./configs/test/test-b100-4.yaml --model $1 --gpu $2 &&
32 | echo 'x6*' &&
33 | python test.py --config ./configs/test/test-b100-6.yaml --model $1 --gpu $2 &&
34 | echo 'x8*' &&
35 | python test.py --config ./configs/test/test-b100-8.yaml --model $1 --gpu $2 &&
36 |
37 | echo 'urban100' &&
38 | echo 'x2' &&
39 | python test.py --config ./configs/test/test-urban100-2.yaml --model $1 --gpu $2 &&
40 | echo 'x3' &&
41 | python test.py --config ./configs/test/test-urban100-3.yaml --model $1 --gpu $2 &&
42 | echo 'x4' &&
43 | python test.py --config ./configs/test/test-urban100-4.yaml --model $1 --gpu $2 &&
44 | echo 'x6*' &&
45 | python test.py --config ./configs/test/test-urban100-6.yaml --model $1 --gpu $2 &&
46 | echo 'x8*' &&
47 | python test.py --config ./configs/test/test-urban100-8.yaml --model $1 --gpu $2 &&
48 |
49 | true
50 |
--------------------------------------------------------------------------------
/scripts/test-div2k.sh:
--------------------------------------------------------------------------------
1 | echo 'div2k-x2' &&
2 | python test.py --config ./configs/test/test-div2k-2.yaml --model $1 --gpu $2 &&
3 | echo 'div2k-x3' &&
4 | python test.py --config ./configs/test/test-div2k-3.yaml --model $1 --gpu $2 &&
5 | echo 'div2k-x4' &&
6 | python test.py --config ./configs/test/test-div2k-4.yaml --model $1 --gpu $2 &&
7 |
8 | echo 'div2k-x6*' &&
9 | python test.py --config ./configs/test/test-div2k-6.yaml --model $1 --gpu $2 &&
10 | echo 'div2k-x12*' &&
11 | python test.py --config ./configs/test/test-div2k-12.yaml --model $1 --gpu $2 &&
12 | echo 'div2k-x18*' &&
13 | python test.py --config ./configs/test/test-div2k-18.yaml --model $1 --gpu $2 &&
14 | echo 'div2k-x24*' &&
15 | python test.py --config ./configs/test/test-div2k-24.yaml --model $1 --gpu $2 &&
16 | echo 'div2k-x30*' &&
17 | python test.py --config ./configs/test/test-div2k-30.yaml --model $1 --gpu $2 &&
18 |
19 | true
20 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import math
4 | from functools import partial
5 |
6 | import yaml
7 | import torch
8 | from torch.utils.data import DataLoader
9 | from tqdm import tqdm
10 |
11 | import datasets
12 | import models
13 | import utils
14 |
15 |
16 | def batched_predict(model, inp, coord, cell, bsize):
17 | with torch.no_grad():
18 | model.gen_feat(inp)
19 | n = coord.shape[1]
20 | ql = 0
21 | preds = []
22 | while ql < n:
23 | qr = min(ql + bsize, n)
24 | pred = model.query_rgb(coord[:, ql: qr, :], cell[:, ql: qr, :])
25 | preds.append(pred)
26 | ql = qr
27 | pred = torch.cat(preds, dim=1)
28 | return pred
29 |
30 |
31 | def eval_psnr(loader, model, data_norm=None, eval_type=None, eval_bsize=None,
32 | verbose=False):
33 | model.eval()
34 |
35 | if data_norm is None:
36 | data_norm = {
37 | 'inp': {'sub': [0], 'div': [1]},
38 | 'gt': {'sub': [0], 'div': [1]}
39 | }
40 | t = data_norm['inp']
41 | inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
42 | inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
43 | t = data_norm['gt']
44 | gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
45 | gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()
46 |
47 | if eval_type is None:
48 | metric_fn = utils.calc_psnr
49 | elif eval_type.startswith('div2k'):
50 | scale = int(eval_type.split('-')[1])
51 | metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale)
52 | elif eval_type.startswith('benchmark'):
53 | scale = int(eval_type.split('-')[1])
54 | metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale)
55 | else:
56 | raise NotImplementedError
57 |
58 | val_res = utils.Averager()
59 |
60 | pbar = tqdm(loader, leave=False, desc='val')
61 | for batch in pbar:
62 | for k, v in batch.items():
63 | batch[k] = v.cuda()
64 |
65 | inp = (batch['inp'] - inp_sub) / inp_div
66 | if eval_bsize is None:
67 | with torch.no_grad():
68 | pred = model(inp, batch['coord'], batch['cell'])
69 | else:
70 | pred = batched_predict(model, inp,
71 | batch['coord'], batch['cell'], eval_bsize)
72 | pred = pred * gt_div + gt_sub
73 | pred.clamp_(0, 1)
74 |
75 | if eval_type is not None: # reshape for shaving-eval
76 | ih, iw = batch['inp'].shape[-2:]
77 | s = math.sqrt(batch['coord'].shape[1] / (ih * iw))
78 | shape = [batch['inp'].shape[0], round(ih * s), round(iw * s), 3]
79 | pred = pred.view(*shape) \
80 | .permute(0, 3, 1, 2).contiguous()
81 | batch['gt'] = batch['gt'].view(*shape) \
82 | .permute(0, 3, 1, 2).contiguous()
83 |
84 | res = metric_fn(pred, batch['gt'])
85 | val_res.add(res.item(), inp.shape[0])
86 |
87 | if verbose:
88 | pbar.set_description('val {:.4f}'.format(val_res.item()))
89 |
90 | return val_res.item()
91 |
92 |
93 | if __name__ == '__main__':
94 | parser = argparse.ArgumentParser()
95 | parser.add_argument('--config')
96 | parser.add_argument('--model')
97 | parser.add_argument('--gpu', default='0')
98 | args = parser.parse_args()
99 |
100 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
101 |
102 | with open(args.config, 'r') as f:
103 | config = yaml.load(f, Loader=yaml.FullLoader)
104 |
105 | spec = config['test_dataset']
106 | dataset = datasets.make(spec['dataset'])
107 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
108 | loader = DataLoader(dataset, batch_size=spec['batch_size'],
109 | num_workers=8, pin_memory=True)
110 |
111 | model_spec = torch.load(args.model)['model']
112 | model = models.make(model_spec, load_sd=True).cuda()
113 |
114 | res = eval_psnr(loader, model,
115 | data_norm=config.get('data_norm'),
116 | eval_type=config.get('eval_type'),
117 | eval_bsize=config.get('eval_bsize'),
118 | verbose=True)
119 | print('result: {:.4f}'.format(res))
120 |
--------------------------------------------------------------------------------
/train_liif.py:
--------------------------------------------------------------------------------
1 | """ Train for generating LIIF, from image to implicit representation.
2 |
3 | Config:
4 | train_dataset:
5 | dataset: $spec; wrapper: $spec; batch_size:
6 | val_dataset:
7 | dataset: $spec; wrapper: $spec; batch_size:
8 | (data_norm):
9 | inp: {sub: []; div: []}
10 | gt: {sub: []; div: []}
11 | (eval_type):
12 | (eval_bsize):
13 |
14 | model: $spec
15 | optimizer: $spec
16 | epoch_max:
17 | (multi_step_lr):
18 | milestones: []; gamma: 0.5
19 | (resume): *.pth
20 |
21 | (epoch_val): ; (epoch_save):
22 | """
23 |
24 | import argparse
25 | import os
26 |
27 | import yaml
28 | import torch
29 | import torch.nn as nn
30 | from tqdm import tqdm
31 | from torch.utils.data import DataLoader
32 | from torch.optim.lr_scheduler import MultiStepLR
33 |
34 | import datasets
35 | import models
36 | import utils
37 | from test import eval_psnr
38 |
39 |
40 | def make_data_loader(spec, tag=''):
41 | if spec is None:
42 | return None
43 |
44 | dataset = datasets.make(spec['dataset'])
45 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset})
46 |
47 | log('{} dataset: size={}'.format(tag, len(dataset)))
48 | for k, v in dataset[0].items():
49 | log(' {}: shape={}'.format(k, tuple(v.shape)))
50 |
51 | loader = DataLoader(dataset, batch_size=spec['batch_size'],
52 | shuffle=(tag == 'train'), num_workers=8, pin_memory=True)
53 | return loader
54 |
55 |
56 | def make_data_loaders():
57 | train_loader = make_data_loader(config.get('train_dataset'), tag='train')
58 | val_loader = make_data_loader(config.get('val_dataset'), tag='val')
59 | return train_loader, val_loader
60 |
61 |
62 | def prepare_training():
63 | if config.get('resume') is not None:
64 | sv_file = torch.load(config['resume'])
65 | model = models.make(sv_file['model'], load_sd=True).cuda()
66 | optimizer = utils.make_optimizer(
67 | model.parameters(), sv_file['optimizer'], load_sd=True)
68 | epoch_start = sv_file['epoch'] + 1
69 | if config.get('multi_step_lr') is None:
70 | lr_scheduler = None
71 | else:
72 | lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
73 | for _ in range(epoch_start - 1):
74 | lr_scheduler.step()
75 | else:
76 | model = models.make(config['model']).cuda()
77 | optimizer = utils.make_optimizer(
78 | model.parameters(), config['optimizer'])
79 | epoch_start = 1
80 | if config.get('multi_step_lr') is None:
81 | lr_scheduler = None
82 | else:
83 | lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr'])
84 |
85 | log('model: #params={}'.format(utils.compute_num_params(model, text=True)))
86 | return model, optimizer, epoch_start, lr_scheduler
87 |
88 |
89 | def train(train_loader, model, optimizer):
90 | model.train()
91 | loss_fn = nn.L1Loss()
92 | train_loss = utils.Averager()
93 |
94 | data_norm = config['data_norm']
95 | t = data_norm['inp']
96 | inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda()
97 | inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda()
98 | t = data_norm['gt']
99 | gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda()
100 | gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda()
101 |
102 | for batch in tqdm(train_loader, leave=False, desc='train'):
103 | for k, v in batch.items():
104 | batch[k] = v.cuda()
105 |
106 | inp = (batch['inp'] - inp_sub) / inp_div
107 | pred = model(inp, batch['coord'], batch['cell'])
108 |
109 | gt = (batch['gt'] - gt_sub) / gt_div
110 | loss = loss_fn(pred, gt)
111 |
112 | train_loss.add(loss.item())
113 |
114 | optimizer.zero_grad()
115 | loss.backward()
116 | optimizer.step()
117 |
118 | pred = None; loss = None
119 |
120 | return train_loss.item()
121 |
122 |
123 | def main(config_, save_path):
124 | global config, log, writer
125 | config = config_
126 | log, writer = utils.set_save_path(save_path)
127 | with open(os.path.join(save_path, 'config.yaml'), 'w') as f:
128 | yaml.dump(config, f, sort_keys=False)
129 |
130 | train_loader, val_loader = make_data_loaders()
131 | if config.get('data_norm') is None:
132 | config['data_norm'] = {
133 | 'inp': {'sub': [0], 'div': [1]},
134 | 'gt': {'sub': [0], 'div': [1]}
135 | }
136 |
137 | model, optimizer, epoch_start, lr_scheduler = prepare_training()
138 |
139 | n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(','))
140 | if n_gpus > 1:
141 | model = nn.parallel.DataParallel(model)
142 |
143 | epoch_max = config['epoch_max']
144 | epoch_val = config.get('epoch_val')
145 | epoch_save = config.get('epoch_save')
146 | max_val_v = -1e18
147 |
148 | timer = utils.Timer()
149 |
150 | for epoch in range(epoch_start, epoch_max + 1):
151 | t_epoch_start = timer.t()
152 | log_info = ['epoch {}/{}'.format(epoch, epoch_max)]
153 |
154 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
155 |
156 | train_loss = train(train_loader, model, optimizer)
157 | if lr_scheduler is not None:
158 | lr_scheduler.step()
159 |
160 | log_info.append('train: loss={:.4f}'.format(train_loss))
161 | writer.add_scalars('loss', {'train': train_loss}, epoch)
162 |
163 | if n_gpus > 1:
164 | model_ = model.module
165 | else:
166 | model_ = model
167 | model_spec = config['model']
168 | model_spec['sd'] = model_.state_dict()
169 | optimizer_spec = config['optimizer']
170 | optimizer_spec['sd'] = optimizer.state_dict()
171 | sv_file = {
172 | 'model': model_spec,
173 | 'optimizer': optimizer_spec,
174 | 'epoch': epoch
175 | }
176 |
177 | torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth'))
178 |
179 | if (epoch_save is not None) and (epoch % epoch_save == 0):
180 | torch.save(sv_file,
181 | os.path.join(save_path, 'epoch-{}.pth'.format(epoch)))
182 |
183 | if (epoch_val is not None) and (epoch % epoch_val == 0):
184 | if n_gpus > 1 and (config.get('eval_bsize') is not None):
185 | model_ = model.module
186 | else:
187 | model_ = model
188 | val_res = eval_psnr(val_loader, model_,
189 | data_norm=config['data_norm'],
190 | eval_type=config.get('eval_type'),
191 | eval_bsize=config.get('eval_bsize'))
192 |
193 | log_info.append('val: psnr={:.4f}'.format(val_res))
194 | writer.add_scalars('psnr', {'val': val_res}, epoch)
195 | if val_res > max_val_v:
196 | max_val_v = val_res
197 | torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth'))
198 |
199 | t = timer.t()
200 | prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1)
201 | t_epoch = utils.time_text(t - t_epoch_start)
202 | t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog)
203 | log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all))
204 |
205 | log(', '.join(log_info))
206 | writer.flush()
207 |
208 |
209 | if __name__ == '__main__':
210 | parser = argparse.ArgumentParser()
211 | parser.add_argument('--config')
212 | parser.add_argument('--name', default=None)
213 | parser.add_argument('--tag', default=None)
214 | parser.add_argument('--gpu', default='0')
215 | args = parser.parse_args()
216 |
217 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
218 |
219 | with open(args.config, 'r') as f:
220 | config = yaml.load(f, Loader=yaml.FullLoader)
221 | print('config loaded.')
222 |
223 | save_name = args.name
224 | if save_name is None:
225 | save_name = '_' + args.config.split('/')[-1][:-len('.yaml')]
226 | if args.tag is not None:
227 | save_name += '_' + args.tag
228 | save_path = os.path.join('./save', save_name)
229 |
230 | main(config, save_path)
231 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import shutil
4 | import math
5 |
6 | import torch
7 | import numpy as np
8 | from torch.optim import SGD, Adam
9 | from tensorboardX import SummaryWriter
10 |
11 |
12 | class Averager():
13 |
14 | def __init__(self):
15 | self.n = 0.0
16 | self.v = 0.0
17 |
18 | def add(self, v, n=1.0):
19 | self.v = (self.v * self.n + v * n) / (self.n + n)
20 | self.n += n
21 |
22 | def item(self):
23 | return self.v
24 |
25 |
26 | class Timer():
27 |
28 | def __init__(self):
29 | self.v = time.time()
30 |
31 | def s(self):
32 | self.v = time.time()
33 |
34 | def t(self):
35 | return time.time() - self.v
36 |
37 |
38 | def time_text(t):
39 | if t >= 3600:
40 | return '{:.1f}h'.format(t / 3600)
41 | elif t >= 60:
42 | return '{:.1f}m'.format(t / 60)
43 | else:
44 | return '{:.1f}s'.format(t)
45 |
46 |
47 | _log_path = None
48 |
49 |
50 | def set_log_path(path):
51 | global _log_path
52 | _log_path = path
53 |
54 |
55 | def log(obj, filename='log.txt'):
56 | print(obj)
57 | if _log_path is not None:
58 | with open(os.path.join(_log_path, filename), 'a') as f:
59 | print(obj, file=f)
60 |
61 |
62 | def ensure_path(path, remove=True):
63 | basename = os.path.basename(path.rstrip('/'))
64 | if os.path.exists(path):
65 | if remove and (basename.startswith('_')
66 | or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'):
67 | shutil.rmtree(path)
68 | os.makedirs(path)
69 | else:
70 | os.makedirs(path)
71 |
72 |
73 | def set_save_path(save_path, remove=True):
74 | ensure_path(save_path, remove=remove)
75 | set_log_path(save_path)
76 | writer = SummaryWriter(os.path.join(save_path, 'tensorboard'))
77 | return log, writer
78 |
79 |
80 | def compute_num_params(model, text=False):
81 | tot = int(sum([np.prod(p.shape) for p in model.parameters()]))
82 | if text:
83 | if tot >= 1e6:
84 | return '{:.1f}M'.format(tot / 1e6)
85 | else:
86 | return '{:.1f}K'.format(tot / 1e3)
87 | else:
88 | return tot
89 |
90 |
91 | def make_optimizer(param_list, optimizer_spec, load_sd=False):
92 | Optimizer = {
93 | 'sgd': SGD,
94 | 'adam': Adam
95 | }[optimizer_spec['name']]
96 | optimizer = Optimizer(param_list, **optimizer_spec['args'])
97 | if load_sd:
98 | optimizer.load_state_dict(optimizer_spec['sd'])
99 | return optimizer
100 |
101 |
102 | def make_coord(shape, ranges=None, flatten=True):
103 | """ Make coordinates at grid centers.
104 | """
105 | coord_seqs = []
106 | for i, n in enumerate(shape):
107 | if ranges is None:
108 | v0, v1 = -1, 1
109 | else:
110 | v0, v1 = ranges[i]
111 | r = (v1 - v0) / (2 * n)
112 | seq = v0 + r + (2 * r) * torch.arange(n).float()
113 | coord_seqs.append(seq)
114 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1)
115 | if flatten:
116 | ret = ret.view(-1, ret.shape[-1])
117 | return ret
118 |
119 |
120 | def to_pixel_samples(img):
121 | """ Convert the image to coord-RGB pairs.
122 | img: Tensor, (3, H, W)
123 | """
124 | coord = make_coord(img.shape[-2:])
125 | rgb = img.view(3, -1).permute(1, 0)
126 | return coord, rgb
127 |
128 |
129 | def calc_psnr(sr, hr, dataset=None, scale=1, rgb_range=1):
130 | diff = (sr - hr) / rgb_range
131 | if dataset is not None:
132 | if dataset == 'benchmark':
133 | shave = scale
134 | if diff.size(1) > 1:
135 | gray_coeffs = [65.738, 129.057, 25.064]
136 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
137 | diff = diff.mul(convert).sum(dim=1)
138 | elif dataset == 'div2k':
139 | shave = scale + 6
140 | else:
141 | raise NotImplementedError
142 | valid = diff[..., shave:-shave, shave:-shave]
143 | else:
144 | valid = diff
145 | mse = valid.pow(2).mean()
146 | return -10 * torch.log10(mse)
147 |
--------------------------------------------------------------------------------