├── .gitignore ├── LICENSE ├── README.md ├── configs ├── test │ ├── test-b100-2.yaml │ ├── test-b100-3.yaml │ ├── test-b100-4.yaml │ ├── test-b100-6.yaml │ ├── test-b100-8.yaml │ ├── test-celebAHQ-32-256.yaml │ ├── test-celebAHQ-64-128.yaml │ ├── test-div2k-12.yaml │ ├── test-div2k-18.yaml │ ├── test-div2k-2.yaml │ ├── test-div2k-24.yaml │ ├── test-div2k-3.yaml │ ├── test-div2k-30.yaml │ ├── test-div2k-4.yaml │ ├── test-div2k-6.yaml │ ├── test-set14-2.yaml │ ├── test-set14-3.yaml │ ├── test-set14-4.yaml │ ├── test-set14-6.yaml │ ├── test-set14-8.yaml │ ├── test-set5-2.yaml │ ├── test-set5-3.yaml │ ├── test-set5-4.yaml │ ├── test-set5-6.yaml │ ├── test-set5-8.yaml │ ├── test-urban100-2.yaml │ ├── test-urban100-3.yaml │ ├── test-urban100-4.yaml │ ├── test-urban100-6.yaml │ └── test-urban100-8.yaml ├── train-celebAHQ │ ├── train_celebAHQ-32-256_liif.yaml │ ├── train_celebAHQ-32-256_upsample.yaml │ ├── train_celebAHQ-64-128_liif.yaml │ └── train_celebAHQ-64-128_upsample.yaml └── train-div2k │ ├── ablation │ ├── train_edsr-baseline-liif-c.yaml │ ├── train_edsr-baseline-liif-d.yaml │ ├── train_edsr-baseline-liif-e.yaml │ ├── train_edsr-baseline-liif-u.yaml │ ├── train_edsr-baseline-liif-x2.yaml │ ├── train_edsr-baseline-liif-x3.yaml │ └── train_edsr-baseline-liif-x4.yaml │ ├── train_edsr-baseline-liif.yaml │ ├── train_edsr-baseline-metasr.yaml │ ├── train_edsr-baseline-x2.yaml │ ├── train_rdn-liif.yaml │ ├── train_rdn-metasr.yaml │ └── train_rdn-x2.yaml ├── datasets ├── __init__.py ├── datasets.py ├── image_folder.py └── wrappers.py ├── demo.py ├── models ├── __init__.py ├── edsr.py ├── liif.py ├── misc.py ├── mlp.py ├── models.py ├── rcan.py └── rdn.py ├── scripts ├── resize.py ├── test-benchmark.sh └── test-div2k.sh ├── test.py ├── train_liif.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | folders 3 | load 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, Yinbo Chen 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LIIF 2 | 3 | This repository contains the official implementation for LIIF introduced in the following paper: 4 | 5 | [**Learning Continuous Image Representation with Local Implicit Image Function**](https://arxiv.org/abs/2012.09161) 6 |
7 | [Yinbo Chen](https://yinboc.github.io/), [Sifei Liu](https://www.sifeiliu.net/), [Xiaolong Wang](https://xiaolonw.github.io/) 8 |
9 | CVPR 2021 (Oral) 10 | 11 | The project page with video is at https://yinboc.github.io/liif/. 12 | 13 | 14 | 15 | ### Citation 16 | 17 | If you find our work useful in your research, please cite: 18 | 19 | ``` 20 | @inproceedings{chen2021learning, 21 | title={Learning continuous image representation with local implicit image function}, 22 | author={Chen, Yinbo and Liu, Sifei and Wang, Xiaolong}, 23 | booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, 24 | pages={8628--8638}, 25 | year={2021} 26 | } 27 | ``` 28 | 29 | ### Environment 30 | - Python 3 31 | - Pytorch 1.6.0 32 | - TensorboardX 33 | - yaml, numpy, tqdm, imageio 34 | 35 | ## Quick Start 36 | 37 | 1. Download a DIV2K pre-trained model. 38 | 39 | Model|File size|Download 40 | :-:|:-:|:-: 41 | EDSR-baseline-LIIF|18M|[Dropbox](https://www.dropbox.com/s/6f402wcn4v83w2v/edsr-baseline-liif.pth?dl=0) | [Google Drive](https://drive.google.com/file/d/1wBHSrgPLOHL_QVhPAIAcDC30KSJLf67x/view?usp=sharing) 42 | RDN-LIIF|256M|[Dropbox](https://www.dropbox.com/s/mzha6ll9kb9bwy0/rdn-liif.pth?dl=0) | [Google Drive](https://drive.google.com/file/d/1xaAx6lBVVw_PJ3YVp02h3k4HuOAXcUkt/view?usp=sharing) 43 | 44 | 2. Convert your image to LIIF and present it in a given resolution (with GPU 0, `[MODEL_PATH]` denotes the `.pth` file) 45 | 46 | ``` 47 | python demo.py --input xxx.png --model [MODEL_PATH] --resolution [HEIGHT],[WIDTH] --output output.png --gpu 0 48 | ``` 49 | 50 | ## Reproducing Experiments 51 | 52 | ### Data 53 | 54 | `mkdir load` for putting the dataset folders. 55 | 56 | - **DIV2K**: `mkdir` and `cd` into `load/div2k`. Download HR images and bicubic validation LR images from [DIV2K website](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (i.e. [Train_HR](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip), [Valid_HR](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_HR.zip), [Valid_LR_X2](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X2.zip), [Valid_LR_X3](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X3.zip), [Valid_LR_X4](http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_valid_LR_bicubic_X4.zip)). `unzip` these files to get the image folders. 57 | 58 | - **benchmark datasets**: `cd` into `load/`. Download and `tar -xf` the [benchmark datasets](https://cv.snu.ac.kr/research/EDSR/benchmark.tar) (provided by [this repo](https://github.com/thstkdgus35/EDSR-PyTorch)), get a `load/benchmark` folder with sub-folders `Set5/, Set14/, B100/, Urban100/`. 59 | 60 | - **celebAHQ**: `mkdir load/celebAHQ` and `cp scripts/resize.py load/celebAHQ/`, then `cd load/celebAHQ/`. Download and `unzip` data1024x1024.zip from the [Google Drive link](https://drive.google.com/drive/folders/11Vz0fqHS2rXDb5pprgTjpD7S2BAJhi1P?usp=sharing) (provided by [this repo](github.com/suvojit-0x55aa/celebA-HQ-dataset-download)). Run `python resize.py` and get image folders `256/, 128/, 64/, 32/`. Download the [split.json](https://www.dropbox.com/s/2qeijojdjzvp3b9/split.json?dl=0). 61 | 62 | ### Running the code 63 | 64 | **0. Preliminaries** 65 | 66 | - For `train_liif.py` or `test.py`, use `--gpu [GPU]` to specify the GPUs (e.g. `--gpu 0` or `--gpu 0,1`). 67 | 68 | - For `train_liif.py`, by default, the save folder is at `save/_[CONFIG_NAME]`. We can use `--name` to specify a name if needed. 69 | 70 | - For dataset args in configs, `cache: in_memory` denotes pre-loading into memory (may require large memory, e.g. ~40GB for DIV2K), `cache: bin` denotes creating binary files (in a sibling folder) for the first time, `cache: none` denotes direct loading. We can modify it according to the hardware resources before running the training scripts. 71 | 72 | **1. DIV2K experiments** 73 | 74 | **Train**: `python train_liif.py --config configs/train-div2k/train_edsr-baseline-liif.yaml` (with EDSR-baseline backbone, for RDN replace `edsr-baseline` with `rdn`). We use 1 GPU for training EDSR-baseline-LIIF and 4 GPUs for RDN-LIIF. 75 | 76 | **Test**: `bash scripts/test-div2k.sh [MODEL_PATH] [GPU]` for div2k validation set, `bash scripts/test-benchmark.sh [MODEL_PATH] [GPU]` for benchmark datasets. `[MODEL_PATH]` is the path to a `.pth` file, we use `epoch-last.pth` in corresponding save folder. 77 | 78 | **2. celebAHQ experiments** 79 | 80 | **Train**: `python train_liif.py --config configs/train-celebAHQ/[CONFIG_NAME].yaml`. 81 | 82 | **Test**: `python test.py --config configs/test/test-celebAHQ-32-256.yaml --model [MODEL_PATH]` (or `test-celebAHQ-64-128.yaml` for another task). We use `epoch-best.pth` in corresponding save folder. 83 | -------------------------------------------------------------------------------- /configs/test/test-b100-2.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X2 6 | root_path_2: ./load/benchmark/B100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-2 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-b100-3.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X3 6 | root_path_2: ./load/benchmark/B100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-3 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-b100-4.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/B100/LR_bicubic/X4 6 | root_path_2: ./load/benchmark/B100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-4 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-b100-6.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/B100/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 6 10 | batch_size: 1 11 | eval_type: benchmark-6 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-b100-8.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/B100/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 8 10 | batch_size: 1 11 | eval_type: benchmark-8 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-celebAHQ-32-256.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/32 6 | root_path_2: ./load/celebAHQ/256 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: val 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-paired 12 | args: {} 13 | batch_size: 1 14 | 15 | data_norm: 16 | inp: {sub: [0.5], div: [0.5]} 17 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-celebAHQ-64-128.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/64 6 | root_path_2: ./load/celebAHQ/128 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: val 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-paired 12 | args: {} 13 | batch_size: 1 14 | 15 | data_norm: 16 | inp: {sub: [0.5], div: [0.5]} 17 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-12.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_valid_HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 12 10 | batch_size: 1 11 | eval_type: div2k-12 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-18.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_valid_HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 18 10 | batch_size: 1 11 | eval_type: div2k-18 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-2.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2 6 | root_path_2: ./load/div2k/DIV2K_valid_HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: div2k-2 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-24.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_valid_HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 24 10 | batch_size: 1 11 | eval_type: div2k-24 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-3.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X3 6 | root_path_2: ./load/div2k/DIV2K_valid_HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: div2k-3 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-30.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_valid_HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 30 10 | batch_size: 1 11 | eval_type: div2k-30 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-4.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X4 6 | root_path_2: ./load/div2k/DIV2K_valid_HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: div2k-4 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-div2k-6.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_valid_HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 6 10 | batch_size: 1 11 | eval_type: div2k-6 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set14-2.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X2 6 | root_path_2: ./load/benchmark/Set14/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-2 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set14-3.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X3 6 | root_path_2: ./load/benchmark/Set14/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-3 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set14-4.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set14/LR_bicubic/X4 6 | root_path_2: ./load/benchmark/Set14/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-4 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set14-6.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Set14/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 6 10 | batch_size: 1 11 | eval_type: benchmark-6 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set14-8.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Set14/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 8 10 | batch_size: 1 11 | eval_type: benchmark-8 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set5-2.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X2 6 | root_path_2: ./load/benchmark/Set5/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-2 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set5-3.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X3 6 | root_path_2: ./load/benchmark/Set5/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-3 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set5-4.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Set5/LR_bicubic/X4 6 | root_path_2: ./load/benchmark/Set5/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-4 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set5-6.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Set5/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 6 10 | batch_size: 1 11 | eval_type: benchmark-6 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-set5-8.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Set5/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 8 10 | batch_size: 1 11 | eval_type: benchmark-8 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-urban100-2.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X2 6 | root_path_2: ./load/benchmark/Urban100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-2 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-urban100-3.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X3 6 | root_path_2: ./load/benchmark/Urban100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-3 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-urban100-4.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/benchmark/Urban100/LR_bicubic/X4 6 | root_path_2: ./load/benchmark/Urban100/HR 7 | wrapper: 8 | name: sr-implicit-paired 9 | args: {} 10 | batch_size: 1 11 | eval_type: benchmark-4 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-urban100-6.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Urban100/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 6 10 | batch_size: 1 11 | eval_type: benchmark-6 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/test/test-urban100-8.yaml: -------------------------------------------------------------------------------- 1 | test_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/benchmark/Urban100/HR 6 | wrapper: 7 | name: sr-implicit-downsampled 8 | args: 9 | scale_min: 8 10 | batch_size: 1 11 | eval_type: benchmark-8 12 | eval_bsize: 30000 13 | 14 | data_norm: 15 | inp: {sub: [0.5], div: [0.5]} 16 | gt: {sub: [0.5], div: [0.5]} -------------------------------------------------------------------------------- /configs/train-celebAHQ/train_celebAHQ-32-256_liif.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/32 6 | root_path_2: ./load/celebAHQ/256 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: train 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-uniform-varied 12 | args: 13 | size_min: 32 14 | size_max: 256 15 | sample_q: 1024 16 | augment: true 17 | batch_size: 16 18 | 19 | val_dataset: 20 | dataset: 21 | name: paired-image-folders 22 | args: 23 | root_path_1: ./load/celebAHQ/32 24 | root_path_2: ./load/celebAHQ/256 25 | split_file: ./load/celebAHQ/split.json 26 | split_key: val 27 | first_k: 100 28 | cache: bin 29 | wrapper: 30 | name: sr-implicit-paired 31 | args: 32 | sample_q: 1024 33 | batch_size: 16 34 | 35 | data_norm: 36 | inp: {sub: [0.5], div: [0.5]} 37 | gt: {sub: [0.5], div: [0.5]} 38 | 39 | model: 40 | name: liif 41 | args: 42 | encoder_spec: 43 | name: edsr-baseline 44 | args: 45 | no_upsampling: true 46 | imnet_spec: 47 | name: mlp 48 | args: 49 | out_dim: 3 50 | hidden_list: [256, 256, 256, 256] 51 | 52 | optimizer: 53 | name: adam 54 | args: 55 | lr: 1.e-4 56 | epoch_max: 200 57 | multi_step_lr: 58 | milestones: [100] 59 | gamma: 0.1 60 | 61 | epoch_val: 1 62 | epoch_save: 100 63 | -------------------------------------------------------------------------------- /configs/train-celebAHQ/train_celebAHQ-32-256_upsample.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/32 6 | root_path_2: ./load/celebAHQ/256 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: train 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-uniform-varied 12 | args: 13 | size_min: 32 14 | size_max: 256 15 | gt_resize: 256 16 | sample_q: 1024 17 | augment: true 18 | batch_size: 16 19 | 20 | val_dataset: 21 | dataset: 22 | name: paired-image-folders 23 | args: 24 | root_path_1: ./load/celebAHQ/32 25 | root_path_2: ./load/celebAHQ/256 26 | split_file: ./load/celebAHQ/split.json 27 | split_key: val 28 | first_k: 100 29 | cache: bin 30 | wrapper: 31 | name: sr-implicit-paired 32 | args: 33 | sample_q: 1024 34 | batch_size: 16 35 | 36 | data_norm: 37 | inp: {sub: [0.5], div: [0.5]} 38 | gt: {sub: [0.5], div: [0.5]} 39 | 40 | model: 41 | name: liif 42 | args: 43 | encoder_spec: 44 | name: edsr-baseline 45 | args: 46 | scale: 8 47 | 48 | optimizer: 49 | name: adam 50 | args: 51 | lr: 1.e-4 52 | epoch_max: 200 53 | multi_step_lr: 54 | milestones: [100] 55 | gamma: 0.1 56 | 57 | epoch_val: 1 58 | epoch_save: 100 59 | -------------------------------------------------------------------------------- /configs/train-celebAHQ/train_celebAHQ-64-128_liif.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/64 6 | root_path_2: ./load/celebAHQ/128 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: train 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-uniform-varied 12 | args: 13 | size_min: 64 14 | size_max: 128 15 | sample_q: 1024 16 | augment: true 17 | batch_size: 16 18 | 19 | val_dataset: 20 | dataset: 21 | name: paired-image-folders 22 | args: 23 | root_path_1: ./load/celebAHQ/64 24 | root_path_2: ./load/celebAHQ/128 25 | split_file: ./load/celebAHQ/split.json 26 | split_key: val 27 | first_k: 100 28 | cache: bin 29 | wrapper: 30 | name: sr-implicit-paired 31 | args: 32 | sample_q: 1024 33 | batch_size: 16 34 | 35 | data_norm: 36 | inp: {sub: [0.5], div: [0.5]} 37 | gt: {sub: [0.5], div: [0.5]} 38 | 39 | model: 40 | name: liif 41 | args: 42 | encoder_spec: 43 | name: edsr-baseline 44 | args: 45 | no_upsampling: true 46 | imnet_spec: 47 | name: mlp 48 | args: 49 | out_dim: 3 50 | hidden_list: [256, 256, 256, 256] 51 | 52 | optimizer: 53 | name: adam 54 | args: 55 | lr: 1.e-4 56 | epoch_max: 200 57 | multi_step_lr: 58 | milestones: [100] 59 | gamma: 0.1 60 | 61 | epoch_val: 1 62 | epoch_save: 100 63 | -------------------------------------------------------------------------------- /configs/train-celebAHQ/train_celebAHQ-64-128_upsample.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/celebAHQ/64 6 | root_path_2: ./load/celebAHQ/128 7 | split_file: ./load/celebAHQ/split.json 8 | split_key: train 9 | cache: bin 10 | wrapper: 11 | name: sr-implicit-uniform-varied 12 | args: 13 | size_min: 64 14 | size_max: 128 15 | gt_resize: 128 16 | sample_q: 1024 17 | augment: true 18 | batch_size: 16 19 | 20 | val_dataset: 21 | dataset: 22 | name: paired-image-folders 23 | args: 24 | root_path_1: ./load/celebAHQ/64 25 | root_path_2: ./load/celebAHQ/128 26 | split_file: ./load/celebAHQ/split.json 27 | split_key: val 28 | first_k: 100 29 | cache: bin 30 | wrapper: 31 | name: sr-implicit-paired 32 | args: 33 | sample_q: 1024 34 | batch_size: 16 35 | 36 | data_norm: 37 | inp: {sub: [0.5], div: [0.5]} 38 | gt: {sub: [0.5], div: [0.5]} 39 | 40 | model: 41 | name: liif 42 | args: 43 | encoder_spec: 44 | name: edsr-baseline 45 | args: 46 | scale: 2 47 | 48 | optimizer: 49 | name: adam 50 | args: 51 | lr: 1.e-4 52 | epoch_max: 200 53 | multi_step_lr: 54 | milestones: [100] 55 | gamma: 0.1 56 | 57 | epoch_val: 1 58 | epoch_save: 100 59 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-c.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | cell_decode: false 50 | 51 | optimizer: 52 | name: adam 53 | args: 54 | lr: 1.e-4 55 | epoch_max: 1000 56 | multi_step_lr: 57 | milestones: [200, 400, 600, 800] 58 | gamma: 0.5 59 | 60 | epoch_val: 1 61 | epoch_save: 100 62 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-d.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-e.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | local_ensemble: false 50 | 51 | optimizer: 52 | name: adam 53 | args: 54 | lr: 1.e-4 55 | epoch_max: 1000 56 | multi_step_lr: 57 | milestones: [200, 400, 600, 800] 58 | gamma: 0.5 59 | 60 | epoch_val: 1 61 | epoch_save: 100 62 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-u.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | feat_unfold: false 50 | 51 | optimizer: 52 | name: adam 53 | args: 54 | lr: 1.e-4 55 | epoch_max: 1000 56 | multi_step_lr: 57 | milestones: [200, 400, 600, 800] 58 | gamma: 0.5 59 | 60 | epoch_val: 1 61 | epoch_save: 100 62 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-x2.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_min: 2 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_min: 2 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-x3.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_min: 3 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_min: 3 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/ablation/train_edsr-baseline-liif-x4.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_min: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_min: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/train_edsr-baseline-liif.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/train_edsr-baseline-metasr.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: metasr 39 | args: 40 | encoder_spec: 41 | name: edsr-baseline 42 | args: 43 | no_upsampling: true 44 | 45 | optimizer: 46 | name: adam 47 | args: 48 | lr: 1.e-4 49 | epoch_max: 1000 50 | multi_step_lr: 51 | milestones: [200, 400, 600, 800] 52 | gamma: 0.5 53 | 54 | epoch_val: 1 55 | epoch_save: 100 56 | -------------------------------------------------------------------------------- /configs/train-div2k/train_edsr-baseline-x2.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/div2k/DIV2K_train_LR_bicubic/X2 6 | root_path_2: ./load/div2k/DIV2K_train_HR 7 | repeat: 20 8 | cache: in_memory 9 | wrapper: 10 | name: sr-implicit-paired 11 | args: 12 | inp_size: 48 13 | augment: true 14 | batch_size: 16 15 | 16 | val_dataset: 17 | dataset: 18 | name: paired-image-folders 19 | args: 20 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2 21 | root_path_2: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-paired 27 | args: 28 | inp_size: 48 29 | batch_size: 16 30 | 31 | data_norm: 32 | inp: {sub: [0.5], div: [0.5]} 33 | gt: {sub: [0.5], div: [0.5]} 34 | 35 | model: 36 | name: liif 37 | args: 38 | encoder_spec: 39 | name: edsr-baseline 40 | args: 41 | scale: 2 42 | 43 | optimizer: 44 | name: adam 45 | args: 46 | lr: 1.e-4 47 | epoch_max: 1000 48 | multi_step_lr: 49 | milestones: [200, 400, 600, 800] 50 | gamma: 0.5 51 | 52 | epoch_val: 1 53 | epoch_save: 100 54 | -------------------------------------------------------------------------------- /configs/train-div2k/train_rdn-liif.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: liif 39 | args: 40 | encoder_spec: 41 | name: rdn 42 | args: 43 | no_upsampling: true 44 | imnet_spec: 45 | name: mlp 46 | args: 47 | out_dim: 3 48 | hidden_list: [256, 256, 256, 256] 49 | 50 | optimizer: 51 | name: adam 52 | args: 53 | lr: 1.e-4 54 | epoch_max: 1000 55 | multi_step_lr: 56 | milestones: [200, 400, 600, 800] 57 | gamma: 0.5 58 | 59 | epoch_val: 1 60 | epoch_save: 100 61 | -------------------------------------------------------------------------------- /configs/train-div2k/train_rdn-metasr.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: image-folder 4 | args: 5 | root_path: ./load/div2k/DIV2K_train_HR 6 | repeat: 20 7 | cache: in_memory 8 | wrapper: 9 | name: sr-implicit-downsampled 10 | args: 11 | inp_size: 48 12 | scale_max: 4 13 | augment: true 14 | sample_q: 2304 15 | batch_size: 16 16 | 17 | val_dataset: 18 | dataset: 19 | name: image-folder 20 | args: 21 | root_path: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-downsampled 27 | args: 28 | inp_size: 48 29 | scale_max: 4 30 | sample_q: 2304 31 | batch_size: 16 32 | 33 | data_norm: 34 | inp: {sub: [0.5], div: [0.5]} 35 | gt: {sub: [0.5], div: [0.5]} 36 | 37 | model: 38 | name: metasr 39 | args: 40 | encoder_spec: 41 | name: rdn 42 | args: 43 | no_upsampling: true 44 | 45 | optimizer: 46 | name: adam 47 | args: 48 | lr: 1.e-4 49 | epoch_max: 1000 50 | multi_step_lr: 51 | milestones: [200, 400, 600, 800] 52 | gamma: 0.5 53 | 54 | epoch_val: 1 55 | epoch_save: 100 56 | -------------------------------------------------------------------------------- /configs/train-div2k/train_rdn-x2.yaml: -------------------------------------------------------------------------------- 1 | train_dataset: 2 | dataset: 3 | name: paired-image-folders 4 | args: 5 | root_path_1: ./load/div2k/DIV2K_train_LR_bicubic/X2 6 | root_path_2: ./load/div2k/DIV2K_train_HR 7 | repeat: 20 8 | cache: in_memory 9 | wrapper: 10 | name: sr-implicit-paired 11 | args: 12 | inp_size: 48 13 | augment: true 14 | batch_size: 16 15 | 16 | val_dataset: 17 | dataset: 18 | name: paired-image-folders 19 | args: 20 | root_path_1: ./load/div2k/DIV2K_valid_LR_bicubic/X2 21 | root_path_2: ./load/div2k/DIV2K_valid_HR 22 | first_k: 10 23 | repeat: 160 24 | cache: in_memory 25 | wrapper: 26 | name: sr-implicit-paired 27 | args: 28 | inp_size: 48 29 | batch_size: 16 30 | 31 | data_norm: 32 | inp: {sub: [0.5], div: [0.5]} 33 | gt: {sub: [0.5], div: [0.5]} 34 | 35 | model: 36 | name: liif 37 | args: 38 | encoder_spec: 39 | name: rdn 40 | args: 41 | scale: 2 42 | 43 | optimizer: 44 | name: adam 45 | args: 46 | lr: 1.e-4 47 | epoch_max: 1000 48 | multi_step_lr: 49 | milestones: [200, 400, 600, 800] 50 | gamma: 0.5 51 | 52 | epoch_val: 1 53 | epoch_save: 100 54 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .datasets import register, make 2 | from . import image_folder 3 | from . import wrappers 4 | -------------------------------------------------------------------------------- /datasets/datasets.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | datasets = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | datasets[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(dataset_spec, args=None): 15 | if args is not None: 16 | dataset_args = copy.deepcopy(dataset_spec['args']) 17 | dataset_args.update(args) 18 | else: 19 | dataset_args = dataset_spec['args'] 20 | dataset = datasets[dataset_spec['name']](**dataset_args) 21 | return dataset 22 | -------------------------------------------------------------------------------- /datasets/image_folder.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from PIL import Image 4 | 5 | import pickle 6 | import imageio 7 | import numpy as np 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torchvision import transforms 11 | 12 | from datasets import register 13 | 14 | 15 | @register('image-folder') 16 | class ImageFolder(Dataset): 17 | 18 | def __init__(self, root_path, split_file=None, split_key=None, first_k=None, 19 | repeat=1, cache='none'): 20 | self.repeat = repeat 21 | self.cache = cache 22 | 23 | if split_file is None: 24 | filenames = sorted(os.listdir(root_path)) 25 | else: 26 | with open(split_file, 'r') as f: 27 | filenames = json.load(f)[split_key] 28 | if first_k is not None: 29 | filenames = filenames[:first_k] 30 | 31 | self.files = [] 32 | for filename in filenames: 33 | file = os.path.join(root_path, filename) 34 | 35 | if cache == 'none': 36 | self.files.append(file) 37 | 38 | elif cache == 'bin': 39 | bin_root = os.path.join(os.path.dirname(root_path), 40 | '_bin_' + os.path.basename(root_path)) 41 | if not os.path.exists(bin_root): 42 | os.mkdir(bin_root) 43 | print('mkdir', bin_root) 44 | bin_file = os.path.join( 45 | bin_root, filename.split('.')[0] + '.pkl') 46 | if not os.path.exists(bin_file): 47 | with open(bin_file, 'wb') as f: 48 | pickle.dump(imageio.imread(file), f) 49 | print('dump', bin_file) 50 | self.files.append(bin_file) 51 | 52 | elif cache == 'in_memory': 53 | self.files.append(transforms.ToTensor()( 54 | Image.open(file).convert('RGB'))) 55 | 56 | def __len__(self): 57 | return len(self.files) * self.repeat 58 | 59 | def __getitem__(self, idx): 60 | x = self.files[idx % len(self.files)] 61 | 62 | if self.cache == 'none': 63 | return transforms.ToTensor()(Image.open(x).convert('RGB')) 64 | 65 | elif self.cache == 'bin': 66 | with open(x, 'rb') as f: 67 | x = pickle.load(f) 68 | x = np.ascontiguousarray(x.transpose(2, 0, 1)) 69 | x = torch.from_numpy(x).float() / 255 70 | return x 71 | 72 | elif self.cache == 'in_memory': 73 | return x 74 | 75 | 76 | @register('paired-image-folders') 77 | class PairedImageFolders(Dataset): 78 | 79 | def __init__(self, root_path_1, root_path_2, **kwargs): 80 | self.dataset_1 = ImageFolder(root_path_1, **kwargs) 81 | self.dataset_2 = ImageFolder(root_path_2, **kwargs) 82 | 83 | def __len__(self): 84 | return len(self.dataset_1) 85 | 86 | def __getitem__(self, idx): 87 | return self.dataset_1[idx], self.dataset_2[idx] 88 | -------------------------------------------------------------------------------- /datasets/wrappers.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import numpy as np 7 | import torch 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | 11 | from datasets import register 12 | from utils import to_pixel_samples 13 | 14 | 15 | @register('sr-implicit-paired') 16 | class SRImplicitPaired(Dataset): 17 | 18 | def __init__(self, dataset, inp_size=None, augment=False, sample_q=None): 19 | self.dataset = dataset 20 | self.inp_size = inp_size 21 | self.augment = augment 22 | self.sample_q = sample_q 23 | 24 | def __len__(self): 25 | return len(self.dataset) 26 | 27 | def __getitem__(self, idx): 28 | img_lr, img_hr = self.dataset[idx] 29 | 30 | s = img_hr.shape[-2] // img_lr.shape[-2] # assume int scale 31 | if self.inp_size is None: 32 | h_lr, w_lr = img_lr.shape[-2:] 33 | img_hr = img_hr[:, :h_lr * s, :w_lr * s] 34 | crop_lr, crop_hr = img_lr, img_hr 35 | else: 36 | w_lr = self.inp_size 37 | x0 = random.randint(0, img_lr.shape[-2] - w_lr) 38 | y0 = random.randint(0, img_lr.shape[-1] - w_lr) 39 | crop_lr = img_lr[:, x0: x0 + w_lr, y0: y0 + w_lr] 40 | w_hr = w_lr * s 41 | x1 = x0 * s 42 | y1 = y0 * s 43 | crop_hr = img_hr[:, x1: x1 + w_hr, y1: y1 + w_hr] 44 | 45 | if self.augment: 46 | hflip = random.random() < 0.5 47 | vflip = random.random() < 0.5 48 | dflip = random.random() < 0.5 49 | 50 | def augment(x): 51 | if hflip: 52 | x = x.flip(-2) 53 | if vflip: 54 | x = x.flip(-1) 55 | if dflip: 56 | x = x.transpose(-2, -1) 57 | return x 58 | 59 | crop_lr = augment(crop_lr) 60 | crop_hr = augment(crop_hr) 61 | 62 | hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) 63 | 64 | if self.sample_q is not None: 65 | sample_lst = np.random.choice( 66 | len(hr_coord), self.sample_q, replace=False) 67 | hr_coord = hr_coord[sample_lst] 68 | hr_rgb = hr_rgb[sample_lst] 69 | 70 | cell = torch.ones_like(hr_coord) 71 | cell[:, 0] *= 2 / crop_hr.shape[-2] 72 | cell[:, 1] *= 2 / crop_hr.shape[-1] 73 | 74 | return { 75 | 'inp': crop_lr, 76 | 'coord': hr_coord, 77 | 'cell': cell, 78 | 'gt': hr_rgb 79 | } 80 | 81 | 82 | def resize_fn(img, size): 83 | return transforms.ToTensor()( 84 | transforms.Resize(size, Image.BICUBIC)( 85 | transforms.ToPILImage()(img))) 86 | 87 | 88 | @register('sr-implicit-downsampled') 89 | class SRImplicitDownsampled(Dataset): 90 | 91 | def __init__(self, dataset, inp_size=None, scale_min=1, scale_max=None, 92 | augment=False, sample_q=None): 93 | self.dataset = dataset 94 | self.inp_size = inp_size 95 | self.scale_min = scale_min 96 | if scale_max is None: 97 | scale_max = scale_min 98 | self.scale_max = scale_max 99 | self.augment = augment 100 | self.sample_q = sample_q 101 | 102 | def __len__(self): 103 | return len(self.dataset) 104 | 105 | def __getitem__(self, idx): 106 | img = self.dataset[idx] 107 | s = random.uniform(self.scale_min, self.scale_max) 108 | 109 | if self.inp_size is None: 110 | h_lr = math.floor(img.shape[-2] / s + 1e-9) 111 | w_lr = math.floor(img.shape[-1] / s + 1e-9) 112 | img = img[:, :round(h_lr * s), :round(w_lr * s)] # assume round int 113 | img_down = resize_fn(img, (h_lr, w_lr)) 114 | crop_lr, crop_hr = img_down, img 115 | else: 116 | w_lr = self.inp_size 117 | w_hr = round(w_lr * s) 118 | x0 = random.randint(0, img.shape[-2] - w_hr) 119 | y0 = random.randint(0, img.shape[-1] - w_hr) 120 | crop_hr = img[:, x0: x0 + w_hr, y0: y0 + w_hr] 121 | crop_lr = resize_fn(crop_hr, w_lr) 122 | 123 | if self.augment: 124 | hflip = random.random() < 0.5 125 | vflip = random.random() < 0.5 126 | dflip = random.random() < 0.5 127 | 128 | def augment(x): 129 | if hflip: 130 | x = x.flip(-2) 131 | if vflip: 132 | x = x.flip(-1) 133 | if dflip: 134 | x = x.transpose(-2, -1) 135 | return x 136 | 137 | crop_lr = augment(crop_lr) 138 | crop_hr = augment(crop_hr) 139 | 140 | hr_coord, hr_rgb = to_pixel_samples(crop_hr.contiguous()) 141 | 142 | if self.sample_q is not None: 143 | sample_lst = np.random.choice( 144 | len(hr_coord), self.sample_q, replace=False) 145 | hr_coord = hr_coord[sample_lst] 146 | hr_rgb = hr_rgb[sample_lst] 147 | 148 | cell = torch.ones_like(hr_coord) 149 | cell[:, 0] *= 2 / crop_hr.shape[-2] 150 | cell[:, 1] *= 2 / crop_hr.shape[-1] 151 | 152 | return { 153 | 'inp': crop_lr, 154 | 'coord': hr_coord, 155 | 'cell': cell, 156 | 'gt': hr_rgb 157 | } 158 | 159 | 160 | @register('sr-implicit-uniform-varied') 161 | class SRImplicitUniformVaried(Dataset): 162 | 163 | def __init__(self, dataset, size_min, size_max=None, 164 | augment=False, gt_resize=None, sample_q=None): 165 | self.dataset = dataset 166 | self.size_min = size_min 167 | if size_max is None: 168 | size_max = size_min 169 | self.size_max = size_max 170 | self.augment = augment 171 | self.gt_resize = gt_resize 172 | self.sample_q = sample_q 173 | 174 | def __len__(self): 175 | return len(self.dataset) 176 | 177 | def __getitem__(self, idx): 178 | img_lr, img_hr = self.dataset[idx] 179 | p = idx / (len(self.dataset) - 1) 180 | w_hr = round(self.size_min + (self.size_max - self.size_min) * p) 181 | img_hr = resize_fn(img_hr, w_hr) 182 | 183 | if self.augment: 184 | if random.random() < 0.5: 185 | img_lr = img_lr.flip(-1) 186 | img_hr = img_hr.flip(-1) 187 | 188 | if self.gt_resize is not None: 189 | img_hr = resize_fn(img_hr, self.gt_resize) 190 | 191 | hr_coord, hr_rgb = to_pixel_samples(img_hr) 192 | 193 | if self.sample_q is not None: 194 | sample_lst = np.random.choice( 195 | len(hr_coord), self.sample_q, replace=False) 196 | hr_coord = hr_coord[sample_lst] 197 | hr_rgb = hr_rgb[sample_lst] 198 | 199 | cell = torch.ones_like(hr_coord) 200 | cell[:, 0] *= 2 / img_hr.shape[-2] 201 | cell[:, 1] *= 2 / img_hr.shape[-1] 202 | 203 | return { 204 | 'inp': img_lr, 205 | 'coord': hr_coord, 206 | 'cell': cell, 207 | 'gt': hr_rgb 208 | } 209 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | 5 | import torch 6 | from torchvision import transforms 7 | 8 | import models 9 | from utils import make_coord 10 | from test import batched_predict 11 | 12 | 13 | if __name__ == '__main__': 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--input', default='input.png') 16 | parser.add_argument('--model') 17 | parser.add_argument('--resolution') 18 | parser.add_argument('--output', default='output.png') 19 | parser.add_argument('--gpu', default='0') 20 | args = parser.parse_args() 21 | 22 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 23 | 24 | img = transforms.ToTensor()(Image.open(args.input).convert('RGB')) 25 | 26 | model = models.make(torch.load(args.model)['model'], load_sd=True).cuda() 27 | 28 | h, w = list(map(int, args.resolution.split(','))) 29 | coord = make_coord((h, w)).cuda() 30 | cell = torch.ones_like(coord) 31 | cell[:, 0] *= 2 / h 32 | cell[:, 1] *= 2 / w 33 | pred = batched_predict(model, ((img - 0.5) / 0.5).cuda().unsqueeze(0), 34 | coord.unsqueeze(0), cell.unsqueeze(0), bsize=30000)[0] 35 | pred = (pred * 0.5 + 0.5).clamp(0, 1).view(h, w, 3).permute(2, 0, 1).cpu() 36 | transforms.ToPILImage()(pred).save(args.output) 37 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import register, make 2 | from . import edsr, rdn, rcan 3 | from . import mlp 4 | from . import liif 5 | from . import misc 6 | -------------------------------------------------------------------------------- /models/edsr.py: -------------------------------------------------------------------------------- 1 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch 2 | 3 | import math 4 | from argparse import Namespace 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from models import register 11 | 12 | 13 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 14 | return nn.Conv2d( 15 | in_channels, out_channels, kernel_size, 16 | padding=(kernel_size//2), bias=bias) 17 | 18 | class MeanShift(nn.Conv2d): 19 | def __init__( 20 | self, rgb_range, 21 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 22 | 23 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 24 | std = torch.Tensor(rgb_std) 25 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 26 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 27 | for p in self.parameters(): 28 | p.requires_grad = False 29 | 30 | class ResBlock(nn.Module): 31 | def __init__( 32 | self, conv, n_feats, kernel_size, 33 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 34 | 35 | super(ResBlock, self).__init__() 36 | m = [] 37 | for i in range(2): 38 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 39 | if bn: 40 | m.append(nn.BatchNorm2d(n_feats)) 41 | if i == 0: 42 | m.append(act) 43 | 44 | self.body = nn.Sequential(*m) 45 | self.res_scale = res_scale 46 | 47 | def forward(self, x): 48 | res = self.body(x).mul(self.res_scale) 49 | res += x 50 | 51 | return res 52 | 53 | class Upsampler(nn.Sequential): 54 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 55 | 56 | m = [] 57 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 58 | for _ in range(int(math.log(scale, 2))): 59 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 60 | m.append(nn.PixelShuffle(2)) 61 | if bn: 62 | m.append(nn.BatchNorm2d(n_feats)) 63 | if act == 'relu': 64 | m.append(nn.ReLU(True)) 65 | elif act == 'prelu': 66 | m.append(nn.PReLU(n_feats)) 67 | 68 | elif scale == 3: 69 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 70 | m.append(nn.PixelShuffle(3)) 71 | if bn: 72 | m.append(nn.BatchNorm2d(n_feats)) 73 | if act == 'relu': 74 | m.append(nn.ReLU(True)) 75 | elif act == 'prelu': 76 | m.append(nn.PReLU(n_feats)) 77 | else: 78 | raise NotImplementedError 79 | 80 | super(Upsampler, self).__init__(*m) 81 | 82 | 83 | url = { 84 | 'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt', 85 | 'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt', 86 | 'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt', 87 | 'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt', 88 | 'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt', 89 | 'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt' 90 | } 91 | 92 | class EDSR(nn.Module): 93 | def __init__(self, args, conv=default_conv): 94 | super(EDSR, self).__init__() 95 | self.args = args 96 | n_resblocks = args.n_resblocks 97 | n_feats = args.n_feats 98 | kernel_size = 3 99 | scale = args.scale[0] 100 | act = nn.ReLU(True) 101 | url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) 102 | if url_name in url: 103 | self.url = url[url_name] 104 | else: 105 | self.url = None 106 | self.sub_mean = MeanShift(args.rgb_range) 107 | self.add_mean = MeanShift(args.rgb_range, sign=1) 108 | 109 | # define head module 110 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 111 | 112 | # define body module 113 | m_body = [ 114 | ResBlock( 115 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 116 | ) for _ in range(n_resblocks) 117 | ] 118 | m_body.append(conv(n_feats, n_feats, kernel_size)) 119 | 120 | self.head = nn.Sequential(*m_head) 121 | self.body = nn.Sequential(*m_body) 122 | 123 | if args.no_upsampling: 124 | self.out_dim = n_feats 125 | else: 126 | self.out_dim = args.n_colors 127 | # define tail module 128 | m_tail = [ 129 | Upsampler(conv, scale, n_feats, act=False), 130 | conv(n_feats, args.n_colors, kernel_size) 131 | ] 132 | self.tail = nn.Sequential(*m_tail) 133 | 134 | def forward(self, x): 135 | #x = self.sub_mean(x) 136 | x = self.head(x) 137 | 138 | res = self.body(x) 139 | res += x 140 | 141 | if self.args.no_upsampling: 142 | x = res 143 | else: 144 | x = self.tail(res) 145 | #x = self.add_mean(x) 146 | return x 147 | 148 | def load_state_dict(self, state_dict, strict=True): 149 | own_state = self.state_dict() 150 | for name, param in state_dict.items(): 151 | if name in own_state: 152 | if isinstance(param, nn.Parameter): 153 | param = param.data 154 | try: 155 | own_state[name].copy_(param) 156 | except Exception: 157 | if name.find('tail') == -1: 158 | raise RuntimeError('While copying the parameter named {}, ' 159 | 'whose dimensions in the model are {} and ' 160 | 'whose dimensions in the checkpoint are {}.' 161 | .format(name, own_state[name].size(), param.size())) 162 | elif strict: 163 | if name.find('tail') == -1: 164 | raise KeyError('unexpected key "{}" in state_dict' 165 | .format(name)) 166 | 167 | 168 | @register('edsr-baseline') 169 | def make_edsr_baseline(n_resblocks=16, n_feats=64, res_scale=1, 170 | scale=2, no_upsampling=False, rgb_range=1): 171 | args = Namespace() 172 | args.n_resblocks = n_resblocks 173 | args.n_feats = n_feats 174 | args.res_scale = res_scale 175 | 176 | args.scale = [scale] 177 | args.no_upsampling = no_upsampling 178 | 179 | args.rgb_range = rgb_range 180 | args.n_colors = 3 181 | return EDSR(args) 182 | 183 | 184 | @register('edsr') 185 | def make_edsr(n_resblocks=32, n_feats=256, res_scale=0.1, 186 | scale=2, no_upsampling=False, rgb_range=1): 187 | args = Namespace() 188 | args.n_resblocks = n_resblocks 189 | args.n_feats = n_feats 190 | args.res_scale = res_scale 191 | 192 | args.scale = [scale] 193 | args.no_upsampling = no_upsampling 194 | 195 | args.rgb_range = rgb_range 196 | args.n_colors = 3 197 | return EDSR(args) 198 | -------------------------------------------------------------------------------- /models/liif.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import models 6 | from models import register 7 | from utils import make_coord 8 | 9 | 10 | @register('liif') 11 | class LIIF(nn.Module): 12 | 13 | def __init__(self, encoder_spec, imnet_spec=None, 14 | local_ensemble=True, feat_unfold=True, cell_decode=True): 15 | super().__init__() 16 | self.local_ensemble = local_ensemble 17 | self.feat_unfold = feat_unfold 18 | self.cell_decode = cell_decode 19 | 20 | self.encoder = models.make(encoder_spec) 21 | 22 | if imnet_spec is not None: 23 | imnet_in_dim = self.encoder.out_dim 24 | if self.feat_unfold: 25 | imnet_in_dim *= 9 26 | imnet_in_dim += 2 # attach coord 27 | if self.cell_decode: 28 | imnet_in_dim += 2 29 | self.imnet = models.make(imnet_spec, args={'in_dim': imnet_in_dim}) 30 | else: 31 | self.imnet = None 32 | 33 | def gen_feat(self, inp): 34 | self.feat = self.encoder(inp) 35 | return self.feat 36 | 37 | def query_rgb(self, coord, cell=None): 38 | feat = self.feat 39 | 40 | if self.imnet is None: 41 | ret = F.grid_sample(feat, coord.flip(-1).unsqueeze(1), 42 | mode='nearest', align_corners=False)[:, :, 0, :] \ 43 | .permute(0, 2, 1) 44 | return ret 45 | 46 | if self.feat_unfold: 47 | feat = F.unfold(feat, 3, padding=1).view( 48 | feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) 49 | 50 | if self.local_ensemble: 51 | vx_lst = [-1, 1] 52 | vy_lst = [-1, 1] 53 | eps_shift = 1e-6 54 | else: 55 | vx_lst, vy_lst, eps_shift = [0], [0], 0 56 | 57 | # field radius (global: [-1, 1]) 58 | rx = 2 / feat.shape[-2] / 2 59 | ry = 2 / feat.shape[-1] / 2 60 | 61 | feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() \ 62 | .permute(2, 0, 1) \ 63 | .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) 64 | 65 | preds = [] 66 | areas = [] 67 | for vx in vx_lst: 68 | for vy in vy_lst: 69 | coord_ = coord.clone() 70 | coord_[:, :, 0] += vx * rx + eps_shift 71 | coord_[:, :, 1] += vy * ry + eps_shift 72 | coord_.clamp_(-1 + 1e-6, 1 - 1e-6) 73 | q_feat = F.grid_sample( 74 | feat, coord_.flip(-1).unsqueeze(1), 75 | mode='nearest', align_corners=False)[:, :, 0, :] \ 76 | .permute(0, 2, 1) 77 | q_coord = F.grid_sample( 78 | feat_coord, coord_.flip(-1).unsqueeze(1), 79 | mode='nearest', align_corners=False)[:, :, 0, :] \ 80 | .permute(0, 2, 1) 81 | rel_coord = coord - q_coord 82 | rel_coord[:, :, 0] *= feat.shape[-2] 83 | rel_coord[:, :, 1] *= feat.shape[-1] 84 | inp = torch.cat([q_feat, rel_coord], dim=-1) 85 | 86 | if self.cell_decode: 87 | rel_cell = cell.clone() 88 | rel_cell[:, :, 0] *= feat.shape[-2] 89 | rel_cell[:, :, 1] *= feat.shape[-1] 90 | inp = torch.cat([inp, rel_cell], dim=-1) 91 | 92 | bs, q = coord.shape[:2] 93 | pred = self.imnet(inp.view(bs * q, -1)).view(bs, q, -1) 94 | preds.append(pred) 95 | 96 | area = torch.abs(rel_coord[:, :, 0] * rel_coord[:, :, 1]) 97 | areas.append(area + 1e-9) 98 | 99 | tot_area = torch.stack(areas).sum(dim=0) 100 | if self.local_ensemble: 101 | t = areas[0]; areas[0] = areas[3]; areas[3] = t 102 | t = areas[1]; areas[1] = areas[2]; areas[2] = t 103 | ret = 0 104 | for pred, area in zip(preds, areas): 105 | ret = ret + pred * (area / tot_area).unsqueeze(-1) 106 | return ret 107 | 108 | def forward(self, inp, coord, cell): 109 | self.gen_feat(inp) 110 | return self.query_rgb(coord, cell) 111 | -------------------------------------------------------------------------------- /models/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | import models 6 | from models import register 7 | from utils import make_coord 8 | 9 | 10 | @register('metasr') 11 | class MetaSR(nn.Module): 12 | 13 | def __init__(self, encoder_spec): 14 | super().__init__() 15 | 16 | self.encoder = models.make(encoder_spec) 17 | imnet_spec = { 18 | 'name': 'mlp', 19 | 'args': { 20 | 'in_dim': 3, 21 | 'out_dim': self.encoder.out_dim * 9 * 3, 22 | 'hidden_list': [256] 23 | } 24 | } 25 | self.imnet = models.make(imnet_spec) 26 | 27 | def gen_feat(self, inp): 28 | self.feat = self.encoder(inp) 29 | return self.feat 30 | 31 | def query_rgb(self, coord, cell=None): 32 | feat = self.feat 33 | feat = F.unfold(feat, 3, padding=1).view( 34 | feat.shape[0], feat.shape[1] * 9, feat.shape[2], feat.shape[3]) 35 | 36 | feat_coord = make_coord(feat.shape[-2:], flatten=False).cuda() 37 | feat_coord[:, :, 0] -= (2 / feat.shape[-2]) / 2 38 | feat_coord[:, :, 1] -= (2 / feat.shape[-1]) / 2 39 | feat_coord = feat_coord.permute(2, 0, 1) \ 40 | .unsqueeze(0).expand(feat.shape[0], 2, *feat.shape[-2:]) 41 | 42 | coord_ = coord.clone() 43 | coord_[:, :, 0] -= cell[:, :, 0] / 2 44 | coord_[:, :, 1] -= cell[:, :, 1] / 2 45 | coord_q = (coord_ + 1e-6).clamp(-1 + 1e-6, 1 - 1e-6) 46 | q_feat = F.grid_sample( 47 | feat, coord_q.flip(-1).unsqueeze(1), 48 | mode='nearest', align_corners=False)[:, :, 0, :] \ 49 | .permute(0, 2, 1) 50 | q_coord = F.grid_sample( 51 | feat_coord, coord_q.flip(-1).unsqueeze(1), 52 | mode='nearest', align_corners=False)[:, :, 0, :] \ 53 | .permute(0, 2, 1) 54 | 55 | rel_coord = coord_ - q_coord 56 | rel_coord[:, :, 0] *= feat.shape[-2] / 2 57 | rel_coord[:, :, 1] *= feat.shape[-1] / 2 58 | 59 | r_rev = cell[:, :, 0] * (feat.shape[-2] / 2) 60 | inp = torch.cat([rel_coord, r_rev.unsqueeze(-1)], dim=-1) 61 | 62 | bs, q = coord.shape[:2] 63 | pred = self.imnet(inp.view(bs * q, -1)).view(bs * q, feat.shape[1], 3) 64 | pred = torch.bmm(q_feat.contiguous().view(bs * q, 1, -1), pred) 65 | pred = pred.view(bs, q, 3) 66 | return pred 67 | 68 | def forward(self, inp, coord, cell): 69 | self.gen_feat(inp) 70 | return self.query_rgb(coord, cell) 71 | -------------------------------------------------------------------------------- /models/mlp.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from models import register 4 | 5 | 6 | @register('mlp') 7 | class MLP(nn.Module): 8 | 9 | def __init__(self, in_dim, out_dim, hidden_list): 10 | super().__init__() 11 | layers = [] 12 | lastv = in_dim 13 | for hidden in hidden_list: 14 | layers.append(nn.Linear(lastv, hidden)) 15 | layers.append(nn.ReLU()) 16 | lastv = hidden 17 | layers.append(nn.Linear(lastv, out_dim)) 18 | self.layers = nn.Sequential(*layers) 19 | 20 | def forward(self, x): 21 | shape = x.shape[:-1] 22 | x = self.layers(x.view(-1, x.shape[-1])) 23 | return x.view(*shape, -1) 24 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | 4 | models = {} 5 | 6 | 7 | def register(name): 8 | def decorator(cls): 9 | models[name] = cls 10 | return cls 11 | return decorator 12 | 13 | 14 | def make(model_spec, args=None, load_sd=False): 15 | if args is not None: 16 | model_args = copy.deepcopy(model_spec['args']) 17 | model_args.update(args) 18 | else: 19 | model_args = model_spec['args'] 20 | model = models[model_spec['name']](**model_args) 21 | if load_sd: 22 | model.load_state_dict(model_spec['sd']) 23 | return model 24 | -------------------------------------------------------------------------------- /models/rcan.py: -------------------------------------------------------------------------------- 1 | import math 2 | from argparse import Namespace 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from models import register 8 | 9 | 10 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 11 | return nn.Conv2d( 12 | in_channels, out_channels, kernel_size, 13 | padding=(kernel_size//2), bias=bias) 14 | 15 | class MeanShift(nn.Conv2d): 16 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 17 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 18 | std = torch.Tensor(rgb_std) 19 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 20 | self.weight.data.div_(std.view(3, 1, 1, 1)) 21 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 22 | self.bias.data.div_(std) 23 | self.requires_grad = False 24 | 25 | class Upsampler(nn.Sequential): 26 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 27 | 28 | m = [] 29 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 30 | for _ in range(int(math.log(scale, 2))): 31 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 32 | m.append(nn.PixelShuffle(2)) 33 | if bn: m.append(nn.BatchNorm2d(n_feat)) 34 | if act: m.append(act()) 35 | elif scale == 3: 36 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 37 | m.append(nn.PixelShuffle(3)) 38 | if bn: m.append(nn.BatchNorm2d(n_feat)) 39 | if act: m.append(act()) 40 | else: 41 | raise NotImplementedError 42 | 43 | super(Upsampler, self).__init__(*m) 44 | 45 | ## Channel Attention (CA) Layer 46 | class CALayer(nn.Module): 47 | def __init__(self, channel, reduction=16): 48 | super(CALayer, self).__init__() 49 | # global average pooling: feature --> point 50 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 51 | # feature channel downscale and upscale --> channel weight 52 | self.conv_du = nn.Sequential( 53 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 54 | nn.ReLU(inplace=True), 55 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 56 | nn.Sigmoid() 57 | ) 58 | 59 | def forward(self, x): 60 | y = self.avg_pool(x) 61 | y = self.conv_du(y) 62 | return x * y 63 | 64 | ## Residual Channel Attention Block (RCAB) 65 | class RCAB(nn.Module): 66 | def __init__( 67 | self, conv, n_feat, kernel_size, reduction, 68 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 69 | 70 | super(RCAB, self).__init__() 71 | modules_body = [] 72 | for i in range(2): 73 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 74 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 75 | if i == 0: modules_body.append(act) 76 | modules_body.append(CALayer(n_feat, reduction)) 77 | self.body = nn.Sequential(*modules_body) 78 | self.res_scale = res_scale 79 | 80 | def forward(self, x): 81 | res = self.body(x) 82 | #res = self.body(x).mul(self.res_scale) 83 | res += x 84 | return res 85 | 86 | ## Residual Group (RG) 87 | class ResidualGroup(nn.Module): 88 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 89 | super(ResidualGroup, self).__init__() 90 | modules_body = [] 91 | modules_body = [ 92 | RCAB( 93 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 94 | for _ in range(n_resblocks)] 95 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 96 | self.body = nn.Sequential(*modules_body) 97 | 98 | def forward(self, x): 99 | res = self.body(x) 100 | res += x 101 | return res 102 | 103 | ## Residual Channel Attention Network (RCAN) 104 | class RCAN(nn.Module): 105 | def __init__(self, args, conv=default_conv): 106 | super(RCAN, self).__init__() 107 | self.args = args 108 | 109 | n_resgroups = args.n_resgroups 110 | n_resblocks = args.n_resblocks 111 | n_feats = args.n_feats 112 | kernel_size = 3 113 | reduction = args.reduction 114 | scale = args.scale[0] 115 | act = nn.ReLU(True) 116 | 117 | # RGB mean for DIV2K 118 | rgb_mean = (0.4488, 0.4371, 0.4040) 119 | rgb_std = (1.0, 1.0, 1.0) 120 | self.sub_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std) 121 | 122 | # define head module 123 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 124 | 125 | # define body module 126 | modules_body = [ 127 | ResidualGroup( 128 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 129 | for _ in range(n_resgroups)] 130 | 131 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 132 | 133 | self.add_mean = MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 134 | 135 | self.head = nn.Sequential(*modules_head) 136 | self.body = nn.Sequential(*modules_body) 137 | 138 | if args.no_upsampling: 139 | self.out_dim = n_feats 140 | else: 141 | self.out_dim = args.n_colors 142 | # define tail module 143 | modules_tail = [ 144 | Upsampler(conv, scale, n_feats, act=False), 145 | conv(n_feats, args.n_colors, kernel_size)] 146 | self.tail = nn.Sequential(*modules_tail) 147 | 148 | def forward(self, x): 149 | #x = self.sub_mean(x) 150 | x = self.head(x) 151 | 152 | res = self.body(x) 153 | res += x 154 | 155 | if self.args.no_upsampling: 156 | x = res 157 | else: 158 | x = self.tail(res) 159 | #x = self.add_mean(x) 160 | return x 161 | 162 | def load_state_dict(self, state_dict, strict=False): 163 | own_state = self.state_dict() 164 | for name, param in state_dict.items(): 165 | if name in own_state: 166 | if isinstance(param, nn.Parameter): 167 | param = param.data 168 | try: 169 | own_state[name].copy_(param) 170 | except Exception: 171 | if name.find('tail') >= 0: 172 | print('Replace pre-trained upsampler to new one...') 173 | else: 174 | raise RuntimeError('While copying the parameter named {}, ' 175 | 'whose dimensions in the model are {} and ' 176 | 'whose dimensions in the checkpoint are {}.' 177 | .format(name, own_state[name].size(), param.size())) 178 | elif strict: 179 | if name.find('tail') == -1: 180 | raise KeyError('unexpected key "{}" in state_dict' 181 | .format(name)) 182 | 183 | if strict: 184 | missing = set(own_state.keys()) - set(state_dict.keys()) 185 | if len(missing) > 0: 186 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) 187 | 188 | 189 | @register('rcan') 190 | def make_rcan(n_resgroups=10, n_resblocks=20, n_feats=64, reduction=16, 191 | scale=2, no_upsampling=False, rgb_range=1): 192 | args = Namespace() 193 | args.n_resgroups = n_resgroups 194 | args.n_resblocks = n_resblocks 195 | args.n_feats = n_feats 196 | args.reduction = reduction 197 | 198 | args.scale = [scale] 199 | args.no_upsampling = no_upsampling 200 | 201 | args.rgb_range = rgb_range 202 | args.res_scale = 1 203 | args.n_colors = 3 204 | return RCAN(args) 205 | -------------------------------------------------------------------------------- /models/rdn.py: -------------------------------------------------------------------------------- 1 | # Residual Dense Network for Image Super-Resolution 2 | # https://arxiv.org/abs/1802.08797 3 | # modified from: https://github.com/thstkdgus35/EDSR-PyTorch 4 | 5 | from argparse import Namespace 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | from models import register 11 | 12 | 13 | class RDB_Conv(nn.Module): 14 | def __init__(self, inChannels, growRate, kSize=3): 15 | super(RDB_Conv, self).__init__() 16 | Cin = inChannels 17 | G = growRate 18 | self.conv = nn.Sequential(*[ 19 | nn.Conv2d(Cin, G, kSize, padding=(kSize-1)//2, stride=1), 20 | nn.ReLU() 21 | ]) 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | return torch.cat((x, out), 1) 26 | 27 | class RDB(nn.Module): 28 | def __init__(self, growRate0, growRate, nConvLayers, kSize=3): 29 | super(RDB, self).__init__() 30 | G0 = growRate0 31 | G = growRate 32 | C = nConvLayers 33 | 34 | convs = [] 35 | for c in range(C): 36 | convs.append(RDB_Conv(G0 + c*G, G)) 37 | self.convs = nn.Sequential(*convs) 38 | 39 | # Local Feature Fusion 40 | self.LFF = nn.Conv2d(G0 + C*G, G0, 1, padding=0, stride=1) 41 | 42 | def forward(self, x): 43 | return self.LFF(self.convs(x)) + x 44 | 45 | class RDN(nn.Module): 46 | def __init__(self, args): 47 | super(RDN, self).__init__() 48 | self.args = args 49 | r = args.scale[0] 50 | G0 = args.G0 51 | kSize = args.RDNkSize 52 | 53 | # number of RDB blocks, conv layers, out channels 54 | self.D, C, G = { 55 | 'A': (20, 6, 32), 56 | 'B': (16, 8, 64), 57 | }[args.RDNconfig] 58 | 59 | # Shallow feature extraction net 60 | self.SFENet1 = nn.Conv2d(args.n_colors, G0, kSize, padding=(kSize-1)//2, stride=1) 61 | self.SFENet2 = nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 62 | 63 | # Redidual dense blocks and dense feature fusion 64 | self.RDBs = nn.ModuleList() 65 | for i in range(self.D): 66 | self.RDBs.append( 67 | RDB(growRate0 = G0, growRate = G, nConvLayers = C) 68 | ) 69 | 70 | # Global Feature Fusion 71 | self.GFF = nn.Sequential(*[ 72 | nn.Conv2d(self.D * G0, G0, 1, padding=0, stride=1), 73 | nn.Conv2d(G0, G0, kSize, padding=(kSize-1)//2, stride=1) 74 | ]) 75 | 76 | if args.no_upsampling: 77 | self.out_dim = G0 78 | else: 79 | self.out_dim = args.n_colors 80 | # Up-sampling net 81 | if r == 2 or r == 3: 82 | self.UPNet = nn.Sequential(*[ 83 | nn.Conv2d(G0, G * r * r, kSize, padding=(kSize-1)//2, stride=1), 84 | nn.PixelShuffle(r), 85 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 86 | ]) 87 | elif r == 4: 88 | self.UPNet = nn.Sequential(*[ 89 | nn.Conv2d(G0, G * 4, kSize, padding=(kSize-1)//2, stride=1), 90 | nn.PixelShuffle(2), 91 | nn.Conv2d(G, G * 4, kSize, padding=(kSize-1)//2, stride=1), 92 | nn.PixelShuffle(2), 93 | nn.Conv2d(G, args.n_colors, kSize, padding=(kSize-1)//2, stride=1) 94 | ]) 95 | else: 96 | raise ValueError("scale must be 2 or 3 or 4.") 97 | 98 | def forward(self, x): 99 | f__1 = self.SFENet1(x) 100 | x = self.SFENet2(f__1) 101 | 102 | RDBs_out = [] 103 | for i in range(self.D): 104 | x = self.RDBs[i](x) 105 | RDBs_out.append(x) 106 | 107 | x = self.GFF(torch.cat(RDBs_out,1)) 108 | x += f__1 109 | 110 | if self.args.no_upsampling: 111 | return x 112 | else: 113 | return self.UPNet(x) 114 | 115 | 116 | @register('rdn') 117 | def make_rdn(G0=64, RDNkSize=3, RDNconfig='B', 118 | scale=2, no_upsampling=False): 119 | args = Namespace() 120 | args.G0 = G0 121 | args.RDNkSize = RDNkSize 122 | args.RDNconfig = RDNconfig 123 | 124 | args.scale = [scale] 125 | args.no_upsampling = no_upsampling 126 | 127 | args.n_colors = 3 128 | return RDN(args) 129 | -------------------------------------------------------------------------------- /scripts/resize.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from tqdm import tqdm 4 | 5 | for size in [256, 128, 64, 32]: 6 | if size == 256: 7 | inp = './data1024x1024' 8 | else: 9 | inp = './256' 10 | print(size) 11 | os.mkdir(str(size)) 12 | filenames = os.listdir(inp) 13 | for filename in tqdm(filenames): 14 | Image.open(os.path.join(inp, filename)) \ 15 | .resize((size, size), Image.BICUBIC) \ 16 | .save(os.path.join('.', str(size), filename.split('.')[0] + '.png')) 17 | -------------------------------------------------------------------------------- /scripts/test-benchmark.sh: -------------------------------------------------------------------------------- 1 | echo 'set5' && 2 | echo 'x2' && 3 | python test.py --config ./configs/test/test-set5-2.yaml --model $1 --gpu $2 && 4 | echo 'x3' && 5 | python test.py --config ./configs/test/test-set5-3.yaml --model $1 --gpu $2 && 6 | echo 'x4' && 7 | python test.py --config ./configs/test/test-set5-4.yaml --model $1 --gpu $2 && 8 | echo 'x6*' && 9 | python test.py --config ./configs/test/test-set5-6.yaml --model $1 --gpu $2 && 10 | echo 'x8*' && 11 | python test.py --config ./configs/test/test-set5-8.yaml --model $1 --gpu $2 && 12 | 13 | echo 'set14' && 14 | echo 'x2' && 15 | python test.py --config ./configs/test/test-set14-2.yaml --model $1 --gpu $2 && 16 | echo 'x3' && 17 | python test.py --config ./configs/test/test-set14-3.yaml --model $1 --gpu $2 && 18 | echo 'x4' && 19 | python test.py --config ./configs/test/test-set14-4.yaml --model $1 --gpu $2 && 20 | echo 'x6*' && 21 | python test.py --config ./configs/test/test-set14-6.yaml --model $1 --gpu $2 && 22 | echo 'x8*' && 23 | python test.py --config ./configs/test/test-set14-8.yaml --model $1 --gpu $2 && 24 | 25 | echo 'b100' && 26 | echo 'x2' && 27 | python test.py --config ./configs/test/test-b100-2.yaml --model $1 --gpu $2 && 28 | echo 'x3' && 29 | python test.py --config ./configs/test/test-b100-3.yaml --model $1 --gpu $2 && 30 | echo 'x4' && 31 | python test.py --config ./configs/test/test-b100-4.yaml --model $1 --gpu $2 && 32 | echo 'x6*' && 33 | python test.py --config ./configs/test/test-b100-6.yaml --model $1 --gpu $2 && 34 | echo 'x8*' && 35 | python test.py --config ./configs/test/test-b100-8.yaml --model $1 --gpu $2 && 36 | 37 | echo 'urban100' && 38 | echo 'x2' && 39 | python test.py --config ./configs/test/test-urban100-2.yaml --model $1 --gpu $2 && 40 | echo 'x3' && 41 | python test.py --config ./configs/test/test-urban100-3.yaml --model $1 --gpu $2 && 42 | echo 'x4' && 43 | python test.py --config ./configs/test/test-urban100-4.yaml --model $1 --gpu $2 && 44 | echo 'x6*' && 45 | python test.py --config ./configs/test/test-urban100-6.yaml --model $1 --gpu $2 && 46 | echo 'x8*' && 47 | python test.py --config ./configs/test/test-urban100-8.yaml --model $1 --gpu $2 && 48 | 49 | true 50 | -------------------------------------------------------------------------------- /scripts/test-div2k.sh: -------------------------------------------------------------------------------- 1 | echo 'div2k-x2' && 2 | python test.py --config ./configs/test/test-div2k-2.yaml --model $1 --gpu $2 && 3 | echo 'div2k-x3' && 4 | python test.py --config ./configs/test/test-div2k-3.yaml --model $1 --gpu $2 && 5 | echo 'div2k-x4' && 6 | python test.py --config ./configs/test/test-div2k-4.yaml --model $1 --gpu $2 && 7 | 8 | echo 'div2k-x6*' && 9 | python test.py --config ./configs/test/test-div2k-6.yaml --model $1 --gpu $2 && 10 | echo 'div2k-x12*' && 11 | python test.py --config ./configs/test/test-div2k-12.yaml --model $1 --gpu $2 && 12 | echo 'div2k-x18*' && 13 | python test.py --config ./configs/test/test-div2k-18.yaml --model $1 --gpu $2 && 14 | echo 'div2k-x24*' && 15 | python test.py --config ./configs/test/test-div2k-24.yaml --model $1 --gpu $2 && 16 | echo 'div2k-x30*' && 17 | python test.py --config ./configs/test/test-div2k-30.yaml --model $1 --gpu $2 && 18 | 19 | true 20 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import math 4 | from functools import partial 5 | 6 | import yaml 7 | import torch 8 | from torch.utils.data import DataLoader 9 | from tqdm import tqdm 10 | 11 | import datasets 12 | import models 13 | import utils 14 | 15 | 16 | def batched_predict(model, inp, coord, cell, bsize): 17 | with torch.no_grad(): 18 | model.gen_feat(inp) 19 | n = coord.shape[1] 20 | ql = 0 21 | preds = [] 22 | while ql < n: 23 | qr = min(ql + bsize, n) 24 | pred = model.query_rgb(coord[:, ql: qr, :], cell[:, ql: qr, :]) 25 | preds.append(pred) 26 | ql = qr 27 | pred = torch.cat(preds, dim=1) 28 | return pred 29 | 30 | 31 | def eval_psnr(loader, model, data_norm=None, eval_type=None, eval_bsize=None, 32 | verbose=False): 33 | model.eval() 34 | 35 | if data_norm is None: 36 | data_norm = { 37 | 'inp': {'sub': [0], 'div': [1]}, 38 | 'gt': {'sub': [0], 'div': [1]} 39 | } 40 | t = data_norm['inp'] 41 | inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda() 42 | inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda() 43 | t = data_norm['gt'] 44 | gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda() 45 | gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda() 46 | 47 | if eval_type is None: 48 | metric_fn = utils.calc_psnr 49 | elif eval_type.startswith('div2k'): 50 | scale = int(eval_type.split('-')[1]) 51 | metric_fn = partial(utils.calc_psnr, dataset='div2k', scale=scale) 52 | elif eval_type.startswith('benchmark'): 53 | scale = int(eval_type.split('-')[1]) 54 | metric_fn = partial(utils.calc_psnr, dataset='benchmark', scale=scale) 55 | else: 56 | raise NotImplementedError 57 | 58 | val_res = utils.Averager() 59 | 60 | pbar = tqdm(loader, leave=False, desc='val') 61 | for batch in pbar: 62 | for k, v in batch.items(): 63 | batch[k] = v.cuda() 64 | 65 | inp = (batch['inp'] - inp_sub) / inp_div 66 | if eval_bsize is None: 67 | with torch.no_grad(): 68 | pred = model(inp, batch['coord'], batch['cell']) 69 | else: 70 | pred = batched_predict(model, inp, 71 | batch['coord'], batch['cell'], eval_bsize) 72 | pred = pred * gt_div + gt_sub 73 | pred.clamp_(0, 1) 74 | 75 | if eval_type is not None: # reshape for shaving-eval 76 | ih, iw = batch['inp'].shape[-2:] 77 | s = math.sqrt(batch['coord'].shape[1] / (ih * iw)) 78 | shape = [batch['inp'].shape[0], round(ih * s), round(iw * s), 3] 79 | pred = pred.view(*shape) \ 80 | .permute(0, 3, 1, 2).contiguous() 81 | batch['gt'] = batch['gt'].view(*shape) \ 82 | .permute(0, 3, 1, 2).contiguous() 83 | 84 | res = metric_fn(pred, batch['gt']) 85 | val_res.add(res.item(), inp.shape[0]) 86 | 87 | if verbose: 88 | pbar.set_description('val {:.4f}'.format(val_res.item())) 89 | 90 | return val_res.item() 91 | 92 | 93 | if __name__ == '__main__': 94 | parser = argparse.ArgumentParser() 95 | parser.add_argument('--config') 96 | parser.add_argument('--model') 97 | parser.add_argument('--gpu', default='0') 98 | args = parser.parse_args() 99 | 100 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 101 | 102 | with open(args.config, 'r') as f: 103 | config = yaml.load(f, Loader=yaml.FullLoader) 104 | 105 | spec = config['test_dataset'] 106 | dataset = datasets.make(spec['dataset']) 107 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) 108 | loader = DataLoader(dataset, batch_size=spec['batch_size'], 109 | num_workers=8, pin_memory=True) 110 | 111 | model_spec = torch.load(args.model)['model'] 112 | model = models.make(model_spec, load_sd=True).cuda() 113 | 114 | res = eval_psnr(loader, model, 115 | data_norm=config.get('data_norm'), 116 | eval_type=config.get('eval_type'), 117 | eval_bsize=config.get('eval_bsize'), 118 | verbose=True) 119 | print('result: {:.4f}'.format(res)) 120 | -------------------------------------------------------------------------------- /train_liif.py: -------------------------------------------------------------------------------- 1 | """ Train for generating LIIF, from image to implicit representation. 2 | 3 | Config: 4 | train_dataset: 5 | dataset: $spec; wrapper: $spec; batch_size: 6 | val_dataset: 7 | dataset: $spec; wrapper: $spec; batch_size: 8 | (data_norm): 9 | inp: {sub: []; div: []} 10 | gt: {sub: []; div: []} 11 | (eval_type): 12 | (eval_bsize): 13 | 14 | model: $spec 15 | optimizer: $spec 16 | epoch_max: 17 | (multi_step_lr): 18 | milestones: []; gamma: 0.5 19 | (resume): *.pth 20 | 21 | (epoch_val): ; (epoch_save): 22 | """ 23 | 24 | import argparse 25 | import os 26 | 27 | import yaml 28 | import torch 29 | import torch.nn as nn 30 | from tqdm import tqdm 31 | from torch.utils.data import DataLoader 32 | from torch.optim.lr_scheduler import MultiStepLR 33 | 34 | import datasets 35 | import models 36 | import utils 37 | from test import eval_psnr 38 | 39 | 40 | def make_data_loader(spec, tag=''): 41 | if spec is None: 42 | return None 43 | 44 | dataset = datasets.make(spec['dataset']) 45 | dataset = datasets.make(spec['wrapper'], args={'dataset': dataset}) 46 | 47 | log('{} dataset: size={}'.format(tag, len(dataset))) 48 | for k, v in dataset[0].items(): 49 | log(' {}: shape={}'.format(k, tuple(v.shape))) 50 | 51 | loader = DataLoader(dataset, batch_size=spec['batch_size'], 52 | shuffle=(tag == 'train'), num_workers=8, pin_memory=True) 53 | return loader 54 | 55 | 56 | def make_data_loaders(): 57 | train_loader = make_data_loader(config.get('train_dataset'), tag='train') 58 | val_loader = make_data_loader(config.get('val_dataset'), tag='val') 59 | return train_loader, val_loader 60 | 61 | 62 | def prepare_training(): 63 | if config.get('resume') is not None: 64 | sv_file = torch.load(config['resume']) 65 | model = models.make(sv_file['model'], load_sd=True).cuda() 66 | optimizer = utils.make_optimizer( 67 | model.parameters(), sv_file['optimizer'], load_sd=True) 68 | epoch_start = sv_file['epoch'] + 1 69 | if config.get('multi_step_lr') is None: 70 | lr_scheduler = None 71 | else: 72 | lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr']) 73 | for _ in range(epoch_start - 1): 74 | lr_scheduler.step() 75 | else: 76 | model = models.make(config['model']).cuda() 77 | optimizer = utils.make_optimizer( 78 | model.parameters(), config['optimizer']) 79 | epoch_start = 1 80 | if config.get('multi_step_lr') is None: 81 | lr_scheduler = None 82 | else: 83 | lr_scheduler = MultiStepLR(optimizer, **config['multi_step_lr']) 84 | 85 | log('model: #params={}'.format(utils.compute_num_params(model, text=True))) 86 | return model, optimizer, epoch_start, lr_scheduler 87 | 88 | 89 | def train(train_loader, model, optimizer): 90 | model.train() 91 | loss_fn = nn.L1Loss() 92 | train_loss = utils.Averager() 93 | 94 | data_norm = config['data_norm'] 95 | t = data_norm['inp'] 96 | inp_sub = torch.FloatTensor(t['sub']).view(1, -1, 1, 1).cuda() 97 | inp_div = torch.FloatTensor(t['div']).view(1, -1, 1, 1).cuda() 98 | t = data_norm['gt'] 99 | gt_sub = torch.FloatTensor(t['sub']).view(1, 1, -1).cuda() 100 | gt_div = torch.FloatTensor(t['div']).view(1, 1, -1).cuda() 101 | 102 | for batch in tqdm(train_loader, leave=False, desc='train'): 103 | for k, v in batch.items(): 104 | batch[k] = v.cuda() 105 | 106 | inp = (batch['inp'] - inp_sub) / inp_div 107 | pred = model(inp, batch['coord'], batch['cell']) 108 | 109 | gt = (batch['gt'] - gt_sub) / gt_div 110 | loss = loss_fn(pred, gt) 111 | 112 | train_loss.add(loss.item()) 113 | 114 | optimizer.zero_grad() 115 | loss.backward() 116 | optimizer.step() 117 | 118 | pred = None; loss = None 119 | 120 | return train_loss.item() 121 | 122 | 123 | def main(config_, save_path): 124 | global config, log, writer 125 | config = config_ 126 | log, writer = utils.set_save_path(save_path) 127 | with open(os.path.join(save_path, 'config.yaml'), 'w') as f: 128 | yaml.dump(config, f, sort_keys=False) 129 | 130 | train_loader, val_loader = make_data_loaders() 131 | if config.get('data_norm') is None: 132 | config['data_norm'] = { 133 | 'inp': {'sub': [0], 'div': [1]}, 134 | 'gt': {'sub': [0], 'div': [1]} 135 | } 136 | 137 | model, optimizer, epoch_start, lr_scheduler = prepare_training() 138 | 139 | n_gpus = len(os.environ['CUDA_VISIBLE_DEVICES'].split(',')) 140 | if n_gpus > 1: 141 | model = nn.parallel.DataParallel(model) 142 | 143 | epoch_max = config['epoch_max'] 144 | epoch_val = config.get('epoch_val') 145 | epoch_save = config.get('epoch_save') 146 | max_val_v = -1e18 147 | 148 | timer = utils.Timer() 149 | 150 | for epoch in range(epoch_start, epoch_max + 1): 151 | t_epoch_start = timer.t() 152 | log_info = ['epoch {}/{}'.format(epoch, epoch_max)] 153 | 154 | writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch) 155 | 156 | train_loss = train(train_loader, model, optimizer) 157 | if lr_scheduler is not None: 158 | lr_scheduler.step() 159 | 160 | log_info.append('train: loss={:.4f}'.format(train_loss)) 161 | writer.add_scalars('loss', {'train': train_loss}, epoch) 162 | 163 | if n_gpus > 1: 164 | model_ = model.module 165 | else: 166 | model_ = model 167 | model_spec = config['model'] 168 | model_spec['sd'] = model_.state_dict() 169 | optimizer_spec = config['optimizer'] 170 | optimizer_spec['sd'] = optimizer.state_dict() 171 | sv_file = { 172 | 'model': model_spec, 173 | 'optimizer': optimizer_spec, 174 | 'epoch': epoch 175 | } 176 | 177 | torch.save(sv_file, os.path.join(save_path, 'epoch-last.pth')) 178 | 179 | if (epoch_save is not None) and (epoch % epoch_save == 0): 180 | torch.save(sv_file, 181 | os.path.join(save_path, 'epoch-{}.pth'.format(epoch))) 182 | 183 | if (epoch_val is not None) and (epoch % epoch_val == 0): 184 | if n_gpus > 1 and (config.get('eval_bsize') is not None): 185 | model_ = model.module 186 | else: 187 | model_ = model 188 | val_res = eval_psnr(val_loader, model_, 189 | data_norm=config['data_norm'], 190 | eval_type=config.get('eval_type'), 191 | eval_bsize=config.get('eval_bsize')) 192 | 193 | log_info.append('val: psnr={:.4f}'.format(val_res)) 194 | writer.add_scalars('psnr', {'val': val_res}, epoch) 195 | if val_res > max_val_v: 196 | max_val_v = val_res 197 | torch.save(sv_file, os.path.join(save_path, 'epoch-best.pth')) 198 | 199 | t = timer.t() 200 | prog = (epoch - epoch_start + 1) / (epoch_max - epoch_start + 1) 201 | t_epoch = utils.time_text(t - t_epoch_start) 202 | t_elapsed, t_all = utils.time_text(t), utils.time_text(t / prog) 203 | log_info.append('{} {}/{}'.format(t_epoch, t_elapsed, t_all)) 204 | 205 | log(', '.join(log_info)) 206 | writer.flush() 207 | 208 | 209 | if __name__ == '__main__': 210 | parser = argparse.ArgumentParser() 211 | parser.add_argument('--config') 212 | parser.add_argument('--name', default=None) 213 | parser.add_argument('--tag', default=None) 214 | parser.add_argument('--gpu', default='0') 215 | args = parser.parse_args() 216 | 217 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 218 | 219 | with open(args.config, 'r') as f: 220 | config = yaml.load(f, Loader=yaml.FullLoader) 221 | print('config loaded.') 222 | 223 | save_name = args.name 224 | if save_name is None: 225 | save_name = '_' + args.config.split('/')[-1][:-len('.yaml')] 226 | if args.tag is not None: 227 | save_name += '_' + args.tag 228 | save_path = os.path.join('./save', save_name) 229 | 230 | main(config, save_path) 231 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import shutil 4 | import math 5 | 6 | import torch 7 | import numpy as np 8 | from torch.optim import SGD, Adam 9 | from tensorboardX import SummaryWriter 10 | 11 | 12 | class Averager(): 13 | 14 | def __init__(self): 15 | self.n = 0.0 16 | self.v = 0.0 17 | 18 | def add(self, v, n=1.0): 19 | self.v = (self.v * self.n + v * n) / (self.n + n) 20 | self.n += n 21 | 22 | def item(self): 23 | return self.v 24 | 25 | 26 | class Timer(): 27 | 28 | def __init__(self): 29 | self.v = time.time() 30 | 31 | def s(self): 32 | self.v = time.time() 33 | 34 | def t(self): 35 | return time.time() - self.v 36 | 37 | 38 | def time_text(t): 39 | if t >= 3600: 40 | return '{:.1f}h'.format(t / 3600) 41 | elif t >= 60: 42 | return '{:.1f}m'.format(t / 60) 43 | else: 44 | return '{:.1f}s'.format(t) 45 | 46 | 47 | _log_path = None 48 | 49 | 50 | def set_log_path(path): 51 | global _log_path 52 | _log_path = path 53 | 54 | 55 | def log(obj, filename='log.txt'): 56 | print(obj) 57 | if _log_path is not None: 58 | with open(os.path.join(_log_path, filename), 'a') as f: 59 | print(obj, file=f) 60 | 61 | 62 | def ensure_path(path, remove=True): 63 | basename = os.path.basename(path.rstrip('/')) 64 | if os.path.exists(path): 65 | if remove and (basename.startswith('_') 66 | or input('{} exists, remove? (y/[n]): '.format(path)) == 'y'): 67 | shutil.rmtree(path) 68 | os.makedirs(path) 69 | else: 70 | os.makedirs(path) 71 | 72 | 73 | def set_save_path(save_path, remove=True): 74 | ensure_path(save_path, remove=remove) 75 | set_log_path(save_path) 76 | writer = SummaryWriter(os.path.join(save_path, 'tensorboard')) 77 | return log, writer 78 | 79 | 80 | def compute_num_params(model, text=False): 81 | tot = int(sum([np.prod(p.shape) for p in model.parameters()])) 82 | if text: 83 | if tot >= 1e6: 84 | return '{:.1f}M'.format(tot / 1e6) 85 | else: 86 | return '{:.1f}K'.format(tot / 1e3) 87 | else: 88 | return tot 89 | 90 | 91 | def make_optimizer(param_list, optimizer_spec, load_sd=False): 92 | Optimizer = { 93 | 'sgd': SGD, 94 | 'adam': Adam 95 | }[optimizer_spec['name']] 96 | optimizer = Optimizer(param_list, **optimizer_spec['args']) 97 | if load_sd: 98 | optimizer.load_state_dict(optimizer_spec['sd']) 99 | return optimizer 100 | 101 | 102 | def make_coord(shape, ranges=None, flatten=True): 103 | """ Make coordinates at grid centers. 104 | """ 105 | coord_seqs = [] 106 | for i, n in enumerate(shape): 107 | if ranges is None: 108 | v0, v1 = -1, 1 109 | else: 110 | v0, v1 = ranges[i] 111 | r = (v1 - v0) / (2 * n) 112 | seq = v0 + r + (2 * r) * torch.arange(n).float() 113 | coord_seqs.append(seq) 114 | ret = torch.stack(torch.meshgrid(*coord_seqs), dim=-1) 115 | if flatten: 116 | ret = ret.view(-1, ret.shape[-1]) 117 | return ret 118 | 119 | 120 | def to_pixel_samples(img): 121 | """ Convert the image to coord-RGB pairs. 122 | img: Tensor, (3, H, W) 123 | """ 124 | coord = make_coord(img.shape[-2:]) 125 | rgb = img.view(3, -1).permute(1, 0) 126 | return coord, rgb 127 | 128 | 129 | def calc_psnr(sr, hr, dataset=None, scale=1, rgb_range=1): 130 | diff = (sr - hr) / rgb_range 131 | if dataset is not None: 132 | if dataset == 'benchmark': 133 | shave = scale 134 | if diff.size(1) > 1: 135 | gray_coeffs = [65.738, 129.057, 25.064] 136 | convert = diff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 137 | diff = diff.mul(convert).sum(dim=1) 138 | elif dataset == 'div2k': 139 | shave = scale + 6 140 | else: 141 | raise NotImplementedError 142 | valid = diff[..., shave:-shave, shave:-shave] 143 | else: 144 | valid = diff 145 | mse = valid.pow(2).mean() 146 | return -10 * torch.log10(mse) 147 | --------------------------------------------------------------------------------