├── .gitignore ├── LICENSE ├── README.md ├── dataset ├── __init__.py └── mix_dataset.py ├── demo ├── SISR │ ├── HR_imgs │ │ └── 0826.png │ └── LR_imgs │ │ └── 0826x4.png └── VSR │ ├── HR_imgs │ ├── 00000000.png │ ├── 00000001.png │ ├── 00000002.png │ ├── 00000003.png │ ├── 00000004.png │ ├── 00000005.png │ └── 00000006.png │ └── LR_imgs │ ├── 00000000.png │ ├── 00000001.png │ ├── 00000002.png │ ├── 00000003.png │ ├── 00000004.png │ ├── 00000005.png │ └── 00000006.png ├── exps ├── BebyGAN │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_A_x2 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_A_x3 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_A_x4 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_B_x2 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_B_x3 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_B_x4 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_C_x2 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_C_x3 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── LAPAR_C_x4 │ ├── config.py │ ├── network.py │ ├── train.py │ ├── train.sh │ └── validate.py ├── MuCAN_REDS │ ├── __pycache__ │ │ ├── config.cpython-35.pyc │ │ └── network.cpython-35.pyc │ ├── config.py │ └── network.py └── MuCAN_Vimeo90K │ ├── config.py │ └── network.py ├── kernel ├── kernel_14_k5.pkl ├── kernel_24_k5.pkl └── kernel_72_k5.pkl ├── requirements.txt ├── test_sample.py └── utils ├── common.py ├── data_prep ├── extract_subimage.py └── generate_lr_bic.m ├── dataloader.py ├── loss.py ├── model_opr.py ├── modules ├── discriminator.py ├── lightWeightNet.py ├── module_util.py ├── rrdb.py └── vggNet.py ├── region_seperator.py ├── resizer.py ├── samplers ├── __init__.py ├── distributed.py ├── grouped_batch_sampler.py └── iteration_based_batch_sampler.py └── solver.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | 5 | # Compiled source # 6 | ################### 7 | *.com 8 | *.class 9 | *.dll 10 | *.exe 11 | *.o 12 | *.so 13 | 14 | # Packages # 15 | ############ 16 | # it's better to unpack these files and commit the raw source 17 | # git has its own built in compression methods 18 | *.7z 19 | *.dmg 20 | *.gz 21 | *.iso 22 | *.jar 23 | *.rar 24 | *.tar 25 | *.zip 26 | 27 | # Logs and databases # 28 | ###################### 29 | *.log 30 | *.sql 31 | *.sqlite 32 | 33 | # OS generated files # 34 | ###################### 35 | .DS_Store 36 | .DS_Store? 37 | ._* 38 | .Spotlight-V100 39 | .Trashes 40 | ehthumbs.db 41 | Thumbs.db 42 | 43 | # saved models # 44 | ################ 45 | *.pth 46 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Jia Research Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Simple-SR 2 | The repository includes MuCAN, LAPAR, Beby-GAN and etc. It is designed for simple training and evaluation. 3 | 4 | --- 5 | ### Update 6 | The training code of BebyGAN and LAPAR is released. 7 | 8 | The well-trained models and all visual examples (including BebyGAN and other SOTAs) are available [here](https://drive.google.com/drive/folders/1t9GPQ61MDLOkgk-Lvez3jntciC-ZBYNS?usp=sharing). 9 | 10 | --- 11 | ### Paper 12 | 13 | #### MuCAN: Multi-Correspondence Aggregation Network for Video Super-Resolution 14 | [\[ECCV\]](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123550341.pdf) [\[arXiv\]](https://arxiv.org/abs/2007.1180) 15 | 16 | #### LAPAR: Linearly-Assembled Pixel-Adaptive Regression Network for Single Image Super-resolution and Beyond 17 | [\[NeurIPS\]](https://papers.nips.cc/paper/2020/file/eaae339c4d89fc102edd9dbdb6a28915-Paper.pdf) [\[arXiv\]](https://arxiv.org/abs/2105.10422) 18 | 19 | #### Best-Buddy GANs for Highly Detailed Image Super-Resolution 20 | [\[arXiv\]](https://arxiv.org/abs/2103.15295) 21 | 22 | 23 | Please find supplementary files of MuCAN and LAPAR [here](https://drive.google.com/drive/folders/1pSFX6kV81slv2vGkboZjewZwQsLkFesU). 24 | 25 | --- 26 | ### Usage 27 | 28 | 1. Clone the repository 29 | ```shell 30 | git clone https://github.com/Jia-Research-Lab/Simple-SR.git 31 | ``` 32 | 2. Install the dependencies 33 | - Python >= 3.6 34 | - PyTorch >= 1.2 35 | - spatial-correlation-sampler 36 | ```shell 37 | pip install spatial-correlation-sampler 38 | ``` 39 | - Other packages 40 | ```shell 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | 3. Download pretrained models from [Google Drive](https://drive.google.com/drive/folders/1c-KUEPJl7pHs9btqHYoUJkcMPKViObgJ?usp=sharing). We re-trained the LAPAR models and their results are slightly different from the ones reported in paper. 45 | - MuCAN 46 | - MuCAN\_REDS.pth: trained on REDS dataset, 5-frame input, x4 scale 47 | - MuCAN\_Vimeo90K.pth: trained on Vimeo90K dataset, 7-frame input, x4 scale 48 | - LAPAR: trained on DIV2K+Flickr2K datasets 49 | | Scale x2 | Scale x3 | Scale x4 | 50 | | :----: | :----: | :----: | 51 | | LAPAR_A_x2.pth | LAPAR_A_x3.pth | LAPAR_A_x4.pth | 52 | | LAPAR_B_x2.pth | LAPAR_B_x3.pth | LAPAR_B_x4.pth | 53 | | LAPAR_C_x2.pth | LAPAR_C_x3.pth | LAPAR_C_x4.pth | 54 | - BebyGAN 55 | - RRDB_warmup.pth: provided for initialization 56 | - BebyGAN_x4.pth: trained on DIV2K+Flickr2K datasets, x4 scale 57 | 58 | 4. Quick test 59 | 60 | You have to define the output_path or gt_path or both. If output_path is given, outputs will be saved. If gt_path is given, PSNR/SSIM will be calculated. If you want to calculate LPIPS results, please install [lpips](https://github.com/richzhang/PerceptualSimilarity) library. 61 | 62 | - For SISR, 63 | ```shell 64 | python3 test_sample.py --sr_type SISR --model_path /model/path --input_path ./demo/SISR/LR_imgs --output_path ./demo/SISR/output --gt_path ./demo/SISR/HR_imgs 65 | ``` 66 | 67 | - For VSR, 68 | ```shell 69 | python3 test_sample.py --sr_type VSR --model_path /model/path --input_path ./demo/VSR/LR_imgs --output_path ./demo/VSR/output --gt_path ./demo/VSR/HR_imgs 70 | ``` 71 | 72 | #### Prepare Data 73 | 1. Training Datasets 74 | 75 | Download [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) and [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar). You may crop the HR and LR images to sub-images for fast reading referring to .utils/data\_prep/extract\_subimage.py. 76 | 77 | 2. Evaluation Datasets 78 | 79 | Download Set5, Set14, Urban100, BSDS100 and Manga109 from [Google Drive](https://drive.google.com/drive/folders/1B3DJGQKB6eNdwuQIhdskA64qUuVKLZ9u) uploaded by BasicSR. 80 | 81 | 3. Update the dataset location in .dataset/\_\_init\_\_.py. 82 | 83 | 4. (Optional) You can convert images to lmdb files for fast loading referring to [BasicSR](https://github.com/xinntao/BasicSR/blob/master/docs/DatasetPreparation.md#LMDB-Description). And you need to modify the data reading logics in .dataset/\*dataset.py accordingly. 84 | 85 | #### Train 86 | 1. Create a log folder as 87 | ```shell 88 | mkdir logs 89 | ``` 90 | 91 | 2. Create a new experiment folder in .exps/. You just need to prepare the config.py and network.py, while the train.py and validate.py are universal. For example, for LAPAR\_A\_x2, run 92 | ```shell 93 | cd exps/LAPAR_A_x2/ 94 | bash train.sh $GPU_NUM $PORT 95 | ``` 96 | Please modify the path of initialization model of BebyGAN (found in [Google Drive](https://drive.google.com/drive/folders/1c-KUEPJl7pHs9btqHYoUJkcMPKViObgJ?usp=sharing)) in config.py before training. 97 | 98 | Notice that you can find the checkpoints, log files and visualization images in either .exps/LAPAR\_A\_x2/log/ (a soft link) or .logs/LAPAR\_A\_x2/. 99 | 100 | #### Test 101 | Please refer to validate.py in each experiment folder or quick test above. 102 | 103 | --- 104 | ### Acknowledgement 105 | We refer to [BasicSR](https://github.com/xinntao/BasicSR) for some details. 106 | 107 | --- 108 | ### Bibtex 109 | @inproceedings{li2020mucan, 110 | title={MuCAN: Multi-correspondence Aggregation Network for Video Super-Resolution}, 111 | author={Li, Wenbo and Tao, Xin and Guo, Taian and Qi, Lu and Lu, Jiangbo and Jia, Jiaya}, 112 | booktitle={European Conference on Computer Vision}, 113 | pages={335--351}, 114 | year={2020}, 115 | organization={Springer} 116 | } 117 | @article{li2020lapar, 118 | title={LAPAR: Linearly-Assembled Pixel-Adaptive Regression Network for Single Image Super-resolution and Beyond}, 119 | author={Li, Wenbo and Zhou, Kun and Qi, Lu and Jiang, Nianjuan and Lu, Jiangbo and Jia, Jiaya}, 120 | journal={Advances in Neural Information Processing Systems}, 121 | volume={33}, 122 | year={2020} 123 | } 124 | @article{li2021best, 125 | title={Best-Buddy GANs for Highly Detailed Image Super-Resolution}, 126 | author={Li, Wenbo and Zhou, Kun and Qi, Lu and Lu, Liying and Jiang, Nianjuan and Lu, Jiangbo and Jia, Jiaya}, 127 | journal={arXiv preprint arXiv:2103.15295}, 128 | year={2021} 129 | } 130 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import importlib 3 | import os 4 | 5 | from utils.common import scandir 6 | 7 | 8 | dataset_root = os.path.dirname(os.path.abspath(__file__)) 9 | dataset_filenames = [ 10 | os.path.splitext(os.path.basename(v))[0] for v in scandir(dataset_root) 11 | if v.endswith('_dataset.py') 12 | ] 13 | _dataset_modules = [ 14 | importlib.import_module(f'dataset.{file_name}') 15 | for file_name in dataset_filenames 16 | ] 17 | 18 | 19 | class DATASET: 20 | LEGAL = ['DIV2K', 'Flickr2K', 'Set5', 'Set14', 'BSDS100', 'Urban100', 'Manga109'] 21 | 22 | # training dataset 23 | DIV2K = edict() 24 | DIV2K.TRAIN = edict() 25 | DIV2K.TRAIN.HRx2 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_HR_sub' # 32208 26 | DIV2K.TRAIN.HRx3 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_HR_sub' # 32208 27 | DIV2K.TRAIN.HRx4 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_HR_sub' # 32208 28 | DIV2K.TRAIN.LRx2 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_LR_bicubic_sub/X2' 29 | DIV2K.TRAIN.LRx3 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_LR_bicubic_sub/X3' 30 | DIV2K.TRAIN.LRx4 = '/data/liwenbo/datasets/DIV2K/DIV2K_train_LR_bicubic_sub/X4' 31 | 32 | Flickr2K = edict() 33 | Flickr2K.TRAIN = edict() 34 | Flickr2K.TRAIN.HRx2 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_HR_sub' # 106641 35 | Flickr2K.TRAIN.HRx3 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_HR_sub' # 106641 36 | Flickr2K.TRAIN.HRx4 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_HR_sub' # 106641 37 | Flickr2K.TRAIN.LRx2 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_LR_bicubic_sub/X2' 38 | Flickr2K.TRAIN.LRx3 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_LR_bicubic_sub/X3' 39 | Flickr2K.TRAIN.LRx4 = '/data/liwenbo/datasets/Flickr2K/Flickr2K_LR_bicubic_sub/X4' 40 | 41 | # testing dataset 42 | Set5 = edict() 43 | Set5.VAL = edict() 44 | Set5.VAL.HRx2 = None 45 | Set5.VAL.HRx3 = None 46 | Set5.VAL.HRx4 = None 47 | Set5.VAL.LRx2 = None 48 | Set5.VAL.LRx3 = None 49 | Set5.VAL.LRx4 = None 50 | 51 | Set14 = edict() 52 | Set14.VAL = edict() 53 | Set14.VAL.HRx2 = None 54 | Set14.VAL.HRx3 = None 55 | Set14.VAL.HRx4 = None 56 | Set14.VAL.LRx2 = None 57 | Set14.VAL.LRx3 = None 58 | Set14.VAL.LRx4 = None 59 | 60 | BSDS100 = edict() 61 | BSDS100.VAL = edict() 62 | BSDS100.VAL.HRx2 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/HR/modX2' 63 | BSDS100.VAL.HRx3 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/HR/modX3' 64 | BSDS100.VAL.HRx4 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/HR/modX4' 65 | BSDS100.VAL.LRx2 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/LR_bicubic/X2' 66 | BSDS100.VAL.LRx3 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/LR_bicubic/X3' 67 | BSDS100.VAL.LRx4 = '/data/liwenbo/datasets/benchmark_SR/BSDS100/LR_bicubic/X4' 68 | 69 | Urban100 = edict() 70 | Urban100.VAL = edict() 71 | Urban100.VAL.HRx2 = None 72 | Urban100.VAL.HRx3 = None 73 | Urban100.VAL.HRx4 = None 74 | Urban100.VAL.LRx2 = None 75 | Urban100.VAL.LRx3 = None 76 | Urban100.VAL.LRx4 = None 77 | 78 | Manga109 = edict() 79 | Manga109.VAL = dict() 80 | Manga109.VAL.HRx2 = None 81 | Manga109.VAL.HRx3 = None 82 | Manga109.VAL.HRx4 = None 83 | Manga109.VAL.LRx2 = None 84 | Manga109.VAL.LRx3 = None 85 | Manga109.VAL.LRx4 = None 86 | 87 | 88 | def get_dataset(config): 89 | dataset_type = config.TYPE 90 | dataset_cls = None 91 | for module in _dataset_modules: 92 | dataset_cls = getattr(module, dataset_type, None) 93 | if dataset_cls is not None: 94 | break 95 | if dataset_cls is None: 96 | raise ValueError(f'Dataset {dataset_type} is not found.') 97 | 98 | hr_paths = [] 99 | lr_paths = [] 100 | D = DATASET() 101 | 102 | for dataset, split in zip(config.DATASETS, config.SPLITS): 103 | if dataset not in D.LEGAL or split not in eval('D.%s' % dataset): 104 | raise ValueError('Illegal dataset.') 105 | hr_paths.append(eval('D.%s.%s.HRx%d' % (dataset, split, config.SCALE))) 106 | lr_paths.append(eval('D.%s.%s.LRx%d' % (dataset, split, config.SCALE))) 107 | 108 | return dataset_cls(hr_paths, lr_paths, config) 109 | 110 | -------------------------------------------------------------------------------- /dataset/mix_dataset.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import random 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class MixDataset(Dataset): 11 | def __init__(self, hr_paths, lr_paths, config): 12 | self.hr_paths = hr_paths 13 | self.lr_paths = lr_paths 14 | self.phase = config.PHASE 15 | self.input_width, self.input_height = config.INPUT_WIDTH, config.INPUT_HEIGHT 16 | self.scale = config.SCALE 17 | self.repeat = config.REPEAT 18 | self.value_range = config.VALUE_RANGE 19 | 20 | self._load_data() 21 | 22 | def _load_data(self): 23 | assert len(self.lr_paths) == len(self.hr_paths), 'Illegal hr-lr dataset mappings.' 24 | 25 | self.hr_list = [] 26 | self.lr_list = [] 27 | for hr_path in self.hr_paths: 28 | hr_imgs = sorted(os.listdir(hr_path)) 29 | for hr_img in hr_imgs: 30 | self.hr_list.append(os.path.join(hr_path, hr_img)) 31 | for lr_path in self.lr_paths: 32 | lr_imgs = sorted(os.listdir(lr_path)) 33 | for lr_img in lr_imgs: 34 | self.lr_list.append(os.path.join(lr_path, lr_img)) 35 | 36 | assert len(self.hr_list) == len(self.lr_list), 'Illegal hr-lr mappings.' 37 | 38 | self.data_len = len(self.hr_list) 39 | self.full_len = self.data_len * self.repeat 40 | 41 | def __len__(self): 42 | return self.full_len 43 | 44 | def __getitem__(self, index): 45 | idx = index % self.data_len 46 | 47 | url_hr = self.hr_list[idx] 48 | url_lr = self.lr_list[idx] 49 | 50 | img_hr = cv2.imread(url_hr, cv2.IMREAD_COLOR) 51 | img_lr = cv2.imread(url_lr, cv2.IMREAD_COLOR) 52 | 53 | if self.phase == 'train': 54 | h, w = img_lr.shape[:2] 55 | s = self.scale 56 | 57 | # random cropping 58 | y = random.randint(0, h - self.input_height) 59 | x = random.randint(0, w - self.input_width) 60 | img_lr = img_lr[y: y + self.input_height, x: x + self.input_width, :] 61 | img_hr = img_hr[y * s: (y + self.input_height) * s, 62 | x * s: (x + self.input_width) * s, :] 63 | 64 | # horizontal flip 65 | if random.random() > 0.5: 66 | cv2.flip(img_lr, 1, img_lr) 67 | cv2.flip(img_hr, 1, img_hr) 68 | # vertical flip 69 | if random.random() > 0.5: 70 | cv2.flip(img_lr, 0, img_lr) 71 | cv2.flip(img_hr, 0, img_hr) 72 | # rotation 90 degree 73 | if random.random() > 0.5: 74 | img_lr = img_lr.transpose(1, 0, 2) 75 | img_hr = img_hr.transpose(1, 0, 2) 76 | 77 | # BGR to RGB, HWC to CHW, uint8 to float32 78 | img_lr = np.transpose(img_lr[:, :, [2, 1, 0]], (2, 0, 1)).astype(np.float32) 79 | img_hr = np.transpose(img_hr[:, :, [2, 1, 0]], (2, 0, 1)).astype(np.float32) 80 | 81 | # numpy array to tensor, [0, 255] to [0, 1] 82 | img_lr = torch.from_numpy(img_lr).float() / self.value_range 83 | img_hr = torch.from_numpy(img_hr).float() / self.value_range 84 | 85 | return img_lr, img_hr 86 | 87 | 88 | if __name__ == '__main__': 89 | from easydict import EasyDict as edict 90 | config = edict() 91 | config.PHASE = 'train' 92 | config.INPUT_WIDTH = config.INPUT_HEIGHT = 64 93 | config.SCALE = 4 94 | config.REPEAT = 1 95 | config.VALUE_RANGE = 255.0 96 | 97 | D = MixDataset(hr_paths=['/data/liwenbo/datasets/DIV2K/DIV2K_train_HR_sub'], 98 | lr_paths=['/data/liwenbo/datasets/DIV2K/DIV2K_train_LR_bicubic_sub/X4'], 99 | config=config) 100 | print(D.data_len, D.full_len) 101 | lr, hr = D.__getitem__(5) 102 | print(lr.size(), hr.size()) 103 | print('Done') 104 | -------------------------------------------------------------------------------- /demo/SISR/HR_imgs/0826.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/SISR/HR_imgs/0826.png -------------------------------------------------------------------------------- /demo/SISR/LR_imgs/0826x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/SISR/LR_imgs/0826x4.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000000.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000001.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000002.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000003.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000004.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000005.png -------------------------------------------------------------------------------- /demo/VSR/HR_imgs/00000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/HR_imgs/00000006.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000000.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000001.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000002.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000002.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000003.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000003.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000004.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000004.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000005.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000005.png -------------------------------------------------------------------------------- /demo/VSR/LR_imgs/00000006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/demo/VSR/LR_imgs/00000006.png -------------------------------------------------------------------------------- /exps/BebyGAN/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 48 12 | DATASET.INPUT_WIDTH = 48 13 | DATASET.SCALE = 4 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 8 21 | DATALOADER.NUM_WORKERS = 4 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.FLAT_KSIZE = 11 26 | MODEL.FLAT_STD = 0.025 27 | # generator 28 | MODEL.G = edict() 29 | MODEL.G.IN_CHANNEL = 3 30 | MODEL.G.OUT_CHANNEL = 3 31 | MODEL.G.N_CHANNEL = 64 32 | MODEL.G.N_BLOCK = 23 33 | MODEL.G.N_GROWTH_CHANNEL = 32 34 | # discriminator 35 | MODEL.D = edict() 36 | MODEL.D.IN_CHANNEL = 3 37 | MODEL.D.N_CHANNEL = 32 38 | MODEL.D.LOSS_TYPE = 'vanilla' # vanilla | lsgan | wgan | wgan_softplus | hinge 39 | # best buddy loss, adversarial loss, back projection loss 40 | MODEL.BBL_WEIGHT = 1.0 41 | MODEL.BBL_ALPHA = 1.0 42 | MODEL.BBL_BETA = 1.0 43 | MODEL.BBL_KSIZE = 3 44 | MODEL.BBL_PAD = 0 45 | MODEL.BBL_STRIDE = 3 46 | MODEL.BBL_TYPE = 'l1' 47 | MODEL.ADV_LOSS_WEIGHT = 0.005 48 | MODEL.BACK_PROJECTION_LOSS_WEIGHT = 1.0 49 | # Perceptual loss 50 | MODEL.USE_PCP_LOSS = True 51 | MODEL.USE_STYLE_LOSS = False 52 | MODEL.PCP_LOSS_WEIGHT = 1.0 53 | MODEL.STYLE_LOSS_WEIGHT = 0 54 | MODEL.PCP_LOSS_TYPE = 'l1' # l1 | l2 | fro 55 | MODEL.VGG_TYPE = 'vgg19' 56 | MODEL.VGG_LAYER_WEIGHTS = dict(conv3_4=1/8, conv4_4=1/4, conv5_4=1/2) # before relu 57 | MODEL.NORM_IMG = False 58 | MODEL.USE_INPUT_NORM = True 59 | # others 60 | MODEL.SCALE = DATASET.SCALE 61 | MODEL.DOWN = 1 62 | MODEL.DEVICE = 'cuda' 63 | 64 | # solver 65 | SOLVER = edict() 66 | # generator 67 | SOLVER.G_OPTIMIZER = 'Adam' 68 | SOLVER.G_BASE_LR = 1e-4 69 | SOLVER.G_BETA1 = 0.9 70 | SOLVER.G_BETA2 = 0.999 71 | SOLVER.G_WEIGHT_DECAY = 0 72 | SOLVER.G_MOMENTUM = 0 73 | SOLVER.G_STEP_ITER = 1 74 | SOLVER.G_PREPARE_ITER = 0 75 | # discriminator 76 | SOLVER.D_OPTIMIZER = 'Adam' 77 | SOLVER.D_BASE_LR = 1e-4 78 | SOLVER.D_BETA1 = 0.9 79 | SOLVER.D_BETA2 = 0.999 80 | SOLVER.D_WEIGHT_DECAY = 0 81 | SOLVER.D_MOMENTUM = 0 82 | SOLVER.D_STEP_ITER = 1 83 | # both G and D 84 | SOLVER.WARM_UP_ITER = 2000 85 | SOLVER.WARM_UP_FACTOR = 0.1 86 | SOLVER.T_PERIOD = [200000, 400000, 600000] 87 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 88 | 89 | # initialization 90 | CONTINUE_ITER = None 91 | G_INIT_MODEL = '/data/liwenbo/sisr/bebygan/pretrained/RRDB_warmup.pth' 92 | D_INIT_MODEL = None 93 | 94 | # log and save 95 | LOG_PERIOD = 20 96 | SAVE_PERIOD = 10000 97 | 98 | # validation 99 | VAL = edict() 100 | VAL.PERIOD = 10000 101 | VAL.TYPE = 'MixDataset' 102 | VAL.DATASETS = ['BSDS100'] 103 | VAL.SPLITS = ['VAL'] 104 | VAL.PHASE = 'val' 105 | VAL.INPUT_HEIGHT = None 106 | VAL.INPUT_WIDTH = None 107 | VAL.SCALE = DATASET.SCALE 108 | VAL.REPEAT = 1 109 | VAL.VALUE_RANGE = 255.0 110 | VAL.IMG_PER_GPU = 1 111 | VAL.NUM_WORKERS = 1 112 | VAL.SAVE_IMG = False 113 | VAL.TO_Y = True 114 | VAL.CROP_BORDER = VAL.SCALE 115 | 116 | 117 | config = Config() 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /exps/BebyGAN/network.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import numpy as np 3 | import os 4 | import cv2 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | from utils.modules.rrdb import RRDBNet 11 | from utils.loss import AdversarialLoss, PerceptualLoss, BBL 12 | from utils.modules.discriminator import Discriminator_VGG_192 13 | 14 | 15 | class Generator(RRDBNet): 16 | def __init__(self, config): 17 | super(Generator, self).__init__(in_nc=config.MODEL.G.IN_CHANNEL, 18 | out_nc=config.MODEL.G.OUT_CHANNEL, 19 | nf=config.MODEL.G.N_CHANNEL, 20 | nb=config.MODEL.G.N_BLOCK, 21 | gc=config.MODEL.G.N_GROWTH_CHANNEL) 22 | 23 | 24 | class Discriminator(Discriminator_VGG_192): 25 | def __init__(self, config): 26 | super(Discriminator, self).__init__(in_chl=config.MODEL.D.IN_CHANNEL, 27 | nf=config.MODEL.D.N_CHANNEL) 28 | 29 | 30 | class Network: 31 | def __init__(self, config): 32 | self.G = Generator(config) 33 | self.D = Discriminator(config) 34 | 35 | self.recon_loss_weight = config.MODEL.BBL_WEIGHT 36 | self.adv_loss_weight = config.MODEL.ADV_LOSS_WEIGHT 37 | self.bp_loss_weight = config.MODEL.BACK_PROJECTION_LOSS_WEIGHT 38 | self.use_pcp = config.MODEL.USE_PCP_LOSS 39 | self.recon_criterion = BBL(alpha=config.MODEL.BBL_ALPHA, 40 | beta=config.MODEL.BBL_BETA, 41 | ksize=config.MODEL.BBL_KSIZE, 42 | pad=config.MODEL.BBL_PAD, 43 | stride=config.MODEL.BBL_STRIDE, 44 | criterion=config.MODEL.BBL_TYPE) 45 | self.adv_criterion = AdversarialLoss(gan_type=config.MODEL.D.LOSS_TYPE) 46 | self.bp_criterion = nn.L1Loss(reduction='mean') 47 | if self.use_pcp: 48 | self.pcp_criterion = PerceptualLoss(layer_weights=config.MODEL.VGG_LAYER_WEIGHTS, 49 | vgg_type=config.MODEL.VGG_TYPE, 50 | use_input_norm=config.MODEL.USE_INPUT_NORM, 51 | use_pcp_loss=config.MODEL.USE_PCP_LOSS, 52 | use_style_loss=config.MODEL.USE_STYLE_LOSS, 53 | norm_img=config.MODEL.NORM_IMG, 54 | criterion=config.MODEL.PCP_LOSS_TYPE) 55 | self.pcp_loss_weight = config.MODEL.PCP_LOSS_WEIGHT 56 | self.style_loss_weight = config.MODEL.STYLE_LOSS_WEIGHT 57 | 58 | def set_device(self, device): 59 | self.G = self.G.to(device) 60 | self.D = self.D.to(device) 61 | self.recon_criterion = self.recon_criterion.to(device) 62 | self.adv_criterion = self.adv_criterion.to(device) 63 | self.bp_criterion = self.bp_criterion.to(device) 64 | if self.use_pcp: 65 | self.pcp_criterion = self.pcp_criterion.to(device) 66 | 67 | 68 | if __name__ == '__main__': 69 | from config import config 70 | 71 | net = Network(config) 72 | print("model have {:.3f}M paramerters in total".format(sum(x.numel() for x in net.G.parameters())/1e6)) 73 | 74 | -------------------------------------------------------------------------------- /exps/BebyGAN/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | from tensorboardX import SummaryWriter 14 | 15 | from config import config 16 | from utils import common, dataloader, solver, model_opr, resizer, region_seperator 17 | from dataset import get_dataset 18 | from network import Network 19 | from validate import validate 20 | 21 | 22 | def init_dist(local_rank): 23 | if mp.get_start_method(allow_none=True) != 'spawn': 24 | mp.set_start_method('spawn', force=True) 25 | torch.cuda.set_device(local_rank) 26 | dist.init_process_group(backend="nccl", init_method='env://') 27 | dist.barrier() 28 | 29 | 30 | def main(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument('--local_rank', type=int, default=0) 33 | args = parser.parse_args() 34 | 35 | # initialization 36 | rank = 0 37 | num_gpu = 1 38 | distributed = False 39 | if 'WORLD_SIZE' in os.environ: 40 | num_gpu = int(os.environ['WORLD_SIZE']) 41 | distributed = num_gpu > 1 42 | if distributed: 43 | rank = args.local_rank 44 | init_dist(rank) 45 | common.init_random_seed(config.DATASET.SEED) 46 | print(rank) 47 | 48 | # set up dirs and log 49 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 50 | root_dir = osp.split(exp_dir)[0] 51 | log_dir = osp.join(root_dir, 'logs', cur_dir) 52 | model_dir = osp.join(log_dir, 'models') 53 | solver_dir = osp.join(log_dir, 'solvers') 54 | if rank <= 0: 55 | common.mkdir(log_dir) 56 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 57 | if not osp.exists(ln_log_dir): 58 | os.system('ln -s %s log' % log_dir) 59 | common.mkdir(model_dir) 60 | common.mkdir(solver_dir) 61 | save_dir = osp.join(log_dir, 'saved_imgs') 62 | common.mkdir(save_dir) 63 | tb_dir = osp.join(log_dir, 'tb_log') 64 | tb_writer = SummaryWriter(tb_dir) 65 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 66 | logger = logging.getLogger('base') 67 | 68 | # dataset 69 | train_dataset = get_dataset(config.DATASET) 70 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 71 | is_dist=distributed) 72 | if rank <= 0: 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 75 | 76 | # model 77 | model = Network(config) 78 | if rank <= 0: 79 | print(model.G) 80 | print(model.D) 81 | 82 | if config.G_INIT_MODEL: 83 | if rank <= 0: 84 | logger.info('[Initing] Generator') 85 | model_opr.load_model(model.G, config.G_INIT_MODEL, strict=False, cpu=True) 86 | if config.D_INIT_MODEL: 87 | if rank <= 0: 88 | logger.info('[Initing] Discriminator') 89 | model_opr.load_model(model.D, config.D_INIT_MODEL, strict=False, cpu=True) 90 | 91 | # load models to continue training 92 | if config.CONTINUE_ITER: 93 | if rank <= 0: 94 | logger.info('[Loading] Iter: %d' % config.CONTINUE_ITER) 95 | G_model_path = osp.join(model_dir, '%d_G.pth' % config.CONTINUE_ITER) 96 | D_model_path = osp.join(model_dir, '%d_D.pth' % config.CONTINUE_ITER) 97 | model_opr.load_model(model.G, G_model_path, strict=True, cpu=True) 98 | model_opr.load_model(model.D, D_model_path, strict=True, cpu=True) 99 | 100 | device = torch.device(config.MODEL.DEVICE) 101 | model.set_device(device) 102 | if distributed: 103 | model.G = torch.nn.parallel.DistributedDataParallel(model.G, device_ids=[torch.cuda.current_device()]) 104 | model.D = torch.nn.parallel.DistributedDataParallel(model.D, device_ids=[torch.cuda.current_device()]) 105 | 106 | # solvers 107 | G_optimizer = solver.make_optimizer_sep(model.G, config.SOLVER.G_BASE_LR, type=config.SOLVER.G_OPTIMIZER, 108 | beta1=config.SOLVER.G_BETA1, beta2=config.SOLVER.G_BETA2, 109 | weight_decay=config.SOLVER.G_WEIGHT_DECAY, 110 | momentum=config.SOLVER.G_MOMENTUM, 111 | num_gpu=None) # lr without X num_gpu 112 | G_lr_scheduler = solver.CosineAnnealingLR_warmup(config, G_optimizer, config.SOLVER.G_BASE_LR) 113 | D_optimizer = solver.make_optimizer_sep(model.D, config.SOLVER.D_BASE_LR, type=config.SOLVER.D_OPTIMIZER, 114 | beta1=config.SOLVER.D_BETA1, beta2=config.SOLVER.D_BETA2, 115 | weight_decay=config.SOLVER.D_WEIGHT_DECAY, 116 | momentum=config.SOLVER.D_MOMENTUM, 117 | num_gpu=None) 118 | D_lr_scheduler = solver.CosineAnnealingLR_warmup(config, D_optimizer, config.SOLVER.D_BASE_LR) 119 | 120 | iteration = 0 121 | 122 | # load solvers to continue training 123 | if config.CONTINUE_ITER: 124 | G_solver_path = osp.join(solver_dir, '%d_G.solver' % config.CONTINUE_ITER) 125 | iteration = model_opr.load_solver(G_optimizer, G_lr_scheduler, G_solver_path) 126 | D_solver_path = osp.join(solver_dir, '%d_D.solver' % config.CONTINUE_ITER) 127 | _ = model_opr.load_solver(D_optimizer, D_lr_scheduler, D_solver_path) 128 | 129 | max_iter = max_psnr = max_ssim = 0 130 | for lr_img, hr_img in train_loader: 131 | iteration = iteration + 1 132 | 133 | model.G.train() 134 | model.D.train() 135 | 136 | lr_img = lr_img.to(device) 137 | hr_img = hr_img.to(device) 138 | 139 | flat_mask = region_seperator.get_flat_mask(lr_img, kernel_size=config.MODEL.FLAT_KSIZE, 140 | std_thresh=config.MODEL.FLAT_STD, 141 | scale=config.DATASET.SCALE) 142 | 143 | # update G first, then D 144 | for p in model.D.parameters(): 145 | p.requires_grad = False 146 | 147 | # from degraded LR to SR 148 | G_optimizer.zero_grad() 149 | output = model.G(lr_img) 150 | output_det = output * (1 - flat_mask) 151 | hr_det = hr_img * (1 - flat_mask) 152 | # degrade SR to LR 153 | bp_lr_img = resizer.imresize(output, scale=1/config.DATASET.SCALE) 154 | 155 | loss_dict = OrderedDict() 156 | 157 | # generator optimization 158 | if iteration % config.SOLVER.G_STEP_ITER == 0: 159 | gen_loss = 0.0 160 | 161 | recon_loss = model.recon_loss_weight * model.recon_criterion(output, hr_img) 162 | loss_dict['G_REC'] = recon_loss 163 | gen_loss += recon_loss 164 | 165 | # back projection loss 166 | bp_loss = model.bp_loss_weight * model.bp_criterion(bp_lr_img, lr_img) 167 | loss_dict['G_BP'] = bp_loss 168 | gen_loss += bp_loss 169 | 170 | if iteration > config.SOLVER.G_PREPARE_ITER: 171 | # perceptual / style loss 172 | if model.use_pcp: 173 | pcp_loss, style_loss,_ = model.pcp_criterion(output, hr_img) 174 | pcp_loss = model.pcp_loss_weight * pcp_loss 175 | loss_dict['G_PCP'] = pcp_loss 176 | gen_loss += pcp_loss 177 | if style_loss is not None: 178 | style_loss = model.style_loss_weight * style_loss 179 | loss_dict['G_STY'] = style_loss 180 | gen_loss += style_loss 181 | 182 | # generator adversarial loss (relativistic gan) 183 | gen_real = model.D(hr_det).detach() 184 | gen_fake = model.D(output_det) 185 | gen_real_loss = model.adv_criterion(gen_real - torch.mean(gen_fake), False, is_disc=False) * 0.5 186 | gen_fake_loss = model.adv_criterion(gen_fake - torch.mean(gen_real), True, is_disc=False) * 0.5 187 | gen_adv_loss = model.adv_loss_weight * (gen_real_loss + gen_fake_loss) 188 | loss_dict['G_ADV'] = gen_adv_loss 189 | gen_loss += gen_adv_loss 190 | 191 | gen_loss.backward() 192 | G_optimizer.step() 193 | G_lr_scheduler.step() 194 | 195 | # discriminator optimization 196 | if iteration % config.SOLVER.D_STEP_ITER == 0 and iteration > config.SOLVER.G_PREPARE_ITER: 197 | for p in model.D.parameters(): 198 | p.requires_grad = True 199 | D_optimizer.zero_grad() 200 | 201 | # discriminator loss 202 | # real 203 | dis_fake = model.D(output_det).detach() 204 | dis_real = model.D(hr_det) 205 | dis_real_loss = model.adv_criterion(dis_real - torch.mean(dis_fake), True, is_disc=True) * 0.5 206 | dis_real_loss.backward() 207 | # fake 208 | dis_fake = model.D(output_det.detach()) 209 | dis_fake_loss = model.adv_criterion(dis_fake - torch.mean(dis_real.detach()), False, is_disc=True) * 0.5 210 | dis_fake_loss.backward() 211 | 212 | loss_dict['D_ADV'] = dis_real_loss + dis_fake_loss 213 | 214 | D_optimizer.step() 215 | D_lr_scheduler.step() 216 | 217 | if rank <= 0: 218 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 219 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, G_optimizer.param_groups[0]['lr']) 220 | for key in loss_dict: 221 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 222 | log_str += key + ': %.6f, ' % float(loss_dict[key]) 223 | logger.info(log_str) 224 | 225 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 226 | logger.info('[Saving] Iter: %d' % iteration) 227 | G_model_path = osp.join(model_dir, '%d_G.pth' % iteration) 228 | D_model_path = osp.join(model_dir, '%d_D.pth' % iteration) 229 | model_opr.save_model(model.G, G_model_path) 230 | model_opr.save_model(model.D, D_model_path) 231 | G_solver_path = osp.join(solver_dir, '%d_G.solver' % iteration) 232 | D_solver_path = osp.join(solver_dir, '%d_D.solver' % iteration) 233 | model_opr.save_solver(G_optimizer, G_lr_scheduler, iteration, G_solver_path) 234 | model_opr.save_solver(D_optimizer, D_lr_scheduler, iteration, D_solver_path) 235 | 236 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 237 | logger.info('[Validating] Iter: %d' % iteration) 238 | model.G.eval() 239 | with torch.no_grad(): 240 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 241 | if psnr > max_psnr: 242 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 243 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 244 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 245 | 246 | if iteration >= config.SOLVER.MAX_ITER: 247 | break 248 | 249 | if rank <= 0: 250 | logger.info('Finish training process!') 251 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 252 | 253 | 254 | if __name__ == '__main__': 255 | main() 256 | -------------------------------------------------------------------------------- /exps/BebyGAN/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/BebyGAN/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model.G(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASETS = ['BSDS100'] 61 | config.VAL.SAVE_IMG = True 62 | 63 | model = Network(config) 64 | if torch.cuda.is_available(): 65 | device = torch.device('cuda') 66 | else: 67 | device = torch.device('cpu') 68 | model.G = model.G.to(device) 69 | 70 | model_path = 'log/models/600000_G.pth' 71 | load_model(model.G, model_path, cpu=True) 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='gt') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x2/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 2 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 32 30 | MODEL.RES_BLOCK = 4 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x2/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x2/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x2/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x2/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x3/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 3 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 32 30 | MODEL.RES_BLOCK = 4 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x3/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x3/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x3/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x3/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x4/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 4 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 32 30 | MODEL.RES_BLOCK = 4 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x4/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x4/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x4/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_A_x4/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x2/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 2 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 24 30 | MODEL.RES_BLOCK = 3 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x2/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x2/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x2/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x2/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x3/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 3 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 24 30 | MODEL.RES_BLOCK = 3 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x3/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x3/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x3/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x3/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x4/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 4 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 24 30 | MODEL.RES_BLOCK = 3 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x4/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x4/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x4/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_B_x4/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x2/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 2 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 16 30 | MODEL.RES_BLOCK = 2 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x2/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x2/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x2/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x2/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x3/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 3 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 16 30 | MODEL.RES_BLOCK = 2 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x3/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x3/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x3/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x3/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x4/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # dataset 6 | DATASET = edict() 7 | DATASET.TYPE = 'MixDataset' 8 | DATASET.DATASETS = ['DIV2K', 'Flickr2K'] 9 | DATASET.SPLITS = ['TRAIN', 'TRAIN'] 10 | DATASET.PHASE = 'train' 11 | DATASET.INPUT_HEIGHT = 64 12 | DATASET.INPUT_WIDTH = 64 13 | DATASET.SCALE = 4 14 | DATASET.REPEAT = 1 15 | DATASET.VALUE_RANGE = 255.0 16 | DATASET.SEED = 100 17 | 18 | # dataloader 19 | DATALOADER = edict() 20 | DATALOADER.IMG_PER_GPU = 32 21 | DATALOADER.NUM_WORKERS = 8 22 | 23 | # model 24 | MODEL = edict() 25 | MODEL.SCALE = DATASET.SCALE 26 | MODEL.KERNEL_SIZE = 5 27 | MODEL.KERNEL_PATH = '../../kernel/kernel_72_k5.pkl' 28 | MODEL.IN_CHANNEL = 3 29 | MODEL.N_CHANNEL = 16 30 | MODEL.RES_BLOCK = 2 31 | MODEL.N_WEIGHT = 72 32 | MODEL.DOWN = 1 33 | MODEL.DEVICE = 'cuda' 34 | 35 | # solver 36 | SOLVER = edict() 37 | SOLVER.OPTIMIZER = 'Adam' 38 | SOLVER.BASE_LR = 4e-4 39 | SOLVER.BETA1 = 0.9 40 | SOLVER.BETA2 = 0.999 41 | SOLVER.WEIGHT_DECAY = 0 42 | SOLVER.MOMENTUM = 0 43 | SOLVER.WARM_UP_ITER = 2000 44 | SOLVER.WARM_UP_FACTOR = 0.1 45 | SOLVER.T_PERIOD = [200000, 400000, 600000] 46 | SOLVER.MAX_ITER = SOLVER.T_PERIOD[-1] 47 | 48 | # initialization 49 | CONTINUE_ITER = None 50 | INIT_MODEL = None 51 | 52 | # log and save 53 | LOG_PERIOD = 20 54 | SAVE_PERIOD = 10000 55 | 56 | # validation 57 | VAL = edict() 58 | VAL.PERIOD = 10000 59 | VAL.TYPE = 'MixDataset' 60 | VAL.DATASETS = ['BSDS100'] 61 | VAL.SPLITS = ['VAL'] 62 | VAL.PHASE = 'val' 63 | VAL.INPUT_HEIGHT = None 64 | VAL.INPUT_WIDTH = None 65 | VAL.SCALE = DATASET.SCALE 66 | VAL.REPEAT = 1 67 | VAL.VALUE_RANGE = 255.0 68 | VAL.IMG_PER_GPU = 1 69 | VAL.NUM_WORKERS = 1 70 | VAL.SAVE_IMG = False 71 | VAL.TO_Y = True 72 | VAL.CROP_BORDER = VAL.SCALE 73 | 74 | 75 | config = Config() 76 | 77 | 78 | 79 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x4/network.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | from utils.modules.lightWeightNet import WeightNet 9 | 10 | 11 | class ComponentDecConv(nn.Module): 12 | def __init__(self, k_path, k_size): 13 | super(ComponentDecConv, self).__init__() 14 | 15 | kernel = pickle.load(open(k_path, 'rb')) 16 | kernel = torch.from_numpy(kernel).float().view(-1, 1, k_size, k_size) 17 | self.register_buffer('weight', kernel) 18 | 19 | def forward(self, x): 20 | out = F.conv2d(x, weight=self.weight, bias=None, stride=1, padding=0, groups=1) 21 | 22 | return out 23 | 24 | 25 | class Network(nn.Module): 26 | def __init__(self, config): 27 | super(Network, self).__init__() 28 | 29 | self.k_size = config.MODEL.KERNEL_SIZE 30 | self.s = config.MODEL.SCALE 31 | 32 | self.w_conv = WeightNet(config.MODEL) 33 | self.decom_conv = ComponentDecConv(config.MODEL.KERNEL_PATH, self.k_size) 34 | 35 | self.criterion = nn.L1Loss(reduction='mean') 36 | 37 | 38 | def forward(self, x, gt=None): 39 | B, C, H, W = x.size() 40 | 41 | bic = F.interpolate(x, scale_factor=self.s, mode='bicubic', align_corners=False) 42 | pad = self.k_size // 2 43 | x_pad = F.pad(bic, pad=(pad, pad, pad, pad), mode='reflect') 44 | pad_H, pad_W = x_pad.size()[2:] 45 | x_pad = x_pad.view(B * 3, 1, pad_H, pad_W) 46 | x_com = self.decom_conv(x_pad).view(B, 3, -1, self.s * H, self.s * W) # B, 3, N_K, Hs, Ws 47 | 48 | weight = self.w_conv(x) 49 | weight = weight.view(B, 1, -1, self.s * H, self.s * W) # B, 1, N_K, Hs, Ws 50 | 51 | out = torch.sum(weight * x_com, dim=2) 52 | 53 | if gt is not None: 54 | loss_dict = dict(L1=self.criterion(out, gt)) 55 | return loss_dict 56 | else: 57 | return out 58 | 59 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x4/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import os.path as osp 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | import logging 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from tensorboardX import SummaryWriter 12 | 13 | from config import config 14 | from utils import common, dataloader, solver, model_opr 15 | from dataset import get_dataset 16 | from network import Network 17 | from validate import validate 18 | 19 | 20 | def init_dist(local_rank): 21 | if mp.get_start_method(allow_none=True) != 'spawn': 22 | mp.set_start_method('spawn', force=True) 23 | torch.cuda.set_device(local_rank) 24 | dist.init_process_group(backend="nccl", init_method='env://') 25 | dist.barrier() 26 | 27 | 28 | def main(): 29 | parser = argparse.ArgumentParser() 30 | parser.add_argument('--local_rank', type=int, default=0) 31 | args = parser.parse_args() 32 | 33 | # initialization 34 | rank = 0 35 | num_gpu = 1 36 | distributed = False 37 | if 'WORLD_SIZE' in os.environ: 38 | num_gpu = int(os.environ['WORLD_SIZE']) 39 | distributed = num_gpu > 1 40 | if distributed: 41 | rank = args.local_rank 42 | init_dist(rank) 43 | common.init_random_seed(config.DATASET.SEED + rank) 44 | 45 | # set up dirs and log 46 | exp_dir, cur_dir = osp.split(osp.split(osp.realpath(__file__))[0]) 47 | root_dir = osp.split(exp_dir)[0] 48 | log_dir = osp.join(root_dir, 'logs', cur_dir) 49 | model_dir = osp.join(log_dir, 'models') 50 | solver_dir = osp.join(log_dir, 'solvers') 51 | if rank <= 0: 52 | common.mkdir(log_dir) 53 | ln_log_dir = osp.join(exp_dir, cur_dir, 'log') 54 | if not osp.exists(ln_log_dir): 55 | os.system('ln -s %s log' % log_dir) 56 | common.mkdir(model_dir) 57 | common.mkdir(solver_dir) 58 | save_dir = osp.join(log_dir, 'saved_imgs') 59 | common.mkdir(save_dir) 60 | tb_dir = osp.join(log_dir, 'tb_log') 61 | tb_writer = SummaryWriter(tb_dir) 62 | common.setup_logger('base', log_dir, 'train', level=logging.INFO, screen=True, to_file=True) 63 | logger = logging.getLogger('base') 64 | 65 | # dataset 66 | train_dataset = get_dataset(config.DATASET) 67 | train_loader = dataloader.train_loader(train_dataset, config, rank=rank, seed=config.DATASET.SEED, 68 | is_dist=distributed) 69 | if rank <= 0: 70 | val_dataset = get_dataset(config.VAL) 71 | val_loader = dataloader.val_loader(val_dataset, config, rank, 1) 72 | data_len = val_dataset.data_len 73 | 74 | # model 75 | model = Network(config) 76 | if rank <= 0: 77 | print(model) 78 | 79 | if config.CONTINUE_ITER: 80 | model_path = osp.join(model_dir, '%d.pth' % config.CONTINUE_ITER) 81 | if rank <= 0: 82 | logger.info('[Continue] Iter: %d' % config.CONTINUE_ITER) 83 | model_opr.load_model(model, model_path, strict=True, cpu=True) 84 | elif config.INIT_MODEL: 85 | if rank <= 0: 86 | logger.info('[Initialize] Model: %s' % config.INIT_MODEL) 87 | model_opr.load_model(model, config.INIT_MODEL, strict=True, cpu=True) 88 | 89 | device = torch.device(config.MODEL.DEVICE) 90 | model.to(device) 91 | if distributed: 92 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[torch.cuda.current_device()]) 93 | 94 | # solvers 95 | optimizer = solver.make_optimizer(config, model) # lr without X num_gpu 96 | lr_scheduler = solver.CosineAnnealingLR_warmup(config, optimizer, config.SOLVER.BASE_LR) 97 | iteration = 0 98 | 99 | if config.CONTINUE_ITER: 100 | solver_path = osp.join(solver_dir, '%d.solver' % config.CONTINUE_ITER) 101 | iteration = model_opr.load_solver(optimizer, lr_scheduler, solver_path) 102 | 103 | max_iter = max_psnr = max_ssim = 0 104 | for lr_img, hr_img in train_loader: 105 | model.train() 106 | iteration = iteration + 1 107 | 108 | optimizer.zero_grad() 109 | 110 | lr_img = lr_img.to(device) 111 | hr_img = hr_img.to(device) 112 | 113 | loss_dict = model(lr_img, gt=hr_img) 114 | total_loss = sum(loss for loss in loss_dict.values()) 115 | total_loss.backward() 116 | 117 | optimizer.step() 118 | lr_scheduler.step() 119 | 120 | if rank <= 0: 121 | if iteration % config.LOG_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 122 | log_str = 'Iter: %d, LR: %.3e, ' % (iteration, optimizer.param_groups[0]['lr']) 123 | for key in loss_dict: 124 | tb_writer.add_scalar(key, loss_dict[key].mean(), global_step=iteration) 125 | log_str += key + ': %.4f, ' % float(loss_dict[key]) 126 | logger.info(log_str) 127 | 128 | if iteration % config.SAVE_PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 129 | logger.info('[Saving] Iter: %d' % iteration) 130 | model_path = osp.join(model_dir, '%d.pth' % iteration) 131 | solver_path = osp.join(solver_dir, '%d.solver' % iteration) 132 | model_opr.save_model(model, model_path) 133 | model_opr.save_solver(optimizer, lr_scheduler, iteration, solver_path) 134 | 135 | if iteration % config.VAL.PERIOD == 0 or iteration == config.SOLVER.MAX_ITER: 136 | logger.info('[Validating] Iter: %d' % iteration) 137 | model.eval() 138 | with torch.no_grad(): 139 | psnr, ssim = validate(model, val_loader, config, device, iteration, save_path=save_dir) 140 | if psnr > max_psnr: 141 | max_psnr, max_ssim, max_iter = psnr, ssim, iteration 142 | logger.info('[Val Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (iteration, psnr, ssim)) 143 | logger.info('[Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 144 | 145 | if iteration >= config.SOLVER.MAX_ITER: 146 | break 147 | 148 | if rank <= 0: 149 | logger.info('Finish training process!') 150 | logger.info('[Final Best Result] Iter: %d, PSNR: %.4f, SSIM: %.4f' % (max_iter, max_psnr, max_ssim)) 151 | 152 | 153 | if __name__ == '__main__': 154 | main() 155 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x4/train.sh: -------------------------------------------------------------------------------- 1 | python3 -m torch.distributed.launch --nproc_per_node=$1 --master_port=$2 train.py 2 | -------------------------------------------------------------------------------- /exps/LAPAR_C_x4/validate.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import os 3 | import numpy as np 4 | import sys 5 | sys.path.insert(0, os.path.dirname(os.path.realpath(__file__)) + '/../..') 6 | 7 | import torch 8 | 9 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 10 | 11 | 12 | def validate(model, val_loader, config, device, iteration, save_path='.'): 13 | with torch.no_grad(): 14 | psnr_l = [] 15 | ssim_l = [] 16 | 17 | for idx, (lr_img, hr_img) in enumerate(val_loader): 18 | lr_img = lr_img.to(device) 19 | hr_img = hr_img.to(device) 20 | 21 | output = model(lr_img) 22 | 23 | output = tensor2img(output) 24 | gt = tensor2img(hr_img) 25 | 26 | if config.VAL.SAVE_IMG: 27 | ipath = os.path.join(save_path, '%d_%03d.png' % (iteration, idx)) 28 | cv2.imwrite(ipath, np.concatenate([output, gt], axis=1)) 29 | 30 | output = output.astype(np.float32) / 255.0 31 | gt = gt.astype(np.float32) / 255.0 32 | 33 | if config.VAL.TO_Y: 34 | output = bgr2ycbcr(output, only_y=True) 35 | gt = bgr2ycbcr(gt, only_y=True) 36 | 37 | if config.VAL.CROP_BORDER != 0: 38 | cb = config.VAL.CROP_BORDER 39 | output = output[cb:-cb, cb:-cb] 40 | gt = gt[cb:-cb, cb:-cb] 41 | 42 | psnr = calculate_psnr(output * 255, gt * 255) 43 | ssim = calculate_ssim(output * 255, gt * 255) 44 | psnr_l.append(psnr) 45 | ssim_l.append(ssim) 46 | 47 | avg_psnr = sum(psnr_l) / len(psnr_l) 48 | avg_ssim = sum(ssim_l) / len(ssim_l) 49 | 50 | return avg_psnr, avg_ssim 51 | 52 | 53 | if __name__ == '__main__': 54 | from config import config 55 | from network import Network 56 | from dataset import get_dataset 57 | from utils import dataloader 58 | from utils.model_opr import load_model 59 | 60 | config.VAL.DATASET = 'Set5' 61 | 62 | model = Network(config) 63 | if torch.cuda.is_available(): 64 | device = torch.device('cuda') 65 | else: 66 | device = torch.device('cpu') 67 | model = model.to(device) 68 | 69 | model_path = 'log/models/200000.pth' 70 | load_model(model, model_path, cpu=True) 71 | sys.exit() 72 | 73 | val_dataset = get_dataset(config.VAL) 74 | val_loader = dataloader.val_loader(val_dataset, config, 0, 1) 75 | psnr, ssim = validate(model, val_loader, config, device, 0, save_path='.') 76 | print('PSNR: %.4f, SSIM: %.4f' % (psnr, ssim)) 77 | -------------------------------------------------------------------------------- /exps/MuCAN_REDS/__pycache__/config.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/exps/MuCAN_REDS/__pycache__/config.cpython-35.pyc -------------------------------------------------------------------------------- /exps/MuCAN_REDS/__pycache__/network.cpython-35.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/exps/MuCAN_REDS/__pycache__/network.cpython-35.pyc -------------------------------------------------------------------------------- /exps/MuCAN_REDS/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | 3 | 4 | class Config: 5 | # model 6 | MODEL = edict() 7 | MODEL.N_FRAME = 5 8 | MODEL.SCALE = 4 9 | MODEL.IN_CHANNEL = 3 10 | MODEL.OUT_CHANNEL = 3 11 | MODEL.N_CHANNEL = 128 12 | MODEL.FRONT_BLOCK = 5 13 | MODEL.NEAREST_NEIGHBOR = 4 14 | MODEL.N_GROUP = 8 15 | MODEL.KERNELS = [3, 3, 3, 3] 16 | MODEL.PATCHES = [7, 11, 15] 17 | MODEL.CORRELATION_KERNEL = 3 18 | MODEL.BACK_BLOCK = 40 19 | MODEL.N_LEVEL= 3 20 | MODEL.DOWN = 4 21 | MODEL.DEVICE = 'cuda' 22 | 23 | 24 | config = Config() 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /exps/MuCAN_Vimeo90K/config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import numpy as np 3 | 4 | 5 | class Config: 6 | # model 7 | MODEL = edict() 8 | MODEL.N_FRAME = 7 9 | MODEL.SCALE = 4 10 | MODEL.IN_CHANNEL = 3 11 | MODEL.OUT_CHANNEL = 3 12 | MODEL.N_CHANNEL = 128 13 | MODEL.FRONT_BLOCK = 5 14 | MODEL.NEAREST_NEIGHBOR = 4 15 | MODEL.N_GROUP = 8 16 | MODEL.KERNELS = [3, 3, 3, 3] 17 | MODEL.PATCHES = [7, 11, 15] 18 | MODEL.CORRELATION_KERNEL = 3 19 | MODEL.BACK_BLOCK = 20 20 | MODEL.N_LEVEL= 3 21 | MODEL.DOWN = 4 22 | MODEL.DEVICE = 'cuda' 23 | 24 | 25 | config = Config() 26 | 27 | 28 | 29 | -------------------------------------------------------------------------------- /kernel/kernel_14_k5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/kernel/kernel_14_k5.pkl -------------------------------------------------------------------------------- /kernel/kernel_24_k5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/kernel/kernel_24_k5.pkl -------------------------------------------------------------------------------- /kernel/kernel_72_k5.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/Simple-SR/08c71e9e46ba781df50893f0476ecd0fc004aa45/kernel/kernel_72_k5.pkl -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | easydict 2 | future 3 | matplotlib 4 | numpy 5 | opencv-python 6 | scikit-image 7 | scipy 8 | torch>=1.2 9 | torchvision 10 | -------------------------------------------------------------------------------- /test_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import numpy as np 4 | import os 5 | import sys 6 | 7 | import torch 8 | 9 | from utils.model_opr import load_model 10 | from utils.common import tensor2img, calculate_psnr, calculate_ssim, bgr2ycbcr 11 | 12 | 13 | def get_network(model_path): 14 | if 'REDS' in model_path: 15 | from exps.MuCAN_REDS.config import config 16 | from exps.MuCAN_REDS.network import Network 17 | elif 'Vimeo' in model_path: 18 | from exps.MuCAN_Vimeo90K.config import config 19 | from exps.MuCAN_Vimeo90K.network import Network 20 | elif 'LAPAR_A_x2' in model_path: 21 | from exps.LAPAR_A_x2.config import config 22 | from exps.LAPAR_A_x2.network import Network 23 | elif 'LAPAR_A_x3' in model_path: 24 | from exps.LAPAR_A_x3.config import config 25 | from exps.LAPAR_A_x3.network import Network 26 | elif 'LAPAR_A_x4' in model_path: 27 | from exps.LAPAR_A_x4.config import config 28 | from exps.LAPAR_A_x4.network import Network 29 | elif 'LAPAR_B_x2' in model_path: 30 | from exps.LAPAR_B_x2.config import config 31 | from exps.LAPAR_B_x2.network import Network 32 | elif 'LAPAR_B_x3' in model_path: 33 | from exps.LAPAR_B_x3.config import config 34 | from exps.LAPAR_B_x3.network import Network 35 | elif 'LAPAR_B_x4' in model_path: 36 | from exps.LAPAR_B_x4.config import config 37 | from exps.LAPAR_B_x4.network import Network 38 | elif 'LAPAR_C_x2' in model_path: 39 | from exps.LAPAR_C_x2.config import config 40 | from exps.LAPAR_C_x2.network import Network 41 | elif 'LAPAR_C_x3' in model_path: 42 | from exps.LAPAR_C_x3.config import config 43 | from exps.LAPAR_C_x3.network import Network 44 | elif 'LAPAR_C_x4' in model_path: 45 | from exps.LAPAR_C_x4.config import config 46 | from exps.LAPAR_C_x4.network import Network 47 | elif 'BebyGAN_x4' in model_path: 48 | from exps.BebyGAN.config import config 49 | from exps.BebyGAN.network import Network 50 | else: 51 | print('Illenal model: not implemented!') 52 | sys.exit(1) 53 | 54 | # an ugly operation 55 | if 'KERNEL_PATH' in config.MODEL: 56 | config.MODEL.KERNEL_PATH = config.MODEL.KERNEL_PATH.replace('../', '') 57 | 58 | if 'BebyGAN' in model_path: 59 | return config, Network(config).G 60 | 61 | return config, Network(config) 62 | 63 | 64 | if __name__ == '__main__': 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('--sr_type', type=str, default='SISR') 67 | parser.add_argument('--model_path', type=str, default=None) 68 | parser.add_argument('--input_path', type=str, default=None) 69 | parser.add_argument('--output_path', type=str, default=None) 70 | parser.add_argument('--gt_path', type=str, default=None) 71 | args = parser.parse_args() 72 | 73 | if args.output_path and not os.path.exists(args.output_path): 74 | os.makedirs(args.output_path) 75 | 76 | print('Loading Network ...') 77 | config, model = get_network(args.model_path) 78 | device = torch.device('cuda') 79 | model = model.to(device) 80 | load_model(model, args.model_path, strict=True) 81 | 82 | down = config.MODEL.DOWN 83 | scale = config.MODEL.SCALE 84 | 85 | print('Reading Images ...') 86 | ipath_l = [] 87 | for f in sorted(os.listdir(args.input_path)): 88 | if f.endswith('png') or f.endswith('jpg'): 89 | ipath_l.append(os.path.join(args.input_path, f)) 90 | 91 | if args.gt_path: 92 | gpath_l = [] 93 | for f in sorted(os.listdir(args.gt_path)): 94 | if f.endswith('png') or f.endswith('jpg'): 95 | gpath_l.append(os.path.join(args.gt_path, f)) 96 | psnr_l = [] 97 | ssim_l = [] 98 | 99 | if args.sr_type == 'SISR': 100 | with torch.no_grad(): 101 | for i, f in enumerate(ipath_l): 102 | img_name = f.split('/')[-1] 103 | print('Processing: %s' % img_name) 104 | lr_img = cv2.imread(f, cv2.IMREAD_COLOR) 105 | lr_img = np.transpose(lr_img[:, :, ::-1], (2, 0, 1)).astype(np.float32) / 255.0 106 | lr_img = torch.from_numpy(lr_img).float().to(device).unsqueeze(0) 107 | 108 | _, C, H, W = lr_img.size() 109 | 110 | need_pad = False 111 | if H % down != 0 or W % down != 0: 112 | need_pad = True 113 | pad_y_t = (down - H % down) % down // 2 114 | pad_y_b = (down - H % down) % down - pad_y_t 115 | pad_x_l = (down - W % down) % down // 2 116 | pad_x_r = (down - W % down) % down - pad_x_l 117 | lr_img = torch.nn.functional.pad(lr_img, pad=(pad_x_l, pad_x_r, pad_y_t, pad_y_b), mode='replicate') 118 | 119 | output = model(lr_img) 120 | 121 | if need_pad: 122 | y_end = -pad_y_b * scale if pad_y_b != 0 else output.size(2) 123 | x_end = -pad_x_r * scale if pad_x_r != 0 else output.size(3) 124 | output = output[:, :, pad_y_t * scale: y_end, pad_x_l * scale: x_end] 125 | 126 | output = tensor2img(output) 127 | if args.output_path: 128 | output_path = os.path.join(args.output_path, img_name) 129 | cv2.imwrite(output_path, output) 130 | 131 | if args.gt_path: 132 | output = output.astype(np.float32) / 255.0 133 | gt = cv2.imread(gpath_l[i], cv2.IMREAD_COLOR).astype(np.float32) / 255.0 134 | 135 | # to y channel 136 | output = bgr2ycbcr(output, only_y=True) 137 | gt = bgr2ycbcr(gt, only_y=True) 138 | 139 | output = output[scale:-scale, scale:-scale] 140 | gt = gt[scale:-scale, scale:-scale] 141 | 142 | psnr = calculate_psnr(output * 255, gt * 255) 143 | ssim = calculate_ssim(output * 255, gt * 255) 144 | 145 | psnr_l.append(psnr) 146 | ssim_l.append(ssim) 147 | 148 | elif args.sr_type == 'VSR': 149 | num_img = len(ipath_l) 150 | 151 | half_n = config.MODEL.N_FRAME // 2 152 | with torch.no_grad(): 153 | for i, f in enumerate(ipath_l): 154 | img_name = f.split('/')[-1] 155 | print('Processing: %s' % img_name) 156 | nbr_l = [] 157 | for j in range(i - half_n, i + half_n + 1): 158 | if j < 0: 159 | ipath = ipath_l[i + half_n - j] 160 | elif j >= num_img: 161 | ipath = ipath_l[i - half_n - (j - num_img + 1)] 162 | else: 163 | ipath = ipath_l[j] 164 | nbr_img = cv2.imread(ipath, cv2.IMREAD_COLOR) 165 | nbr_l.append(nbr_img) 166 | lr_imgs = np.stack(nbr_l, axis=0) 167 | lr_imgs = np.transpose(lr_imgs[:, :, :, ::-1], (0, 3, 1, 2)).astype(np.float32) / 255.0 168 | lr_imgs = torch.from_numpy(lr_imgs).float().to(device) 169 | 170 | N, C, H, W = lr_imgs.size() 171 | 172 | need_pad = False 173 | if H % down != 0 or W % down != 0: 174 | need_pad = True 175 | pad_y_t = (down - H % down) % down // 2 176 | pad_y_b = (down - H % down) % down - pad_y_t 177 | pad_x_l = (down - W % down) % down // 2 178 | pad_x_r = (down - W % down) % down - pad_x_l 179 | lr_imgs = torch.nn.functional.pad(lr_imgs, pad=(pad_x_l, pad_x_r, pad_y_t, pad_y_b), mode='replicate') 180 | lr_imgs = lr_imgs.unsqueeze(0) 181 | 182 | output = model(lr_imgs) 183 | 184 | if need_pad: 185 | y_end = -pad_y_b * scale if pad_y_b != 0 else output.size(2) 186 | x_end = -pad_x_r * scale if pad_x_r != 0 else output.size(3) 187 | output = output[:, :, pad_y_t * scale: y_end, pad_x_l * scale: x_end] 188 | 189 | output = tensor2img(output) 190 | if args.output_path: 191 | output_path = os.path.join(args.output_path, img_name) 192 | cv2.imwrite(output_path, output) 193 | 194 | if args.gt_path: 195 | output = output.astype(np.float32) / 255.0 196 | gt = cv2.imread(gpath_l[i], cv2.IMREAD_COLOR).astype(np.float32) / 255.0 197 | 198 | # to y channel 199 | output = bgr2ycbcr(output, only_y=True) 200 | gt = bgr2ycbcr(gt, only_y=True) 201 | 202 | output = output[scale:-scale, scale:-scale] 203 | gt = gt[scale:-scale, scale:-scale] 204 | 205 | psnr = calculate_psnr(output * 255, gt * 255) 206 | ssim = calculate_ssim(output * 255, gt * 255) 207 | 208 | psnr_l.append(psnr) 209 | ssim_l.append(ssim) 210 | 211 | else: 212 | print('Illenal SR type: not implemented!') 213 | sys.exit(1) 214 | 215 | if args.gt_path: 216 | avg_psnr = sum(psnr_l) / len(psnr_l) 217 | avg_ssim = sum(ssim_l) / len(ssim_l) 218 | print('--------- Result ---------') 219 | print('PSNR: %.2f, SSIM:%.4f' % (avg_psnr, avg_ssim)) 220 | 221 | print('Finished!') 222 | 223 | -------------------------------------------------------------------------------- /utils/data_prep/extract_subimage.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from multiprocessing import Pool 3 | import numpy as np 4 | import os 5 | import sys 6 | from utils.common import ProgressBar 7 | 8 | 9 | def main(): 10 | """A multi-thread tool to crop sub imags.""" 11 | input_folder = '/data/liwenbo/datasets/DIV2K/DIV2K_train_HR' # original image folder 12 | save_folder = '/data/liwenbo/datasets/DIV2K/DIV2K_train_HR_sub' # the created sub-image folder 13 | n_thread = 20 14 | 15 | # for hr 16 | crop_sz = 480 17 | step = 240 18 | thres_sz = 48 19 | 20 | # # for lrx2 21 | # crop_sz = 240 22 | # step = 120 23 | # thres_sz = 24 24 | 25 | # # for lrx3 26 | # crop_sz = 160 27 | # step = 80 28 | # thres_sz = 16 29 | 30 | # # for lrx4 31 | # crop_sz = 120 32 | # step = 60 33 | # thres_sz = 12 34 | 35 | compression_level = 3 # 3 is the default value in cv2 36 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 37 | # compression time. If read raw images during training, use 0 for faster IO speed. 38 | 39 | if not os.path.exists(save_folder): 40 | os.makedirs(save_folder) 41 | print('mkdir [{:s}] ...'.format(save_folder)) 42 | else: 43 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 44 | sys.exit(1) 45 | 46 | img_list = [] 47 | for root, _, file_list in sorted(os.walk(input_folder)): 48 | path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder 49 | img_list.extend(path) 50 | 51 | def update(arg): 52 | pbar.update(arg) 53 | 54 | pbar = ProgressBar(len(img_list)) 55 | 56 | pool = Pool(n_thread) 57 | for path in img_list: 58 | pool.apply_async(worker, 59 | args=(path, save_folder, crop_sz, step, thres_sz, compression_level), 60 | callback=update) 61 | pool.close() 62 | pool.join() 63 | print('All subprocesses done.') 64 | 65 | 66 | def worker(path, save_folder, crop_sz, step, thres_sz, compression_level): 67 | img_name = os.path.basename(path) 68 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 69 | 70 | n_channels = len(img.shape) 71 | if n_channels == 2: 72 | h, w = img.shape 73 | elif n_channels == 3: 74 | h, w, c = img.shape 75 | else: 76 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 77 | 78 | h_space = np.arange(0, h - crop_sz + 1, step) 79 | if h - (h_space[-1] + crop_sz) > thres_sz: 80 | h_space = np.append(h_space, h - crop_sz) 81 | w_space = np.arange(0, w - crop_sz + 1, step) 82 | if w - (w_space[-1] + crop_sz) > thres_sz: 83 | w_space = np.append(w_space, w - crop_sz) 84 | 85 | index = 0 86 | for x in h_space: 87 | for y in w_space: 88 | index += 1 89 | if n_channels == 2: 90 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 91 | else: 92 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :] 93 | crop_img = np.ascontiguousarray(crop_img) 94 | # var = np.var(crop_img / 255) 95 | # if var > 0.008: 96 | # print(img_name, index_str, var) 97 | if 'x' in img_name: 98 | iid = img_name.split('x')[0] 99 | new_name = img_name.replace(iid, iid + '_s{:03d}'.format(index)) 100 | else: 101 | new_name = img_name.replace('.png', '_s{:03d}.png'.format(index)) 102 | cv2.imwrite( 103 | os.path.join(save_folder, new_name), 104 | crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]) 105 | return 'Processing {:s} ...'.format(img_name) 106 | 107 | 108 | if __name__ == '__main__': 109 | main() 110 | -------------------------------------------------------------------------------- /utils/data_prep/generate_lr_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub'; 7 | % save_mod_folder = ''; 8 | save_LR_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | img = img(1:sz(1), 1:sz(2)); 76 | else 77 | tmpsz = size(img); 78 | sz = tmpsz(1:2); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2),:); 81 | end 82 | end -------------------------------------------------------------------------------- /utils/dataloader.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import math 3 | import numpy as np 4 | import random 5 | 6 | import torch 7 | 8 | from utils import samplers 9 | 10 | 11 | def worker_init_fn_seed(worker_id, num_workers, rank, seed): 12 | worker_seed = num_workers * rank + worker_id + seed 13 | np.random.seed(worker_seed) 14 | random.seed(worker_seed) 15 | 16 | 17 | def train_loader(dataset, config, rank, seed=None, is_dist=True, is_shuffle=True, start_iter=0): 18 | if is_dist: 19 | sampler = samplers.DistributedSampler(dataset, shuffle=is_shuffle) 20 | elif is_shuffle: 21 | sampler = torch.utils.data.sampler.RandomSampler(dataset) 22 | else: 23 | sampler = torch.utils.data.sampler.SequentialSampler(dataset) 24 | 25 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, config.DATALOADER.IMG_PER_GPU, drop_last=False) 26 | batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, config.SOLVER.MAX_ITER, start_iter) 27 | 28 | if seed is not None: 29 | worker_init_fn = partial(worker_init_fn_seed, num_workers=config.DATALOADER.NUM_WORKERS, rank=rank, seed=seed) 30 | else: 31 | worker_init_fn = None 32 | loader = torch.utils.data.DataLoader(dataset, num_workers=config.DATALOADER.NUM_WORKERS, 33 | batch_sampler=batch_sampler, worker_init_fn=worker_init_fn) 34 | 35 | return loader 36 | 37 | 38 | def val_loader(dataset, config, local_rank, num_gpu): 39 | num_data = len(dataset) 40 | data_per_gpu = math.ceil(num_data / num_gpu) 41 | st = local_rank * data_per_gpu 42 | ed = min(num_data, st + data_per_gpu) 43 | indices = range(st, ed) 44 | subset = torch.utils.data.Subset(dataset, indices) 45 | 46 | sampler = torch.utils.data.sampler.SequentialSampler(subset) 47 | batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, config.VAL.IMG_PER_GPU, drop_last=False) 48 | 49 | loader = torch.utils.data.DataLoader(subset, num_workers=config.VAL.NUM_WORKERS, batch_sampler=batch_sampler) 50 | 51 | return loader 52 | 53 | -------------------------------------------------------------------------------- /utils/model_opr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import DataParallel 3 | from torch.nn.parallel import DistributedDataParallel 4 | 5 | 6 | def load_model(model, model_path, strict=True, cpu=False): 7 | if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): 8 | model = model.module 9 | if cpu: 10 | loaded_model = torch.load(model_path, map_location='cpu') 11 | else: 12 | loaded_model = torch.load(model_path) 13 | model.load_state_dict(loaded_model, strict=strict) 14 | 15 | 16 | def load_solver(optimizer, lr_scheduler, solver_path): 17 | loaded_solver = torch.load(solver_path) 18 | loaded_optimizer = loaded_solver['optimizer'] 19 | loaded_lr_scheduler = loaded_solver['lr_scheduler'] 20 | iteration = loaded_solver['iteration'] 21 | optimizer.load_state_dict(loaded_optimizer) 22 | lr_scheduler.load_state_dict(loaded_lr_scheduler) 23 | 24 | return iteration 25 | 26 | 27 | def save_model(model, model_path): 28 | if isinstance(model, DataParallel) or isinstance(model, DistributedDataParallel): 29 | model = model.module 30 | torch.save(model.state_dict(), model_path) 31 | 32 | 33 | def save_solver(optimizer, lr_scheduler, iteration, solver_path): 34 | solver = dict() 35 | solver['optimizer'] = optimizer.state_dict() 36 | solver['lr_scheduler'] = lr_scheduler.state_dict() 37 | solver['iteration'] = iteration 38 | torch.save(solver, solver_path) 39 | -------------------------------------------------------------------------------- /utils/modules/lightWeightNet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class Scale(nn.Module): 8 | def __init__(self, init_value=1e-3): 9 | super(Scale, self).__init__() 10 | self.scale = Parameter(torch.FloatTensor([init_value])) 11 | 12 | def forward(self, x): 13 | return x * self.scale 14 | 15 | 16 | class AWRU(nn.Module): 17 | def __init__(self, nf, kernel_size, wn, act=nn.ReLU(True)): 18 | super(AWRU, self).__init__() 19 | self.res_scale = Scale(1) 20 | self.x_scale = Scale(1) 21 | 22 | self.body = nn.Sequential( 23 | wn(nn.Conv2d(nf, nf, kernel_size, padding=kernel_size//2)), 24 | act, 25 | wn(nn.Conv2d(nf, nf, kernel_size, padding=kernel_size//2)), 26 | ) 27 | 28 | def forward(self, x): 29 | res = self.res_scale(self.body(x)) + self.x_scale(x) 30 | return res 31 | 32 | 33 | class AWMS(nn.Module): 34 | def __init__(self, nf, out_chl, wn, act=nn.ReLU(True)): 35 | super(AWMS, self).__init__() 36 | self.tail_k3 = wn(nn.Conv2d(nf, nf, 3, padding=3//2, dilation=1)) 37 | self.tail_k5 = wn(nn.Conv2d(nf, nf, 5, padding=5//2, dilation=1)) 38 | self.scale_k3 = Scale(0.5) 39 | self.scale_k5 = Scale(0.5) 40 | self.fuse = wn(nn.Conv2d(nf, nf, 3, padding=3 // 2)) 41 | self.act = act 42 | self.w_conv = wn(nn.Conv2d(nf, out_chl, 3, padding=3//2)) 43 | 44 | def forward(self, x): 45 | x0 = self.scale_k3(self.tail_k3(x)) 46 | x1 = self.scale_k5(self.tail_k5(x)) 47 | cur_x = x0 + x1 48 | 49 | fuse_x = self.act(self.fuse(cur_x)) 50 | out = self.w_conv(fuse_x) 51 | 52 | return out 53 | 54 | 55 | class LFB(nn.Module): 56 | def __init__(self, nf, wn, act=nn.ReLU(inplace=True)): 57 | super(LFB, self).__init__() 58 | self.b0 = AWRU(nf, 3, wn=wn, act=act) 59 | self.b1 = AWRU(nf, 3, wn=wn, act=act) 60 | self.b2 = AWRU(nf, 3, wn=wn, act=act) 61 | self.b3 = AWRU(nf, 3, wn=wn, act=act) 62 | self.reduction = wn(nn.Conv2d(nf * 4, nf, 3, padding=3//2)) 63 | self.res_scale = Scale(1) 64 | self.x_scale = Scale(1) 65 | 66 | def forward(self, x): 67 | x0 = self.b0(x) 68 | x1 = self.b1(x0) 69 | x2 = self.b2(x1) 70 | x3 = self.b3(x2) 71 | res = self.reduction(torch.cat([x0, x1, x2, x3], dim=1)) 72 | 73 | return self.res_scale(res) + self.x_scale(x) 74 | 75 | 76 | class WeightNet(nn.Module): 77 | def __init__(self, config): 78 | super(WeightNet, self).__init__() 79 | 80 | in_chl = config.IN_CHANNEL 81 | nf = config.N_CHANNEL 82 | n_block = config.RES_BLOCK 83 | out_chl = config.N_WEIGHT 84 | scale = config.SCALE 85 | 86 | act = nn.ReLU(inplace=True) 87 | wn = lambda x: nn.utils.weight_norm(x) 88 | 89 | rgb_mean = torch.FloatTensor([0.4488, 0.4371, 0.4040]).view([1, 3, 1, 1]) 90 | self.register_buffer('rgb_mean', rgb_mean) 91 | 92 | self.head = nn.Sequential( 93 | wn(nn.Conv2d(in_chl, nf, 3, padding=3//2)), 94 | act, 95 | ) 96 | 97 | body = [] 98 | for i in range(n_block): 99 | body.append(LFB(nf, wn=wn, act=act)) 100 | self.body = nn.Sequential(*body) 101 | 102 | self.up = nn.Sequential( 103 | wn(nn.Conv2d(nf, nf * scale ** 2, 3, padding=3//2)), 104 | act, 105 | nn.PixelShuffle(upscale_factor=scale) 106 | ) 107 | 108 | self.tail = AWMS(nf, out_chl, wn, act=act) 109 | 110 | def forward(self, x): 111 | x = x - self.rgb_mean 112 | x = self.head(x) 113 | x = self.body(x) 114 | x = self.up(x) 115 | out = self.tail(x) 116 | 117 | return out 118 | 119 | 120 | if __name__ == '__main__': 121 | from easydict import EasyDict as edict 122 | 123 | config = edict() 124 | config.IN_CHANNEL = 3 125 | config.N_CHANNEL = 32 126 | config.RES_BLOCK = 4 127 | config.N_WEIGHT = 72 128 | config.SCALE = 2 129 | 130 | net = WeightNet(config).cuda() 131 | 132 | cnt = 0 133 | for p in net.parameters(): 134 | cnt += p.numel() 135 | print(cnt) 136 | 137 | x = torch.randn(1, 3, 32, 32).cuda() 138 | out = net(x) 139 | print(out.size()) 140 | -------------------------------------------------------------------------------- /utils/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def make_layer(block, n_layers): 28 | layers = [] 29 | for _ in range(n_layers): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class ResidualBlock_noBN(nn.Module): 35 | '''Residual block w/o BN 36 | ---Conv-ReLU-Conv-+-ReLU 37 | |________________| 38 | ''' 39 | 40 | def __init__(self, nf=64): 41 | super(ResidualBlock_noBN, self).__init__() 42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 44 | 45 | # initialization 46 | initialize_weights([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = F.relu(self.conv1(x), inplace=True) 51 | out = self.conv2(out) 52 | return F.relu(identity + out) 53 | 54 | 55 | class ResidualBlock_noBN_noAct(nn.Module): 56 | '''Residual block w/o BN 57 | ---Conv-ReLU-Conv-+- 58 | |________________| 59 | ''' 60 | 61 | def __init__(self, nf=64): 62 | super(ResidualBlock_noBN_noAct, self).__init__() 63 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | 66 | # initialization 67 | initialize_weights([self.conv1, self.conv2], 0.1) 68 | 69 | def forward(self, x): 70 | identity = x 71 | out = F.relu(self.conv1(x), inplace=True) 72 | out = self.conv2(out) 73 | return identity + out 74 | 75 | 76 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 77 | """Warp an image or feature map with optical flow 78 | Args: 79 | x (Tensor): size (N, C, H, W) 80 | flow (Tensor): size (N, H, W, 2), normal value 81 | interp_mode (str): 'nearest' or 'bilinear' 82 | padding_mode (str): 'zeros' or 'border' or 'reflection' 83 | 84 | Returns: 85 | Tensor: warped image or feature map 86 | """ 87 | assert x.size()[-2:] == flow.size()[1:3] 88 | B, C, H, W = x.size() 89 | # mesh grid 90 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 91 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 92 | grid.requires_grad = False 93 | grid = grid.type_as(x) 94 | vgrid = grid + flow 95 | # scale grid to [-1,1] 96 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 97 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 98 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 99 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) -------------------------------------------------------------------------------- /utils/modules/rrdb.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | def make_layer(block, n_layers): 8 | layers = [] 9 | for _ in range(n_layers): 10 | layers.append(block()) 11 | return nn.Sequential(*layers) 12 | 13 | 14 | class ResidualDenseBlock_5C(nn.Module): 15 | def __init__(self, nf=64, gc=32, bias=True): 16 | super(ResidualDenseBlock_5C, self).__init__() 17 | # gc: growth channel, i.e. intermediate channels 18 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 19 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 20 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 21 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 22 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 23 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 24 | 25 | # initialization 26 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 27 | 28 | def forward(self, x): 29 | x1 = self.lrelu(self.conv1(x)) 30 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 31 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 32 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 33 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 34 | return x5 * 0.2 + x 35 | 36 | 37 | class RRDB(nn.Module): 38 | '''Residual in Residual Dense Block''' 39 | 40 | def __init__(self, nf, gc=32): 41 | super(RRDB, self).__init__() 42 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 43 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 44 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 45 | 46 | def forward(self, x): 47 | out = self.RDB1(x) 48 | out = self.RDB2(out) 49 | out = self.RDB3(out) 50 | return out * 0.2 + x 51 | 52 | 53 | class RRDBNet(nn.Module): 54 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 55 | super(RRDBNet, self).__init__() 56 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 57 | 58 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 59 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 60 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 61 | #### upsampling 62 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 64 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 65 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 66 | 67 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 68 | 69 | def forward(self, x): 70 | fea = self.conv_first(x) 71 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 72 | fea = fea + trunk 73 | 74 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 75 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 76 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 77 | 78 | return out 79 | 80 | 81 | if __name__ == '__main__': 82 | net = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23).cuda() 83 | x = torch.randn(1, 3, 180, 320).cuda() 84 | import time 85 | with torch.no_grad(): 86 | times = 50 87 | t1 = time.time() 88 | for i in range(times): 89 | out = net(x) 90 | print((time.time() - t1) / times) 91 | -------------------------------------------------------------------------------- /utils/modules/vggNet.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | import torch 3 | import torch.nn as nn 4 | from torchvision.models import vgg as vgg 5 | 6 | 7 | NAMES = { 8 | 'vgg11': [ 9 | 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 10 | 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 11 | 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 12 | 'conv5_2', 'relu5_2', 'pool5' 13 | ], 14 | 'vgg13': [ 15 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 16 | 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 17 | 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 18 | 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' 19 | ], 20 | 'vgg16': [ 21 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 22 | 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 23 | 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 24 | 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 25 | 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 26 | 'pool5' 27 | ], 28 | 'vgg19': [ 29 | 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 30 | 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', 'conv3_1', 'relu3_1', 31 | 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 32 | 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 33 | 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', 34 | 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 35 | 'pool5' 36 | ] 37 | } 38 | 39 | 40 | def insert_bn(names): 41 | """Insert bn layer after each conv. 42 | 43 | Args: 44 | names (list): The list of layer names. 45 | 46 | Returns: 47 | list: The list of layer names with bn layers. 48 | """ 49 | names_bn = [] 50 | for name in names: 51 | names_bn.append(name) 52 | if 'conv' in name: 53 | position = name.replace('conv', '') 54 | names_bn.append('bn' + position) 55 | return names_bn 56 | 57 | 58 | class VGGFeatureExtractor(nn.Module): 59 | """VGG network for feature extraction. 60 | 61 | In this implementation, we allow users to choose whether use normalization 62 | in the input feature and the type of vgg network. Note that the pretrained 63 | path must fit the vgg type. 64 | 65 | Args: 66 | layer_name_list (list[str]): Forward function returns the corresponding 67 | features according to the layer_name_list. 68 | Example: {'relu1_1', 'relu2_1', 'relu3_1'}. 69 | vgg_type (str): Set the type of vgg network. Default: 'vgg19'. 70 | use_input_norm (bool): If True, normalize the input image. Importantly, 71 | the input feature must in the range [0, 1]. Default: True. 72 | requires_grad (bool): If true, the parameters of VGG network will be 73 | optimized. Default: False. 74 | remove_pooling (bool): If true, the max pooling operations in VGG net 75 | will be removed. Default: False. 76 | pooling_stride (int): The stride of max pooling operation. Default: 2. 77 | """ 78 | 79 | def __init__(self, 80 | layer_name_list, 81 | vgg_type='vgg19', 82 | use_input_norm=True, 83 | requires_grad=False, 84 | remove_pooling=False, 85 | pooling_stride=2): 86 | super(VGGFeatureExtractor, self).__init__() 87 | 88 | self.layer_name_list = layer_name_list 89 | self.use_input_norm = use_input_norm 90 | 91 | self.names = NAMES[vgg_type.replace('_bn', '')] 92 | if 'bn' in vgg_type: 93 | self.names = insert_bn(self.names) 94 | 95 | # only borrow layers that will be used to avoid unused params 96 | max_idx = 0 97 | for v in layer_name_list: 98 | idx = self.names.index(v) 99 | if idx > max_idx: 100 | max_idx = idx 101 | 102 | features = getattr(vgg, vgg_type)(pretrained=True).features[:max_idx + 1] 103 | 104 | modified_net = OrderedDict() 105 | for k, v in zip(self.names, features): 106 | if 'pool' in k: 107 | # if remove_pooling is true, pooling operation will be removed 108 | if remove_pooling: 109 | continue 110 | else: 111 | # in some cases, we may want to change the default stride 112 | modified_net[k] = nn.MaxPool2d( 113 | kernel_size=2, stride=pooling_stride) 114 | else: 115 | modified_net[k] = v 116 | 117 | self.vgg_net = nn.Sequential(modified_net) 118 | 119 | if not requires_grad: 120 | self.vgg_net.eval() 121 | for param in self.parameters(): 122 | param.requires_grad = False 123 | else: 124 | self.vgg_net.train() 125 | for param in self.parameters(): 126 | param.requires_grad = True 127 | 128 | if self.use_input_norm: 129 | # the mean is for image with range [0, 1] 130 | self.register_buffer( 131 | 'mean', 132 | torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 133 | # the std is for image with range [0, 1] 134 | self.register_buffer( 135 | 'std', 136 | torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 137 | 138 | def forward(self, x): 139 | """Forward function. 140 | 141 | Args: 142 | x (Tensor): Input tensor with shape (n, c, h, w). 143 | 144 | Returns: 145 | Tensor: Forward results. 146 | """ 147 | 148 | if self.use_input_norm: 149 | x = (x - self.mean) / self.std 150 | 151 | output = {} 152 | for key, layer in self.vgg_net._modules.items(): 153 | x = layer(x) 154 | if key in self.layer_name_list: 155 | output[key] = x.clone() 156 | 157 | return output 158 | -------------------------------------------------------------------------------- /utils/region_seperator.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def get_flat_mask(img, kernel_size=7, std_thresh=0.03, scale=1): 9 | img = F.interpolate(img, scale_factor=scale, mode='bicubic', align_corners=False) 10 | B, _, H, W = img.size() 11 | r, g, b = torch.unbind(img, dim=1) 12 | l_img = (0.2989 * r + 0.587 * g + 0.114 * b).unsqueeze(dim=1) 13 | l_img_pad = F.pad(l_img, (kernel_size//2, kernel_size//2, kernel_size//2, kernel_size//2), mode='reflect') 14 | unf_img = F.unfold(l_img_pad, kernel_size=kernel_size, padding=0, stride=1) 15 | std_map = torch.std(unf_img, dim=1, keepdim=True).view(B, 1, H, W) 16 | mask = torch.lt(std_map, std_thresh).float() 17 | 18 | return mask 19 | -------------------------------------------------------------------------------- /utils/samplers/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from .distributed import DistributedSampler 3 | from .grouped_batch_sampler import GroupedBatchSampler 4 | from .iteration_based_batch_sampler import IterationBasedBatchSampler 5 | 6 | __all__ = ["DistributedSampler", "GroupedBatchSampler", "IterationBasedBatchSampler"] -------------------------------------------------------------------------------- /utils/samplers/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | # Code is copy-pasted exactly as in torch.utils.data.distributed. 3 | # FIXME remove this once c10d fixes the bug it has 4 | import math 5 | import torch 6 | import torch.distributed as dist 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | class DistributedSampler(Sampler): 11 | """Sampler that restricts data loading to a subset of the dataset. 12 | It is especially useful in conjunction with 13 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 14 | process can pass a DistributedSampler instance as a DataLoader sampler, 15 | and load a subset of the original dataset that is exclusive to it. 16 | .. note:: 17 | Dataset is assumed to be of constant size. 18 | Arguments: 19 | dataset: Dataset used for sampling. 20 | num_replicas (optional): Number of processes participating in 21 | distributed training. 22 | rank (optional): Rank of the current process within num_replicas. 23 | """ 24 | 25 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | self.shuffle = shuffle 41 | 42 | def __iter__(self): 43 | if self.shuffle: 44 | # deterministically shuffle based on epoch 45 | g = torch.Generator() 46 | g.manual_seed(self.epoch) 47 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 48 | else: 49 | indices = torch.arange(len(self.dataset)).tolist() 50 | 51 | # add extra samples to make it evenly divisible 52 | indices += indices[:(self.total_size - len(indices))] 53 | assert len(indices) == self.total_size 54 | 55 | # subsample 56 | offset = self.num_samples * self.rank 57 | indices = indices[offset: offset + self.num_samples] 58 | assert len(indices) == self.num_samples 59 | 60 | return iter(indices) 61 | 62 | def __len__(self): 63 | return self.num_samples 64 | 65 | def set_epoch(self, epoch): 66 | self.epoch = epoch 67 | -------------------------------------------------------------------------------- /utils/samplers/grouped_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import itertools 3 | import copy 4 | import bisect 5 | 6 | import torch 7 | from torch.utils.data.sampler import BatchSampler 8 | from torch.utils.data.sampler import Sampler 9 | 10 | 11 | def _quantize(x, bins): 12 | bins = copy.copy(bins) 13 | bins = sorted(bins) 14 | quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) 15 | return quantized 16 | 17 | 18 | def _compute_aspect_ratios(dataset): 19 | aspect_ratios = [] 20 | for i in range(len(dataset)): 21 | img_info = dataset.get_img_info(i) 22 | aspect_ratio = float(img_info["height"]) / float( 23 | img_info["width"]) 24 | aspect_ratios.append(aspect_ratio) 25 | return aspect_ratios 26 | 27 | 28 | class GroupedBatchSampler(BatchSampler): 29 | """ 30 | Wraps another sampler to yield a mini-batch of indices. 31 | It enforces that elements from the same group should appear in 32 | groups of batch_size. 33 | It also tries to provide mini-batches which follows an ordering which is 34 | as close as possible to the ordering from the original sampler. 35 | 36 | Arguments: 37 | sampler (Sampler): Base sampler. 38 | batch_size (int): Size of mini-batch. 39 | drop_uneven (bool): If ``True``, the sampler will drop the batches whose 40 | size is less than ``batch_size`` 41 | 42 | """ 43 | 44 | def __init__(self, sampler, dataset, aspect_grouping, batch_size, 45 | drop_uneven=False): 46 | aspect_ratios = _compute_aspect_ratios(dataset) 47 | group_ids = _quantize(aspect_ratios, aspect_grouping) 48 | if not isinstance(sampler, Sampler): 49 | raise ValueError( 50 | "sampler should be an instance of " 51 | "torch.utils.data.Sampler, but got sampler={}".format(sampler) 52 | ) 53 | self.sampler = sampler 54 | self.group_ids = torch.as_tensor(group_ids) 55 | assert self.group_ids.dim() == 1 56 | self.batch_size = batch_size 57 | self.drop_uneven = drop_uneven 58 | 59 | self.groups = torch.unique(self.group_ids).sort(0)[0] 60 | 61 | self._can_reuse_batches = False 62 | 63 | def _prepare_batches(self): 64 | dataset_size = len(self.group_ids) 65 | # get the sampled indices from the sampler 66 | sampled_ids = torch.as_tensor(list(self.sampler)) 67 | # potentially not all elements of the dataset were sampled 68 | # by the sampler (e.g., DistributedSampler). 69 | # construct a tensor which contains -1 if the element was 70 | # not sampled, and a non-negative number indicating the 71 | # order where the element was sampled. 72 | # for example. if sampled_ids = [3, 1] and dataset_size = 5, 73 | # the order is [-1, 1, -1, 0, -1] 74 | order = torch.full((dataset_size,), -1, dtype=torch.int64) 75 | order[sampled_ids] = torch.arange(len(sampled_ids)) 76 | 77 | # get a mask with the elements that were sampled 78 | mask = order >= 0 79 | 80 | # find the elements that belong to each individual cluster 81 | clusters = [(self.group_ids == i) & mask for i in self.groups] 82 | # get relative order of the elements inside each cluster 83 | # that follows the order from the sampler 84 | relative_order = [order[cluster] for cluster in clusters] 85 | # with the relative order, find the absolute order in the 86 | # sampled space 87 | permutation_ids = [s[s.sort()[1]] for s in relative_order] 88 | # permute each cluster so that they follow the order from 89 | # the sampler 90 | permuted_clusters = [sampled_ids[idx] for idx in permutation_ids] 91 | 92 | # splits each cluster in batch_size, and merge as a list of tensors 93 | splits = [c.split(self.batch_size) for c in permuted_clusters] 94 | merged = tuple(itertools.chain.from_iterable(splits)) 95 | 96 | # now each batch internally has the right order, but 97 | # they are grouped by clusters. Find the permutation between 98 | # different batches that brings them as close as possible to 99 | # the order that we have in the sampler. For that, we will consider the 100 | # ordering as coming from the first element of each batch, and sort 101 | # correspondingly 102 | first_element_of_batch = [t[0].item() for t in merged] 103 | # get and inverse mapping from sampled indices and the position where 104 | # they occur (as returned by the sampler) 105 | inv_sampled_ids_map = {v: k for k, v in enumerate(sampled_ids.tolist())} 106 | # from the first element in each batch, get a relative ordering 107 | first_index_of_batch = torch.as_tensor( 108 | [inv_sampled_ids_map[s] for s in first_element_of_batch] 109 | ) 110 | 111 | # permute the batches so that they approximately follow the order 112 | # from the sampler 113 | permutation_order = first_index_of_batch.sort(0)[1].tolist() 114 | # finally, permute the batches 115 | batches = [merged[i].tolist() for i in permutation_order] 116 | 117 | if self.drop_uneven: 118 | kept = [] 119 | for batch in batches: 120 | if len(batch) == self.batch_size: 121 | kept.append(batch) 122 | batches = kept 123 | return batches 124 | 125 | def __iter__(self): 126 | if self._can_reuse_batches: 127 | batches = self._batches 128 | self._can_reuse_batches = False 129 | else: 130 | batches = self._prepare_batches() 131 | self._batches = batches 132 | return iter(batches) 133 | 134 | def __len__(self): 135 | if not hasattr(self, "_batches"): 136 | self._batches = self._prepare_batches() 137 | self._can_reuse_batches = True 138 | return len(self._batches) 139 | -------------------------------------------------------------------------------- /utils/samplers/iteration_based_batch_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from torch.utils.data.sampler import BatchSampler 3 | 4 | 5 | class IterationBasedBatchSampler(BatchSampler): 6 | """ 7 | Wraps a BatchSampler, resampling from it until 8 | a specified number of iterations have been sampled 9 | """ 10 | 11 | def __init__(self, batch_sampler, num_iterations, start_iter=0): 12 | self.batch_sampler = batch_sampler 13 | self.num_iterations = num_iterations 14 | self.start_iter = start_iter 15 | 16 | def __iter__(self): 17 | iteration = self.start_iter 18 | while iteration <= self.num_iterations: 19 | # if the underlying sampler has a set_epoch method, like 20 | # DistributedSampler, used for making each process see 21 | # a different split of the dataset, then set it 22 | if hasattr(self.batch_sampler.sampler, "set_epoch"): 23 | self.batch_sampler.sampler.set_epoch(iteration) 24 | for batch in self.batch_sampler: 25 | iteration += 1 26 | if iteration > self.num_iterations: 27 | break 28 | yield batch 29 | 30 | def __len__(self): 31 | return self.num_iterations 32 | -------------------------------------------------------------------------------- /utils/solver.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch.optim as optim 4 | from torch.optim.lr_scheduler import _LRScheduler 5 | 6 | 7 | def make_optimizer(config, model, num_gpu=None): 8 | if num_gpu is None: 9 | lr = config.SOLVER.BASE_LR 10 | else: 11 | lr = config.SOLVER.BASE_LR * num_gpu 12 | 13 | if config.SOLVER.OPTIMIZER == 'Adam': 14 | optimizer = optim.Adam(model.parameters(), 15 | lr=lr, betas=(0.9, 0.999), eps=1e-8, 16 | weight_decay=config.SOLVER.WEIGHT_DECAY) 17 | elif config.SOLVER.OPTIMIZER == 'SGD': 18 | optimizer = optim.SGD(model.parameters(), 19 | lr=lr, momentum=config.SOLVER.MOMENTUM, 20 | weight_decay=config.SOLVER.WEIGHT_DECAY) 21 | else: 22 | raise ValueError('Illegal optimizer.') 23 | 24 | return optimizer 25 | 26 | 27 | def make_optimizer_sep(model, base_lr, type='Adam', beta1=0.9, beta2=0.999, weight_decay=0, momentum=0, num_gpu=None): 28 | if num_gpu is None: 29 | lr = base_lr 30 | else: 31 | lr = base_lr * num_gpu 32 | 33 | if type == 'Adam': 34 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), 35 | lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=weight_decay) 36 | elif type == 'SGD': 37 | optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), 38 | lr=lr, momentum=momentum, weight_decay=weight_decay) 39 | elif type == 'AdamW': 40 | optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), 41 | lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=0.01) 42 | elif type == 'Adamax': 43 | optimizer = optim.Adamax(filter(lambda p: p.requires_grad, model.parameters()), 44 | lr=lr, betas=(beta1, beta2), eps=1e-8, weight_decay=weight_decay) 45 | else: 46 | raise ValueError('Illegal optimizer.') 47 | 48 | return optimizer 49 | 50 | 51 | def make_lr_scheduler(config, optimizer): 52 | w_iter = config.SOLVER.WARM_UP_ITER 53 | w_fac = config.SOLVER.WARM_UP_FACTOR 54 | max_iter = config.SOLVER.MAX_ITER 55 | lr_lambda = lambda iteration: w_fac + (1 - w_fac) * iteration / w_iter \ 56 | if iteration < w_iter \ 57 | else 1 / 2 * (1 + math.cos((iteration - w_iter) / (max_iter - w_iter) * math.pi)) 58 | scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1) 59 | 60 | return scheduler 61 | 62 | 63 | class CosineAnnealingLR_warmup(_LRScheduler): 64 | def __init__(self, config, optimizer, base_lr, last_epoch=-1, min_lr=1e-7): 65 | self.base_lr = base_lr 66 | self.min_lr = min_lr 67 | self.w_iter = config.SOLVER.WARM_UP_ITER 68 | self.w_fac = config.SOLVER.WARM_UP_FACTOR 69 | self.T_period = config.SOLVER.T_PERIOD 70 | self.last_restart = 0 71 | self.T_max = self.T_period[0] 72 | assert config.SOLVER.MAX_ITER == self.T_period[-1], 'Illegal training period setting.' 73 | super(CosineAnnealingLR_warmup, self).__init__(optimizer, last_epoch=last_epoch) 74 | 75 | def get_lr(self): 76 | if self.last_epoch - self.last_restart < self.w_iter: 77 | ratio = self.w_fac + (1 - self.w_fac) * (self.last_epoch - self.last_restart) / self.w_iter 78 | return [(self.base_lr - self.min_lr) * ratio + self.min_lr for group in self.optimizer.param_groups] 79 | elif self.last_epoch in self.T_period: 80 | self.last_restart = self.last_epoch 81 | if self.last_epoch != self.T_period[-1]: 82 | self.T_max = self.T_period[self.T_period.index(self.last_epoch) + 1] 83 | return [self.min_lr for group in self.optimizer.param_groups] 84 | else: 85 | ratio = 1 / 2 * (1 + math.cos( 86 | (self.last_epoch - self.last_restart - self.w_iter) / (self.T_max - self.last_restart - self.w_iter) * math.pi)) 87 | return [(self.base_lr - self.min_lr) * ratio + self.min_lr for group in self.optimizer.param_groups] 88 | --------------------------------------------------------------------------------