├── .gitignore ├── LICENSE ├── README.md ├── code ├── __init__.py ├── data │ ├── __init__.py │ ├── benchmark.py │ ├── benchmark_texture_sr.py │ ├── common.py │ ├── demo.py │ ├── div2k.py │ ├── div2ksub.py │ ├── srdata.py │ ├── srtexture.py │ ├── texture.py │ └── texture_hr.py ├── dataloader.py ├── dataloader_new.py ├── loss │ ├── __init__.py │ ├── adversarial.py │ ├── discriminator.py │ └── vgg.py ├── main.py ├── model │ ├── __init__.py │ ├── carn.py │ ├── common.py │ ├── ddbpn.py │ ├── edsr.py │ ├── finetune.py │ ├── mdsr.py │ ├── ops.py │ ├── rcan.py │ └── srresnet.py ├── option.py ├── scripts │ ├── 3d_appearance_sr.pdf │ ├── contribution.jpg │ ├── demo.sh │ ├── finetune_sr.sh │ ├── network_NHR.png │ ├── network_NLR.png │ ├── qsub.sh │ ├── qsub_HRST_CNN.sh │ ├── qsub_NHR.sh │ ├── qsub_NLR.sh │ └── vpython.sh ├── template.py ├── trainer.py ├── trainer_finetune.py ├── utility.py └── utils │ ├── LMDB_TEST.py │ ├── compare_PSNR_preSR.py │ ├── compare_PSNR_preSR_image.py │ ├── compute_PSNR_SR_1NLR.py │ ├── compute_PSNR_SR_2Sub.py │ ├── compute_PSNR_SR_3NHR.py │ ├── compute_PSNR_SR_4HRST.py │ ├── compute_PSNR_UP.py │ ├── create_lmdb.py │ ├── myssim.py │ ├── prepare_psnr_table.py │ ├── prepare_ssim_table.py │ ├── prepare_visual.py │ └── psnr.py └── experiment └── README.md /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | ./code/trainer_multi.py 3 | ./code/trainer_single.py 4 | ./code/trainer_srgan.py 5 | ./code/init_gpu_interactive.sh 6 | ./code/qrsh_reserve.sh 7 | ./code/results_texture/* 8 | *__pycache__ 9 | ./experiment/* 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yawei Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | We received **Best Poster Prize** at [ICVSS 2019](https://iplab.dmi.unict.it/icvss2019/PresentationPrize). 2 | 3 | This is the official website of our work [3D Appearance Super-Resolution with Deep Learning](./code/scripts/3d_appearance_sr.pdf) [(arxiv)](https://arxiv.org/abs/1906.00925) published on CVPR2019. 4 | 5 | We provided 3DASR, a 3D appearance SR dataset that captures both synthetic and real scenes with a large variety of texture characteristics. The dataset contains ground truth HR texture maps and LR texture maps of scaling factors ×2, ×3, and ×4. The 3D mesh, multi-view images, projection matrices, and normal maps are also provided. We introduced a deep learning-based SR framework in the multi-view setting. We showed that 2D deep learning-based SR techniques can successfully be adapted to the new texture domain by introducing the geometric information via normal maps. 6 | 7 | ![alt text][contribution] 8 | 9 | [contribution]: ./code/scripts/contribution.jpg "We introduce the 3DASR, a 3D appearance SR dataset and a deep learning-based approach to super-resolve the appearance of 3D objects." 10 | We introduce the 3DASR, a 3D appearance SR dataset and a deep learning-based approach to super-resolve the appearance of 3D objects. 11 | 12 | 13 | # Dependencies 14 | * Python 3.6 15 | * PyTorch >= 1.0.0 16 | * numpy 17 | * skimage 18 | * imageio 19 | * matplotlib 20 | * tqdm 21 | 22 | # Quick Start (Test) 23 | 1. `git clone https://github.com/ofsoundof/3D_Appearance_SR.git` 24 | 2. Download pretrained model and texture map dataset. 25 | 3. Put pretrained model at this website ➡ [`./experiment/`](./experiment). 26 | 4. `cd ./code/script` 27 | 28 | `CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NLR --save Test/NLR_first --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/NLR/model_x2_split1.pt --data_train texture --data_test texture --model_one one --subset . --normal_lr lr --input_res lr --chop --reset --save_results --print_model --test_only` 29 | 30 | Use `--ext sep_reset` for the first run that uses a specific split of the two splits from cross-validation. 31 | 32 | Be sure to change log directory `--dir` and data directory `--dir_data`. 33 | 34 | # How to Run the Code 35 | ## Prepare pretrained model 36 | 1. Download our pretrained model for 3D appearance SR from [google drive](https://drive.google.com/file/d/1TaBua-A0DT0jc4x_I4HVFicKOndzSBxU/view?usp=sharing) or [BaiduNetDisk, extraction code: nnnm](https://pan.baidu.com/s/1-_yozGa3QMMe0TRIUg5WBw). The pretrained models of NLR and NHR in the paper are included. 37 | 38 | 2. Download the pretrained EDSR model from [EDSR project page](https://github.com/thstkdgus35/EDSR-PyTorch). 39 | 40 | 3. Put the pretrained model at [`./experiment`](./experiment). 41 | 42 | ## Prepare dataset 43 | 1. Download the texture map of the 3D appearance dataset from [Google Drive](https://drive.google.com/file/d/18rHsefdYNSEG7QMwzaS8iFHIdLOB2eND/view?usp=sharing) or [BaiduNetDisk, extraction code: crnw](https://pan.baidu.com/s/1U-bnnG6LjOVtHqX3fMCq2w). 44 | 45 | ## Train and test 46 | 1. Please refer to [`demo.sh`](./code/scripts/demo.sh) for the training and testing demo script. In a batch system, you can also use [`qsub_NLR.sh`](./code/scripts/qsub_NLR.sh). 47 | 2. Remember to change the log directory `--dir` and data directory `--dir_data`. `--dir` is the directory where you can put your log information and the trained model. `--dir_data` is the directory where you put the dataset. 48 | 49 | # BibTeX 50 | If you find our work useful in your research or publication, please cite our work: 51 | 52 | Yawei Li , Vagia Tsiminaki, Radu Timofte, Marc Pollefeys, and Luc van Gool, "**3D Appearance Super-Resolution with Deep Learning**" In *Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition*, 2019. 53 | 54 | @inproceedings{li2019_3dappearance, 55 | title={3D Appearance Super-Resolution with Deep Learning}, 56 | author={Li, Yawei and Tsiminaki, Vagia and Timofte, Radu and Pollefeys, Marc and Van Gool, Luc}, 57 | booktitle={In Proceedings of the IEEE International Conference on Computer Vision}, 58 | year={2019} 59 | } 60 | -------------------------------------------------------------------------------- /code/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofsoundof/3D_Appearance_SR/91fc377466e2756c6cf753b7db48ef98e4ea13c2/code/__init__.py -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | from importlib import import_module 2 | 3 | from dataloader_new import MSDataLoader 4 | from torch.utils.data.dataloader import default_collate 5 | # from torch.utils.data.dataloader import DataLoader 6 | 7 | class Data: 8 | def __init__(self, args): 9 | kwargs = {} 10 | if not args.cpu: 11 | kwargs['collate_fn'] = default_collate 12 | kwargs['pin_memory'] = True 13 | else: 14 | kwargs['collate_fn'] = default_collate 15 | kwargs['pin_memory'] = False 16 | kwargs['num_workers'] = args.n_threads 17 | self.loader_train = None 18 | if not args.test_only: 19 | module_train = import_module('data.' + args.data_train.lower()) 20 | trainset = getattr(module_train, args.data_train)(args) 21 | #from IPython import embed; embed(); exit() 22 | 23 | # self.loader_train = DataLoader( 24 | # trainset, 25 | # batch_size=args.batch_size, 26 | # shuffle=True, 27 | # **kwargs 28 | # ) 29 | self.loader_train = MSDataLoader( 30 | args, 31 | trainset, 32 | batch_size=args.batch_size, 33 | shuffle=True, 34 | **kwargs 35 | ) 36 | 37 | 38 | if args.data_test in ['Set5', 'Set14', 'B100', 'Urban100']: 39 | if not args.benchmark_noise: 40 | module_test = import_module('data.benchmark') 41 | testset = getattr(module_test, 'Benchmark')(args, train=False) 42 | else: 43 | module_test = import_module('data.benchmark_noise') 44 | testset = getattr(module_test, 'BenchmarkNoise')( 45 | args, 46 | train=False 47 | ) 48 | elif args.data_test in ['Collection', 'ColMapMiddlebury', 'ETH3D', 'SyB3R']: 49 | module_test = import_module('data.benchmark_texture_sr') 50 | testset = getattr(module_test, 'BenchmarkTextureSR')(args, train=False) 51 | else: 52 | module_test = import_module('data.' + args.data_test.lower()) 53 | testset = getattr(module_test, args.data_test)(args, train=False) 54 | 55 | # self.loader_test = DataLoader( 56 | # testset, 57 | # batch_size=1, 58 | # shuffle=False, 59 | # **kwargs 60 | # ) 61 | self.loader_test = MSDataLoader( 62 | args, 63 | testset, 64 | batch_size=1, 65 | shuffle=False, 66 | **kwargs 67 | ) -------------------------------------------------------------------------------- /code/data/benchmark.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import common 3 | from data import srdata 4 | import glob 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Benchmark(srdata.SRData): 12 | def __init__(self, args, train=True): 13 | super(Benchmark, self).__init__(args, train, benchmark=True) 14 | 15 | # def _scan(self): 16 | # list_hr = [] 17 | # list_lr = [[] for _ in self.scale] 18 | # for entry in os.scandir(self.dir_hr): 19 | # filename = os.path.splitext(entry.name)[0] 20 | # list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 21 | # for si, s in enumerate(self.scale): 22 | # list_lr[si].append(os.path.join(self.dir_lr, filename + self.ext)) 23 | # # list_lr[si].append(os.path.join( 24 | # # self.dir_lr, 25 | # # 'X{}/{}x{}{}'.format(s, filename, s, self.ext) 26 | # # )) 27 | # 28 | # list_hr.sort() 29 | # for l in list_lr: 30 | # l.sort() 31 | # 32 | # return list_hr, list_lr 33 | 34 | def _scan(self): 35 | list_hr = [] 36 | list_lr = [] 37 | for s in self.scale: 38 | list_hr.append(sorted(glob.glob(self.dir_hr.format(s, self.color)))) 39 | list_lr.append(sorted(glob.glob(self.dir_lr.format(s, self.color)))) 40 | 41 | return list_hr, list_lr 42 | 43 | def _set_filesystem(self, dir_data): 44 | self.apath = os.path.join(dir_data, 'Test', self.args.data_test) 45 | self.dir_hr = self.apath + '_X{}_high{}/*.png' 46 | self.dir_lr = self.apath + '_X{}_low{}/*.png' 47 | self.ext = '.png' 48 | -------------------------------------------------------------------------------- /code/data/benchmark_texture_sr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from data import common 3 | from data import srdata 4 | import glob 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class BenchmarkTextureSR(srdata.SRData): 12 | def __init__(self, args, train=True): 13 | super(BenchmarkTextureSR, self).__init__(args, train, benchmark=True) 14 | 15 | # def _scan(self): 16 | # list_hr = [] 17 | # list_lr = [[] for _ in self.scale] 18 | # for entry in os.scandir(self.dir_hr): 19 | # filename = os.path.splitext(entry.name)[0] 20 | # list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 21 | # for si, s in enumerate(self.scale): 22 | # list_lr[si].append(os.path.join(self.dir_lr, filename + self.ext)) 23 | # # list_lr[si].append(os.path.join( 24 | # # self.dir_lr, 25 | # # 'X{}/{}x{}{}'.format(s, filename, s, self.ext) 26 | # # )) 27 | # 28 | # list_hr.sort() 29 | # for l in list_lr: 30 | # l.sort() 31 | # 32 | # return list_hr, list_lr 33 | 34 | def _scan(self): 35 | list_hr = [] 36 | list_lr = [] 37 | for s in self.scale: 38 | list_hr.append(sorted(glob.glob(self.dir_hr))) 39 | list_lr.append(sorted(glob.glob(self.dir_lr.format(s)))) 40 | 41 | return list_hr, list_lr 42 | 43 | def _set_filesystem(self, dir_data): 44 | self.apath = os.path.join(dir_data, self.args.data_test, self.args.data_test_texture_sr) 45 | self.dir_hr = self.apath + '/x1/Images/*.png' 46 | self.dir_lr = self.apath + '/x{}/Images/*.png' 47 | self.ext = '.png' 48 | -------------------------------------------------------------------------------- /code/data/common.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import skimage.io as sio 5 | import skimage.color as sc 6 | import skimage.transform as st 7 | 8 | import torch 9 | from torchvision import transforms 10 | 11 | def get_patch(img_in, img_tar, patch_size, scale, multi_scale=False): 12 | ih, iw = img_in.shape[:2] 13 | 14 | p = scale if multi_scale else 1 15 | tp = p * patch_size 16 | ip = tp // scale 17 | 18 | ix = random.randrange(0, iw - ip + 1) 19 | iy = random.randrange(0, ih - ip + 1) 20 | tx, ty = scale * ix, scale * iy 21 | 22 | img_in = img_in[iy:iy + ip, ix:ix + ip, :] 23 | img_tar = img_tar[ty:ty + tp, tx:tx + tp, :] 24 | 25 | return img_in, img_tar 26 | 27 | def set_channel(l, n_channel): 28 | def _set_channel(img): 29 | if img.ndim == 2: 30 | img = np.expand_dims(img, axis=2) 31 | 32 | c = img.shape[2] 33 | if n_channel == 1 and c == 3: 34 | img = np.expand_dims(sc.rgb2ycbcr(img)[:, :, 0], 2) 35 | elif n_channel == 3 and c == 1: 36 | img = np.concatenate([img] * n_channel, 2) 37 | 38 | return img 39 | 40 | return [_set_channel(_l) for _l in l] 41 | 42 | def np2Tensor(l, rgb_range): 43 | def _np2Tensor(img): 44 | np_transpose = np.ascontiguousarray(img.transpose((2, 0, 1))) 45 | tensor = torch.from_numpy(np_transpose).float() 46 | tensor.mul_(rgb_range / 255) 47 | 48 | return tensor 49 | 50 | return [_np2Tensor(_l) for _l in l] 51 | 52 | def add_noise(x, noise='.'): 53 | if noise is not '.': 54 | noise_type = noise[0] 55 | noise_value = int(noise[1:]) 56 | if noise_type == 'G': 57 | noises = np.random.normal(scale=noise_value, size=x.shape) 58 | noises = noises.round() 59 | elif noise_type == 'S': 60 | noises = np.random.poisson(x * noise_value) / noise_value 61 | noises = noises - noises.mean(axis=0).mean(axis=0) 62 | 63 | x_noise = x.astype(np.int16) + noises.astype(np.int16) 64 | x_noise = x_noise.clip(0, 255).astype(np.uint8) 65 | return x_noise 66 | else: 67 | return x 68 | 69 | def augment(l, hflip=True, rot=True): 70 | hflip = hflip and random.random() < 0.5 71 | vflip = rot and random.random() < 0.5 72 | rot90 = rot and random.random() < 0.5 73 | 74 | def _augment(img): 75 | if hflip: img = img[:, ::-1, :] 76 | if vflip: img = img[::-1, :, :] 77 | if rot90: img = img.transpose(1, 0, 2) 78 | 79 | return img 80 | 81 | return [_augment(_l) for _l in l] 82 | -------------------------------------------------------------------------------- /code/data/demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class Demo(data.Dataset): 12 | def __init__(self, args, train=False): 13 | self.args = args 14 | self.name = 'Demo' 15 | self.scale = args.scale 16 | self.idx_scale = 0 17 | self.train = False 18 | self.benchmark = False 19 | 20 | self.filelist = [] 21 | for f in os.listdir(args.dir_demo): 22 | if f.find('.png') >= 0 or f.find('.jp') >= 0: 23 | self.filelist.append(os.path.join(args.dir_demo, f)) 24 | self.filelist.sort() 25 | 26 | def __getitem__(self, idx): 27 | filename = os.path.split(self.filelist[idx])[-1] 28 | filename, _ = os.path.splitext(filename) 29 | lr = misc.imread(self.filelist[idx]) 30 | lr = common.set_channel([lr], self.args.n_colors)[0] 31 | 32 | return common.np2Tensor([lr], self.args.rgb_range)[0], -1, filename 33 | 34 | def __len__(self): 35 | return len(self.filelist) 36 | 37 | def set_scale(self, idx_scale): 38 | self.idx_scale = idx_scale 39 | 40 | -------------------------------------------------------------------------------- /code/data/div2k.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | 12 | class DIV2K(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(DIV2K, self).__init__(args, train) 15 | self.repeat = args.test_every // (args.n_train // args.batch_size) 16 | 17 | def _scan(self): 18 | list_hr = [] 19 | list_lr = [[] for _ in self.scale] 20 | if self.train: 21 | idx_begin = 0 22 | idx_end = self.args.n_train 23 | else: 24 | idx_begin = self.args.n_train 25 | idx_end = self.args.offset_val + self.args.n_val 26 | 27 | for i in range(idx_begin + 1, idx_end + 1): 28 | filename = '{:0>4}'.format(i) 29 | list_hr.append(os.path.join(self.dir_hr, filename + self.ext)) 30 | for si, s in enumerate(self.scale): 31 | list_lr[si].append(os.path.join( 32 | self.dir_lr, 33 | 'X{}/{}x{}{}'.format(s, filename, s, self.ext) 34 | )) 35 | 36 | return list_hr, list_lr 37 | 38 | def _set_filesystem(self, dir_data): 39 | self.apath = dir_data + '/DIV2K' 40 | self.dir_hr = os.path.join(self.apath, 'DIV2K_train_HR') 41 | self.dir_lr = os.path.join(self.apath, 'DIV2K_train_LR_bicubic') 42 | self.ext = '.png' 43 | 44 | def _name_hrbin(self): 45 | return os.path.join( 46 | self.apath, 47 | 'bin', 48 | '{}_bin_HR.npy'.format(self.split) 49 | ) 50 | 51 | def _name_lrbin(self, scale): 52 | return os.path.join( 53 | self.apath, 54 | 'bin', 55 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 56 | ) 57 | 58 | def __len__(self): 59 | if self.train: 60 | return len(self.images_hr) * self.repeat 61 | else: 62 | return len(self.images_hr) 63 | 64 | def _get_index(self, idx): 65 | if self.train: 66 | return idx % len(self.images_hr) 67 | else: 68 | return idx 69 | 70 | -------------------------------------------------------------------------------- /code/data/div2ksub.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | from data import srdata 5 | 6 | import numpy as np 7 | import scipy.misc as misc 8 | 9 | import torch 10 | import torch.utils.data as data 11 | import glob 12 | class DIV2KSUB(srdata.SRData): 13 | def __init__(self, args, train=True): 14 | super(DIV2KSUB, self).__init__(args, train) 15 | self.repeat = max(round(args.test_every / (args.n_train / args.batch_size)), 1) 16 | self.n_train = args.n_train 17 | def _scan(self): 18 | list_hr = sorted(glob.glob(os.path.join(self.dir_hr, '*.png'))) 19 | list_lr = [sorted(glob.glob(os.path.join(self.dir_lr + '{}'.format(s), '*.png'))) for s in self.scale] 20 | #for si, s in enumerate(self.scale): 21 | 22 | return list_hr, list_lr 23 | 24 | def _set_filesystem(self, dir_data): 25 | self.apath = dir_data + '/DIV2K' 26 | self.dir_hr = os.path.join(self.apath, 'GT_sub') 27 | self.dir_lr = os.path.join(self.apath, 'GT_sub_bicLRx') 28 | self.ext = '.png' 29 | 30 | def _name_hrbin(self): 31 | return os.path.join( 32 | self.apath, 33 | 'bin', 34 | '{}_bin_HR.npy'.format(self.split) 35 | ) 36 | 37 | def _name_lrbin(self, scale): 38 | return os.path.join( 39 | self.apath, 40 | 'bin', 41 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 42 | ) 43 | 44 | def __len__(self): 45 | if self.train: 46 | return self.n_train * self.repeat #len(self.images_hr) * self.repeat 47 | else: 48 | return self.n_train #len(self.images_hr) 49 | 50 | def _get_index(self, idx): 51 | if self.train: 52 | return idx % self.n_train #len(self.images_hr) 53 | else: 54 | return idx 55 | 56 | 57 | # # block2 58 | # class VarBlockSimple(nn.Module): 59 | # ''' 60 | # regression block used for CARN 61 | # ''' 62 | # 63 | # def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1): 64 | # super(VarBlockSimple, self).__init__() 65 | # conv_mask = [conv(n_feats, 1, kernel_size=5), nn.PReLU(), conv(1, 1, kernel_size=5), reg_act] 66 | # conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()] 67 | # self.conv_mask = nn.Sequential(*conv_mask) 68 | # self.conv_body = nn.Sequential(*conv_body) 69 | # 70 | # def forward(self, x): 71 | # #x = torch.matmul(x, self.conv_mask(x)) 72 | # res = self.conv_body(self.conv_mask(x) * x) 73 | # x = res + x 74 | # return x 75 | # #block3 76 | # class VarBlockSimple(nn.Module): 77 | # ''' 78 | # regression block used for CARN 79 | # ''' 80 | # 81 | # def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1): 82 | # super(VarBlockSimple, self).__init__() 83 | # conv_mask = [conv(n_feats, 1, kernel_size=5), nn.PReLU(), conv(1, 1, kernel_size=5), reg_act] 84 | # conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()] 85 | # conv_tail = [conv(n_feats, n_feats, kernel_size), nn.PReLU()] 86 | # self.conv_mask = nn.Sequential(*conv_mask) 87 | # self.conv_body = nn.Sequential(*conv_body) 88 | # self.conv_tail = nn.Sequential(*conv_tail) 89 | # 90 | # def forward(self, x): 91 | # #x = torch.matmul(x, self.conv_mask(x)) 92 | # res = self.conv_body(self.conv_mask(x) * x) 93 | # x = res + self.conv_tail(x) 94 | # return x 95 | -------------------------------------------------------------------------------- /code/data/srdata.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from data import common 4 | 5 | import numpy as np 6 | import scipy.misc as misc 7 | 8 | import torch 9 | import torch.utils.data as data 10 | 11 | class SRData(data.Dataset): 12 | def __init__(self, args, train=True, benchmark=False): 13 | self.args = args 14 | self.train = train 15 | self.split = 'train' if train else 'test' 16 | self.benchmark = benchmark 17 | self.scale = args.scale 18 | self.idx_scale = 0 19 | self.color = 'RGB' if args.n_colors == 3 else 'Y' 20 | 21 | self._set_filesystem(args.dir_data) 22 | 23 | def _load_bin(): 24 | self.images_hr = np.load(self._name_hrbin()) 25 | self.images_lr = [ 26 | np.load(self._name_lrbin(s)) for s in self.scale 27 | ] 28 | 29 | if args.ext == 'img' or benchmark: 30 | self.images_hr, self.images_lr = self._scan() 31 | elif args.ext.find('sep') >= 0: 32 | self.images_hr, self.images_lr = self._scan() 33 | if args.ext.find('reset') >= 0: 34 | print('Preparing seperated binary files') 35 | for v in self.images_hr: 36 | 37 | hr = misc.imread(v) 38 | name_sep = v.replace(self.ext, '.npy') 39 | np.save(name_sep, hr) 40 | # from IPython import embed; embed(); exit() 41 | for si, s in enumerate(self.scale): 42 | for v in self.images_lr[si]: 43 | lr = misc.imread(v) 44 | name_sep = v.replace(self.ext, '.npy') 45 | np.save(name_sep, lr) 46 | 47 | self.images_hr = [ 48 | v.replace(self.ext, '.npy') for v in self.images_hr 49 | ] 50 | self.images_lr = [ 51 | [v.replace(self.ext, '.npy') for v in self.images_lr[i]] 52 | for i in range(len(self.scale)) 53 | ] 54 | # from IPython import embed; embed(); exit() 55 | elif args.ext.find('bin') >= 0: 56 | try: 57 | if args.ext.find('reset') >= 0: 58 | raise IOError 59 | print('Loading a binary file') 60 | _load_bin() 61 | except: 62 | print('Preparing a binary file') 63 | bin_path = os.path.join(self.apath, 'bin') 64 | if not os.path.isdir(bin_path): 65 | os.mkdir(bin_path) 66 | 67 | list_hr, list_lr = self._scan() 68 | hr = [misc.imread(f) for f in list_hr] 69 | np.save(self._name_hrbin(), hr) 70 | del hr 71 | for si, s in enumerate(self.scale): 72 | lr_scale = [misc.imread(f) for f in list_lr[si]] 73 | np.save(self._name_lrbin(s), lr_scale) 74 | del lr_scale 75 | _load_bin() 76 | else: 77 | print('Please define data type') 78 | 79 | def _scan(self): 80 | raise NotImplementedError 81 | 82 | def _set_filesystem(self, dir_data): 83 | raise NotImplementedError 84 | 85 | def _name_hrbin(self): 86 | raise NotImplementedError 87 | 88 | def _name_lrbin(self, scale): 89 | raise NotImplementedError 90 | 91 | def __getitem__(self, idx): 92 | lr, hr, filename = self._load_file(idx) 93 | lr, hr = self._get_patch(lr, hr) 94 | lr, hr = common.set_channel([lr, hr], self.args.n_colors) 95 | lr_tensor, hr_tensor = common.np2Tensor([lr, hr], self.args.rgb_range) 96 | return lr_tensor, hr_tensor, filename 97 | 98 | def __len__(self): 99 | return len(self.images_hr) if not self.benchmark else len(self.images_hr[0]) 100 | 101 | def _get_index(self, idx): 102 | return idx 103 | 104 | def _load_file(self, idx): 105 | idx = self._get_index(idx) 106 | # from IPython import embed; embed() 107 | lr = self.images_lr[self.idx_scale][idx] 108 | hr = self.images_hr[idx] if not self.benchmark else self.images_hr[self.idx_scale][idx] 109 | if self.args.ext == 'img' or self.benchmark: 110 | filename = hr 111 | lr = misc.imread(lr) 112 | hr = misc.imread(hr) 113 | elif self.args.ext.find('sep') >= 0: 114 | filename = hr 115 | lr = np.load(lr) 116 | hr = np.load(hr) 117 | else: 118 | filename = str(idx + 1) 119 | 120 | filename = os.path.splitext(os.path.split(filename)[-1])[0] 121 | 122 | return lr, hr, filename 123 | 124 | def _get_patch(self, lr, hr): 125 | patch_size = self.args.patch_size 126 | scale = self.scale[self.idx_scale] 127 | multi_scale = len(self.scale) > 1 128 | if self.train: 129 | #from IPython import embed; embed(); exit() 130 | lr, hr = common.get_patch( 131 | lr, hr, patch_size, scale, multi_scale=multi_scale 132 | ) 133 | lr, hr = common.augment([lr, hr]) 134 | lr = common.add_noise(lr, self.args.noise) 135 | else: 136 | ih, iw = lr.shape[0:2] 137 | hr = hr[0:ih * scale, 0:iw * scale] 138 | 139 | return lr, hr 140 | 141 | def set_scale(self, idx_scale): 142 | self.idx_scale = idx_scale 143 | 144 | -------------------------------------------------------------------------------- /code/data/srtexture.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import os 4 | 5 | from data import common 6 | 7 | import numpy as np 8 | import scipy.misc as misc 9 | 10 | import torch 11 | import torch.utils.data as data 12 | import random 13 | 14 | class SRData(data.Dataset): 15 | def __init__(self, args, train=True, benchmark=False): 16 | self.args = args 17 | self.train = train 18 | self.split = 'train' if train else 'test' 19 | self.benchmark = benchmark 20 | self.scale = args.scale 21 | self.idx_scale = 0 22 | self.color = 'RGB' if args.n_colors == 3 else 'Y' 23 | self.model_one = args.model_one == 'one' 24 | self.model_flag = args.model 25 | self.data_train = args.data_train 26 | self.subset = args.subset 27 | self.normal_lr = args.normal_lr == 'lr' 28 | self.input_res = args.input_res 29 | self._set_filesystem(args.dir_data) 30 | 31 | def _load_bin(): 32 | self.images_hr = np.load(self._name_hrbin()) 33 | self.images_lr = [ 34 | np.load(self._name_lrbin(s)) for s in self.scale 35 | ] 36 | 37 | if args.ext == 'img' or benchmark: 38 | self.images_hr, self.images_lr, self.normals_lr, self.masks_lr = self._scan() 39 | elif args.ext.find('sep') >= 0: 40 | self.images_hr, self.images_lr, self.normals_lr, self.masks_lr = self._scan() 41 | if args.ext.find('reset') >= 0: 42 | print('Preparing seperated binary files') 43 | for v in self.images_hr: 44 | 45 | hr = misc.imread(v) 46 | name_sep = v.replace(self.ext, '.npy') 47 | np.save(name_sep, hr) 48 | # from IPython import embed; embed(); exit() 49 | for si, s in enumerate(self.scale): 50 | for n in range(len(self.images_lr[si])): 51 | v = self.images_lr[si][n] 52 | #from IPython import embed; embed(); exit() 53 | lr = misc.imread(v) 54 | name_sep = v.replace(self.ext, '.npy') 55 | np.save(name_sep, lr) 56 | 57 | v = self.normals_lr[si][n] 58 | lr = misc.imread(v) 59 | name_sep = v.replace(self.ext, '.npy') 60 | np.save(name_sep, lr) 61 | 62 | v = self.masks_lr[si][n] 63 | lr = np.expand_dims(misc.imread(v), 2) 64 | name_sep = v.replace(self.ext, '.npy') 65 | np.save(name_sep, lr) 66 | # from IPython import embed; embed(); exit() 67 | self.images_hr = [ 68 | v.replace(self.ext, '.npy') for v in self.images_hr 69 | ] 70 | self.images_lr = [ 71 | [v.replace(self.ext, '.npy') for v in self.images_lr[i]] 72 | for i in range(len(self.scale)) 73 | ] 74 | self.normals_lr = [ 75 | [v.replace(self.ext, '.npy') for v in self.normals_lr[i]] 76 | for i in range(len(self.scale)) 77 | ] 78 | self.masks_lr = [ 79 | [v.replace(self.ext, '.npy') for v in self.masks_lr[i]] 80 | for i in range(len(self.scale)) 81 | ] 82 | # from IPython import embed; embed(); exit() 83 | elif args.ext.find('bin') >= 0: 84 | try: 85 | if args.ext.find('reset') >= 0: 86 | raise IOError 87 | print('Loading a binary file') 88 | _load_bin() 89 | except: 90 | print('Preparing a binary file') 91 | bin_path = os.path.join(self.apath, 'bin') 92 | if not os.path.isdir(bin_path): 93 | os.mkdir(bin_path) 94 | 95 | list_hr, list_lr = self._scan() 96 | hr = [misc.imread(f) for f in list_hr] 97 | np.save(self._name_hrbin(), hr) 98 | del hr 99 | for si, s in enumerate(self.scale): 100 | lr_scale = [misc.imread(f) for f in list_lr[si]] 101 | np.save(self._name_lrbin(s), lr_scale) 102 | del lr_scale 103 | _load_bin() 104 | else: 105 | print('Please define data type') 106 | # from IPython import embed; embed(); exit() 107 | def _scan(self): 108 | raise NotImplementedError 109 | 110 | def _set_filesystem(self, dir_data): 111 | raise NotImplementedError 112 | 113 | def _name_hrbin(self): 114 | raise NotImplementedError 115 | 116 | def _name_lrbin(self, scale): 117 | raise NotImplementedError 118 | 119 | def __getitem__(self, idx): 120 | lr, nl, mk, hr, filename = self._load_file(idx) 121 | lr, nl, mk, hr = self._get_patch(lr, nl, mk, hr) 122 | lr, hr = common.set_channel([lr, hr], self.args.n_colors) 123 | #print('The size of lr, hr images are {}, {}'.format(lr.shape, hr.shape)) 124 | lr_tensor, nl_tensor, mk_tensor, hr_tensor = common.np2Tensor([lr, nl, mk, hr], self.args.rgb_range) 125 | # if self.model_flag.lower() == 'finetune': 126 | return lr_tensor, nl_tensor, mk_tensor, hr_tensor, filename 127 | # else: 128 | # return lr_tensor, hr_tensor, filename 129 | 130 | def __len__(self): 131 | return len(self.images_hr) 132 | 133 | def _get_index(self, idx): 134 | return idx 135 | 136 | def _load_file(self, idx): 137 | idx = self._get_index(idx) 138 | # from IPython import embed; embed() 139 | # print('The hr images are {}'.format(self.images_hr)) 140 | lr = self.images_lr[self.idx_scale][idx] 141 | nl = self.normals_lr[self.idx_scale][idx] 142 | mk = self.masks_lr[self.idx_scale][idx] 143 | hr = self.images_hr[idx] 144 | 145 | # print(self.images_hr) 146 | # print('......................................................') 147 | # print(hr) 148 | if self.args.ext == 'img' or self.benchmark: 149 | filename = hr 150 | lr = misc.imread(lr) 151 | nl = misc.imread(nl) 152 | mk = misc.imread(mk) 153 | hr = misc.imread(hr) 154 | elif self.args.ext.find('sep') >= 0: 155 | filename = hr 156 | lr = np.load(lr) 157 | nl = np.load(nl) 158 | mk = np.load(mk) 159 | hr = np.load(hr) 160 | 161 | else: 162 | filename = str(idx + 1) 163 | #print('The resolution of lr, hr images are {}, {}'.format(lr.shape, hr.shape)) 164 | filename = os.path.splitext(os.path.split(filename)[-1])[0] 165 | #print('the filename is {}'.format(filename)) 166 | return lr, nl, mk, hr, filename 167 | 168 | def _get_patch(self, lr, nl, mk, hr): 169 | patch_size = self.args.patch_size 170 | scale = self.scale[self.idx_scale] 171 | multi_scale = len(self.scale) > 1 172 | if self.train: 173 | #from IPython import embed; embed(); exit() 174 | lr, nl, mk, hr = self.get_patch( 175 | lr, nl, mk, hr, patch_size, scale, multi_scale=multi_scale 176 | ) 177 | lr, nl, mk, hr = common.augment([lr, nl, mk, hr]) 178 | lr = common.add_noise(lr, self.args.noise) 179 | else: 180 | ih, iw = lr.shape[0:2] 181 | hr = hr[0:ih * scale, 0:iw * scale] 182 | if not self.normal_lr: 183 | nl = nl[0:ih * scale, 0:iw * scale] 184 | 185 | return lr, nl, mk, hr 186 | 187 | def get_patch(self, img_in, nml_in, msk_in, img_tar, patch_size, scale, multi_scale=False): 188 | ih, iw = img_in.shape[:2] 189 | 190 | p = scale if multi_scale else 1 191 | tp = p * patch_size 192 | ip = tp // scale 193 | 194 | ix = random.randrange(0, iw - ip + 1) 195 | iy = random.randrange(0, ih - ip + 1) 196 | tx, ty = scale * ix, scale * iy 197 | if self.input_res == 'lr': 198 | img_in = img_in[iy:iy + ip, ix:ix + ip, :] 199 | msk_in = msk_in[iy:iy + ip, ix:ix + ip] 200 | img_tar = img_tar[ty:ty + tp, tx:tx + tp, :] 201 | if self.normal_lr: 202 | nml_in = nml_in[iy:iy + ip, ix:ix + ip, :] 203 | else: 204 | nml_in = nml_in[ty:ty + tp, tx:tx + tp, :] 205 | else: 206 | img_in = img_in[iy:iy + ip, ix:ix + ip, :] 207 | msk_in = msk_in[iy:iy + ip, ix:ix + ip] 208 | img_tar = img_tar[iy:iy + ip, ix:ix + ip, :] 209 | nml_in = nml_in[iy:iy + ip, ix:ix + ip, :] 210 | 211 | return img_in, nml_in, msk_in, img_tar 212 | 213 | def set_scale(self, idx_scale): 214 | self.idx_scale = idx_scale 215 | 216 | -------------------------------------------------------------------------------- /code/data/texture.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Yawei Li' 2 | 3 | import os 4 | 5 | from data import common 6 | from data import srtexture 7 | 8 | import numpy as np 9 | import scipy.misc as misc 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import glob 14 | 15 | class texture(srtexture.SRData): 16 | def __init__(self, args, train=True): 17 | super(texture, self).__init__(args, train) 18 | 19 | if self.subset == '.': 20 | self.num_all = 24 21 | self.num_split = 12 22 | if train: 23 | self.repeat = args.test_every // (self.num_split // args.batch_size) 24 | else: 25 | self.num_all = self.all[self.subset_idx[0]] 26 | if (self.train and self.model_one) or (not self.train and not self.model_one): 27 | self.num_split = self.split[self.subset_idx[0]] 28 | else: 29 | self.num_split = self.num_all - self.split[self.subset_idx[0]] 30 | 31 | def _scan(self): 32 | list_hr = [] 33 | list_lr = [[] for _ in self.scale] 34 | list_lr_normal = [[] for _ in self.scale] 35 | list_lr_mask = [[] for _ in self.scale] 36 | self.subset_idx = list(range(4)) if self.subset == '.' else [self.set.index(self.subset)] 37 | for s in self.subset_idx: 38 | dir_hr = os.path.join(self.apath, self.set[s], 'x1/Texture/*.png') 39 | list_hr_set = sorted(glob.glob(dir_hr)) 40 | if self.data_train == 'texture': 41 | if (self.train and self.model_one) or (not self.train and not self.model_one): 42 | list_hr_split = list_hr_set[:self.split[s]] 43 | list_hr += list_hr_split 44 | else: 45 | list_hr_split = list_hr_set[self.split[s]:] 46 | list_hr += list_hr_split 47 | else: 48 | list_hr += list_hr_set 49 | for si, s in enumerate(self.scale): 50 | list_lr[si] = [n.replace('x1', 'x{}'.format(s)) for n in list_hr] 51 | if self.normal_lr: 52 | # from IPython import embed; embed(); exit() 53 | list_lr_normal[si] = [n.replace('x1', 'x{}'.format(s)).replace('Texture/', 'normal/').replace('Texture.png', 'normal.png') for n in list_hr] 54 | else: 55 | list_lr_normal[si] = [n.replace('Texture/', 'normal/').replace('Texture.png', 'normal.png') for n in list_hr] 56 | list_lr_mask[si] = [n.replace('x1', 'x{}'.format(s)).replace('Texture/', 'mask/').replace('Texture.png', 'mask.png') for n in list_hr] 57 | # from IPython import embed; embed(); exit() 58 | return list_hr, list_lr, list_lr_normal, list_lr_mask 59 | 60 | def _set_filesystem(self, dir_data): 61 | self.apath = dir_data + 'Texture_map' 62 | self.set = ['MiddleBury', 'ETH3D', 'Collection', 'SyB3R'] 63 | self.ext = '.png' 64 | self.split = [1, 6, 3, 2] #[1, 7, 3, 1] 65 | self.all = [2, 13, 6, 3] 66 | # if self.model_one: 67 | # self.split = [1, 6, 3, 2] 68 | 69 | def _name_hrbin(self): 70 | return os.path.join( 71 | self.apath, 72 | 'bin', 73 | '{}_bin_HR.npy'.format(self.split) 74 | ) 75 | 76 | def _name_lrbin(self, scale): 77 | return os.path.join( 78 | self.apath, 79 | 'bin', 80 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 81 | ) 82 | 83 | def __len__(self): 84 | if self.train: 85 | if self.subset == '.': 86 | return self.num_split * self.repeat 87 | else: 88 | return 2400 89 | else: 90 | if self.data_train == 'texture': 91 | return self.num_split 92 | else: 93 | return self.num_all 94 | 95 | def _get_index(self, idx): 96 | if self.train: 97 | return idx % self.num_split 98 | else: 99 | return idx 100 | 101 | 102 | -------------------------------------------------------------------------------- /code/data/texture_hr.py: -------------------------------------------------------------------------------- 1 | __author__ = 'Yawei Li' 2 | 3 | import os 4 | 5 | from data import common 6 | from data import srtexture 7 | 8 | import numpy as np 9 | import scipy.misc as misc 10 | 11 | import torch 12 | import torch.utils.data as data 13 | import glob 14 | 15 | class texture_hr(srtexture.SRData): 16 | def __init__(self, args, train=True): 17 | super(texture_hr, self).__init__(args, train) 18 | 19 | if self.subset == '.': 20 | self.num_all = 8 21 | self.num_split = 4 22 | #self.repeat = args.test_every // (self.num_split // args.batch_size) 23 | else: 24 | self.num_all = self.all[self.subset_idx[0]] 25 | if (self.train and self.model_one) or (not self.train and not self.model_one): 26 | self.num_split = self.split[self.subset_idx[0]] 27 | else: 28 | self.num_split = self.num_all - self.split[self.subset_idx[0]] 29 | 30 | def _scan(self): 31 | list_hr = [] 32 | list_lr = [[] for _ in self.scale] 33 | list_lr_normal = [[] for _ in self.scale] 34 | list_lr_mask = [[] for _ in self.scale] 35 | self.subset_idx = [0, 2] if self.subset == '.' else [self.set.index(self.subset)] 36 | for s in self.subset_idx: 37 | dir_hr = os.path.join(self.apath, self.set[s], 'x1/Texture/*.png') 38 | list_hr_set = sorted(glob.glob(dir_hr)) 39 | if self.data_train == 'texture_hr': 40 | if (self.train and self.model_one) or (not self.train and not self.model_one): 41 | list_hr_split = list_hr_set[:self.split[s]] 42 | list_hr += list_hr_split 43 | else: 44 | list_hr_split = list_hr_set[self.split[s]:] 45 | list_hr += list_hr_split 46 | else: 47 | list_hr += list_hr_set 48 | for si, s in enumerate(self.scale): 49 | list_lr[si] = [n.replace('/scratch_net/ofsoundof/yawli/Datasets/texture_map', '/home/yawli/Documents/3d-appearance-benchmark/SR/texture').replace('x1/Texture', 'x{}'.format(s)) for n in list_hr] 50 | #list_lr[si] = [n.replace('MiddleBury', 'ColMapMiddlebury') if 'MiddleBury' in n else n for n in list_lr[si]] 51 | list_lr_normal[si] = [n.replace('Texture', 'normal') for n in list_hr] 52 | list_lr_mask[si] = [n.replace('Texture', 'mask') for n in list_hr] 53 | #from IPython import embed; embed(); exit() 54 | return list_hr, list_lr, list_lr_normal, list_lr_mask 55 | 56 | def _set_filesystem(self, dir_data): 57 | self.apath = dir_data + '/texture_map' 58 | self.set = ['MiddleBury', 'ETH3D', 'Collection', 'SyB3R'] 59 | self.ext = '.png' 60 | self.split = [1, 6, 3, 2] #[1, 7, 3, 1] 61 | self.all = [2, 13, 6, 3] 62 | # if self.model_one: 63 | # self.split = [1, 6, 3, 2] 64 | 65 | def _name_hrbin(self): 66 | return os.path.join( 67 | self.apath, 68 | 'bin', 69 | '{}_bin_HR.npy'.format(self.split) 70 | ) 71 | 72 | def _name_lrbin(self, scale): 73 | return os.path.join( 74 | self.apath, 75 | 'bin', 76 | '{}_bin_LR_X{}.npy'.format(self.split, scale) 77 | ) 78 | 79 | def __len__(self): 80 | if self.train: 81 | if self.subset == '.': 82 | return 2400 83 | else: 84 | return 2400 85 | else: 86 | if self.data_train == 'texture_hr': 87 | return self.num_split 88 | else: 89 | return self.num_all 90 | 91 | def _get_index(self, idx): 92 | if self.train: 93 | return idx % self.num_split 94 | else: 95 | return idx 96 | 97 | 98 | -------------------------------------------------------------------------------- /code/dataloader.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | import queue 4 | import random 5 | import collections 6 | 7 | import torch 8 | import torch.multiprocessing as multiprocessing 9 | 10 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 11 | _remove_worker_pids, _error_if_any_worker_fails 12 | from torch.utils.data.dataloader import DataLoader 13 | from torch.utils.data.dataloader import _DataLoaderIter 14 | #from IPython import embed; embed(); exit() 15 | from torch.utils.data.dataloader import ExceptionWrapper 16 | from torch.utils.data.dataloader import _use_shared_memory 17 | from torch.utils.data.dataloader import _worker_manager_loop 18 | from torch.utils.data.dataloader import numpy_type_map 19 | from torch.utils.data.dataloader import default_collate 20 | from torch.utils.data.dataloader import pin_memory_batch 21 | from torch.utils.data.dataloader import _SIGCHLD_handler_set 22 | from torch.utils.data.dataloader import _set_SIGCHLD_handler 23 | 24 | if sys.version_info[0] == 2: 25 | import Queue as queue 26 | else: 27 | import queue 28 | 29 | def _ms_loop(dataset, index_queue, data_queue, collate_fn, scale, seed, init_fn, worker_id): 30 | global _use_shared_memory 31 | _use_shared_memory = True 32 | _set_worker_signal_handlers() 33 | 34 | torch.set_num_threads(1) 35 | torch.manual_seed(seed) 36 | while True: 37 | r = index_queue.get() 38 | if r is None: 39 | break 40 | idx, batch_indices = r 41 | try: 42 | idx_scale = 0 43 | if len(scale) > 1 and dataset.train: 44 | idx_scale = random.randrange(0, len(scale)) 45 | dataset.set_scale(idx_scale) 46 | 47 | samples = collate_fn([dataset[i] for i in batch_indices]) 48 | samples.append(idx_scale) 49 | #This is why idx_scale appears in the samples of the train loader 50 | 51 | except Exception: 52 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 53 | else: 54 | data_queue.put((idx, samples)) 55 | 56 | class _MSDataLoaderIter(_DataLoaderIter): 57 | def __init__(self, loader): 58 | # from IPython import embed; embed(); exit() 59 | self.dataset = loader.dataset 60 | self.scale = loader.scale 61 | self.collate_fn = loader.collate_fn 62 | self.batch_sampler = loader.batch_sampler 63 | self.num_workers = loader.num_workers 64 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 65 | self.timeout = loader.timeout 66 | self.done_event = threading.Event() 67 | 68 | self.sample_iter = iter(self.batch_sampler) 69 | 70 | if self.num_workers > 0: 71 | self.worker_init_fn = loader.worker_init_fn 72 | self.index_queues = [ 73 | multiprocessing.Queue() for _ in range(self.num_workers) 74 | ] 75 | self.worker_queue_idx = 0 76 | self.worker_result_queue = multiprocessing.SimpleQueue() 77 | self.batches_outstanding = 0 78 | self.worker_pids_set = False 79 | self.shutdown = False 80 | self.send_idx = 0 81 | self.rcvd_idx = 0 82 | self.reorder_dict = {} 83 | 84 | base_seed = torch.LongTensor(1).random_()[0] 85 | self.workers = [ 86 | multiprocessing.Process( 87 | target=_ms_loop, 88 | args=( 89 | self.dataset, 90 | self.index_queues[i], 91 | self.worker_result_queue, 92 | self.collate_fn, 93 | self.scale, 94 | base_seed + i, 95 | self.worker_init_fn, 96 | i 97 | ) 98 | ) 99 | for i in range(self.num_workers)] 100 | 101 | if self.pin_memory or self.timeout > 0: 102 | self.data_queue = queue.Queue() 103 | if self.pin_memory: 104 | maybe_device_id = torch.cuda.current_device() 105 | else: 106 | # do not initialize cuda context if not necessary 107 | maybe_device_id = None 108 | self.worker_manager_thread = threading.Thread( 109 | target=_worker_manager_loop, 110 | args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, 111 | maybe_device_id)) 112 | self.worker_manager_thread.daemon = True 113 | self.worker_manager_thread.start() 114 | else: 115 | self.data_queue = self.worker_result_queue 116 | 117 | for w in self.workers: 118 | w.daemon = True # ensure that the worker exits on process exit 119 | w.start() 120 | 121 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 122 | _set_SIGCHLD_handler() 123 | self.worker_pids_set = True 124 | 125 | # prime the prefetch loop 126 | for _ in range(2 * self.num_workers): 127 | self._put_indices() 128 | 129 | class MSDataLoader(DataLoader): 130 | def __init__( 131 | self, args, dataset, batch_size=1, shuffle=False, 132 | sampler=None, batch_sampler=None, 133 | collate_fn=default_collate, pin_memory=False, drop_last=False, 134 | timeout=0, worker_init_fn=None): 135 | 136 | super(MSDataLoader, self).__init__( 137 | dataset, batch_size=batch_size, shuffle=shuffle, 138 | sampler=sampler, batch_sampler=batch_sampler, 139 | num_workers=args.n_threads, collate_fn=collate_fn, 140 | pin_memory=pin_memory, drop_last=drop_last, 141 | timeout=timeout, worker_init_fn=worker_init_fn) 142 | 143 | self.scale = args.scale 144 | 145 | def __iter__(self): 146 | return _MSDataLoaderIter(self) 147 | -------------------------------------------------------------------------------- /code/dataloader_new.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import threading 3 | import queue 4 | import random 5 | import collections 6 | 7 | import torch 8 | import torch.multiprocessing as multiprocessing 9 | 10 | from torch._C import _set_worker_signal_handlers, _update_worker_pids, \ 11 | _remove_worker_pids, _error_if_any_worker_fails 12 | from torch.utils.data.dataloader import DataLoader 13 | from torch.utils.data.dataloader import _DataLoaderIter 14 | from torch.utils.data.dataloader import ExceptionWrapper 15 | from torch.utils.data.dataloader import _pin_memory_loop 16 | from torch.utils.data.dataloader import default_collate 17 | from torch.utils.data.dataloader import _set_SIGCHLD_handler 18 | from torch.utils.data.dataloader import ManagerWatchdog 19 | from torch.utils.data.dataloader import MP_STATUS_CHECK_INTERVAL 20 | 21 | 22 | 23 | if sys.version_info[0] == 2: 24 | import Queue as queue 25 | else: 26 | import queue 27 | 28 | def _ms_loop(dataset, index_queue, data_queue, done_event, collate_fn, scale, seed, init_fn, worker_id): 29 | try: 30 | global _use_shared_memory 31 | _use_shared_memory = True 32 | 33 | _set_worker_signal_handlers() 34 | 35 | torch.set_num_threads(1) 36 | random.seed(seed) 37 | torch.manual_seed(seed) 38 | 39 | data_queue.cancel_join_thread() 40 | 41 | if init_fn is not None: 42 | init_fn(worker_id) 43 | 44 | watchdog = ManagerWatchdog() 45 | 46 | while watchdog.is_alive(): 47 | try: 48 | r = index_queue.get(timeout=MP_STATUS_CHECK_INTERVAL) 49 | except queue.Empty: 50 | continue 51 | if r is None: 52 | # Received the final signal 53 | assert done_event.is_set() 54 | return 55 | elif done_event.is_set(): 56 | # Done event is set. But I haven't received the final signal 57 | # (None) yet. I will keep continuing until get it, and skip the 58 | # processing steps. 59 | continue 60 | idx, batch_indices = r 61 | try: 62 | idx_scale = 0 63 | if len(scale) > 1 and dataset.train: 64 | idx_scale = random.randrange(0, len(scale)) 65 | dataset.set_scale(idx_scale) 66 | 67 | samples = collate_fn([dataset[i] for i in batch_indices]) 68 | samples.append(idx_scale) 69 | #This is why idx_scale appears in the samples of the train loader 70 | 71 | except Exception: 72 | # It is important that we don't store exc_info in a variable, 73 | # see NOTE [ Python Traceback Reference Cycle Problem ] 74 | data_queue.put((idx, ExceptionWrapper(sys.exc_info()))) 75 | else: 76 | data_queue.put((idx, samples)) 77 | del samples 78 | except KeyboardInterrupt: 79 | # Main process will raise KeyboardInterrupt anyways. 80 | pass 81 | 82 | 83 | class _MSDataLoaderIter(_DataLoaderIter): 84 | def __init__(self, loader): 85 | self.dataset = loader.dataset 86 | self.collate_fn = loader.collate_fn 87 | self.scale = loader.scale 88 | self.batch_sampler = loader.batch_sampler 89 | self.num_workers = loader.num_workers 90 | self.pin_memory = loader.pin_memory and torch.cuda.is_available() 91 | self.timeout = loader.timeout 92 | 93 | self.sample_iter = iter(self.batch_sampler) 94 | 95 | base_seed = torch.LongTensor(1).random_().item() 96 | 97 | if self.num_workers > 0: 98 | self.worker_init_fn = loader.worker_init_fn 99 | self.worker_queue_idx = 0 100 | self.worker_result_queue = multiprocessing.Queue() 101 | self.batches_outstanding = 0 102 | self.worker_pids_set = False 103 | self.shutdown = False 104 | self.send_idx = 0 105 | self.rcvd_idx = 0 106 | self.reorder_dict = {} 107 | self.done_event = multiprocessing.Event() 108 | 109 | self.index_queues = [] 110 | self.workers = [] 111 | for i in range(self.num_workers): 112 | index_queue = multiprocessing.Queue() 113 | index_queue.cancel_join_thread() 114 | w = multiprocessing.Process( 115 | target=_ms_loop, 116 | args=(self.dataset, index_queue, 117 | self.worker_result_queue, self.done_event, 118 | self.collate_fn, self.scale, base_seed + i, 119 | self.worker_init_fn, i)) 120 | w.daemon = True 121 | # NB: Process.start() actually take some time as it needs to 122 | # start a process and pass the arguments over via a pipe. 123 | # Therefore, we only add a worker to self.workers list after 124 | # it started, so that we do not call .join() if program dies 125 | # before it starts, and __del__ tries to join but will get: 126 | # AssertionError: can only join a started process. 127 | w.start() 128 | self.index_queues.append(index_queue) 129 | self.workers.append(w) 130 | 131 | if self.pin_memory: 132 | self.data_queue = queue.Queue() 133 | pin_memory_thread = threading.Thread( 134 | target=_pin_memory_loop, 135 | args=(self.worker_result_queue, self.data_queue, 136 | torch.cuda.current_device(), self.done_event)) 137 | pin_memory_thread.daemon = True 138 | pin_memory_thread.start() 139 | # Similar to workers (see comment above), we only register 140 | # pin_memory_thread once it is started. 141 | self.pin_memory_thread = pin_memory_thread 142 | else: 143 | self.data_queue = self.worker_result_queue 144 | 145 | _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) 146 | _set_SIGCHLD_handler() 147 | self.worker_pids_set = True 148 | 149 | # prime the prefetch loop 150 | for _ in range(2 * self.num_workers): 151 | self._put_indices() 152 | 153 | 154 | class MSDataLoader(DataLoader): 155 | def __init__( 156 | self, args, dataset, batch_size=1, shuffle=False, 157 | sampler=None, batch_sampler=None, num_workers=0, 158 | collate_fn=default_collate, pin_memory=False, drop_last=False, 159 | timeout=0, worker_init_fn=None): 160 | 161 | super(MSDataLoader, self).__init__( 162 | dataset, batch_size=batch_size, shuffle=shuffle, 163 | sampler=sampler, batch_sampler=batch_sampler, 164 | num_workers=num_workers, collate_fn=collate_fn, 165 | pin_memory=pin_memory, drop_last=drop_last, 166 | timeout=timeout, worker_init_fn=worker_init_fn) 167 | 168 | self.scale = args.scale 169 | 170 | def __iter__(self): 171 | return _MSDataLoaderIter(self) 172 | -------------------------------------------------------------------------------- /code/loss/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import matplotlib 5 | matplotlib.use('Agg') 6 | import matplotlib.pyplot as plt 7 | 8 | import numpy as np 9 | 10 | import torch 11 | import torch.nn as nn 12 | import torch.nn.functional as F 13 | 14 | class Loss(nn.modules.loss._Loss): 15 | def __init__(self, args, ckp): 16 | super(Loss, self).__init__() 17 | print('Preparing loss function:') 18 | 19 | self.n_GPUs = args.n_GPUs 20 | self.loss = [] 21 | self.loss_module = nn.ModuleList() 22 | for loss in args.loss.split('+'): 23 | weight, loss_type = loss.split('*') 24 | if loss_type == 'MSE': 25 | loss_function = nn.MSELoss() 26 | elif loss_type == 'L1': 27 | loss_function = nn.L1Loss() 28 | elif loss_type.find('VGG') >= 0: 29 | module = import_module('loss.vgg') 30 | loss_function = getattr(module, 'VGG')( 31 | loss_type[3:], 32 | rgb_range=args.rgb_range 33 | ) 34 | elif loss_type.find('GAN') >= 0: 35 | module = import_module('loss.adversarial') 36 | loss_function = getattr(module, 'Adversarial')( 37 | args, 38 | loss_type 39 | ) 40 | 41 | self.loss.append({ 42 | 'type': loss_type, 43 | 'weight': float(weight), 44 | 'function': loss_function} 45 | ) 46 | if loss_type.find('GAN') >= 0: 47 | self.loss.append({'type': 'DIS', 'weight': 1, 'function': None}) 48 | 49 | if len(self.loss) > 1: 50 | self.loss.append({'type': 'Total', 'weight': 0, 'function': None}) 51 | 52 | for l in self.loss: 53 | if l['function'] is not None: 54 | print('{:.3f} * {}'.format(l['weight'], l['type'])) 55 | self.loss_module.append(l['function']) 56 | 57 | self.log = torch.Tensor() 58 | 59 | device = torch.device('cpu' if args.cpu else 'cuda') 60 | self.loss_module.to(device) 61 | if args.precision == 'half': self.loss_module.half() 62 | if not args.cpu and args.n_GPUs > 1: 63 | self.loss_module = nn.DataParallel( 64 | self.loss_module, range(args.n_GPUs) 65 | ) 66 | 67 | if args.load != '.': self.load(ckp.dir, cpu=args.cpu) 68 | 69 | def forward(self, sr, hr): 70 | losses = [] 71 | for i, l in enumerate(self.loss): 72 | if l['function'] is not None: 73 | loss = l['function'](sr, hr) 74 | effective_loss = l['weight'] * loss 75 | losses.append(effective_loss) 76 | self.log[-1, i] += effective_loss.item() 77 | elif l['type'] == 'DIS': 78 | self.log[-1, i] += self.loss[i - 1]['function'].loss 79 | 80 | loss_sum = sum(losses) 81 | if len(self.loss) > 1: 82 | self.log[-1, -1] += loss_sum.item() 83 | 84 | return loss_sum 85 | 86 | def step(self): 87 | for l in self.get_loss_module(): 88 | if hasattr(l, 'scheduler'): 89 | l.scheduler.step() 90 | 91 | def start_log(self): 92 | self.log = torch.cat((self.log, torch.zeros(1, len(self.loss)))) 93 | 94 | def end_log(self, n_batches): 95 | self.log[-1].div_(n_batches) 96 | 97 | def display_loss(self, batch): 98 | n_samples = batch + 1 99 | log = [] 100 | for l, c in zip(self.loss, self.log[-1]): 101 | log.append('[{}: {:.4f}]'.format(l['type'], c / n_samples)) 102 | 103 | return ''.join(log) 104 | 105 | def plot_loss(self, apath, epoch): 106 | axis = np.linspace(1, epoch, epoch) 107 | for i, l in enumerate(self.loss): 108 | label = '{} Loss'.format(l['type']) 109 | fig = plt.figure() 110 | plt.title(label) 111 | plt.plot(axis, self.log[:, i].numpy(), label=label) 112 | plt.legend() 113 | plt.xlabel('Epochs') 114 | plt.ylabel('Loss') 115 | plt.grid(True) 116 | plt.savefig('{}/loss_{}.pdf'.format(apath, l['type'])) 117 | plt.close(fig) 118 | 119 | def get_loss_module(self): 120 | if self.n_GPUs == 1: 121 | return self.loss_module 122 | else: 123 | return self.loss_module.module 124 | 125 | def save(self, apath): 126 | torch.save(self.state_dict(), os.path.join(apath, 'loss.pt')) 127 | torch.save(self.log, os.path.join(apath, 'loss_log.pt')) 128 | 129 | def load(self, apath, cpu=False): 130 | if cpu: 131 | kwargs = {'map_location': lambda storage, loc: storage} 132 | else: 133 | kwargs = {} 134 | 135 | self.load_state_dict(torch.load( 136 | os.path.join(apath, 'loss.pt'), 137 | **kwargs 138 | )) 139 | self.log = torch.load(os.path.join(apath, 'loss_log.pt')) 140 | for l in self.loss_module: 141 | if hasattr(l, 'scheduler'): 142 | for _ in range(len(self.log)): l.scheduler.step() 143 | 144 | -------------------------------------------------------------------------------- /code/loss/adversarial.py: -------------------------------------------------------------------------------- 1 | import utility 2 | from model import common 3 | from loss import discriminator 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import torch.optim as optim 9 | from torch.autograd import Variable 10 | 11 | class Adversarial(nn.Module): 12 | def __init__(self, args, gan_type): 13 | super(Adversarial, self).__init__() 14 | self.gan_type = gan_type 15 | self.gan_k = args.gan_k 16 | self.discriminator = discriminator.Discriminator(args, gan_type) 17 | if gan_type != 'WGAN_GP': 18 | self.optimizer = utility.make_optimizer(args, self.discriminator) 19 | else: 20 | self.optimizer = optim.Adam( 21 | self.discriminator.parameters(), 22 | betas=(0, 0.9), eps=1e-8, lr=1e-5 23 | ) 24 | self.scheduler = utility.make_scheduler(args, self.optimizer) 25 | 26 | def forward(self, fake, real): 27 | fake_detach = fake.detach() 28 | 29 | self.loss = 0 30 | for _ in range(self.gan_k): 31 | self.optimizer.zero_grad() 32 | d_fake = self.discriminator(fake_detach) 33 | d_real = self.discriminator(real) 34 | if self.gan_type == 'GAN': 35 | label_fake = torch.zeros_like(d_fake) 36 | label_real = torch.ones_like(d_real) 37 | loss_d \ 38 | = F.binary_cross_entropy_with_logits(d_fake, label_fake) \ 39 | + F.binary_cross_entropy_with_logits(d_real, label_real) 40 | elif self.gan_type.find('WGAN') >= 0: 41 | loss_d = (d_fake - d_real).mean() 42 | if self.gan_type.find('GP') >= 0: 43 | epsilon = torch.rand_like(fake).view(-1, 1, 1, 1) 44 | hat = fake_detach.mul(1 - epsilon) + real.mul(epsilon) 45 | hat.requires_grad = True 46 | d_hat = self.discriminator(hat) 47 | gradients = torch.autograd.grad( 48 | outputs=d_hat.sum(), inputs=hat, 49 | retain_graph=True, create_graph=True, only_inputs=True 50 | )[0] 51 | gradients = gradients.view(gradients.size(0), -1) 52 | gradient_norm = gradients.norm(2, dim=1) 53 | gradient_penalty = 10 * gradient_norm.sub(1).pow(2).mean() 54 | loss_d += gradient_penalty 55 | 56 | # Discriminator update 57 | self.loss += loss_d.item() 58 | loss_d.backward() 59 | self.optimizer.step() 60 | 61 | if self.gan_type == 'WGAN': 62 | for p in self.discriminator.parameters(): 63 | p.data.clamp_(-1, 1) 64 | 65 | self.loss /= self.gan_k 66 | 67 | d_fake_for_g = self.discriminator(fake) 68 | if self.gan_type == 'GAN': 69 | loss_g = F.binary_cross_entropy_with_logits( 70 | d_fake_for_g, label_real 71 | ) 72 | elif self.gan_type.find('WGAN') >= 0: 73 | loss_g = -d_fake_for_g.mean() 74 | 75 | # Generator loss 76 | return loss_g 77 | 78 | def state_dict(self, *args, **kwargs): 79 | state_discriminator = self.discriminator.state_dict(*args, **kwargs) 80 | state_optimizer = self.optimizer.state_dict() 81 | 82 | return dict(**state_discriminator, **state_optimizer) 83 | 84 | # Some references 85 | # https://github.com/kuc2477/pytorch-wgan-gp/blob/master/model.py 86 | # OR 87 | # https://github.com/caogang/wgan-gp/blob/master/gan_cifar10.py 88 | -------------------------------------------------------------------------------- /code/loss/discriminator.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | class Discriminator(nn.Module): 6 | def __init__(self, args, gan_type='GAN'): 7 | super(Discriminator, self).__init__() 8 | 9 | in_channels = 3 10 | out_channels = 64 11 | depth = 7 12 | #bn = not gan_type == 'WGAN_GP' 13 | bn = True 14 | act = nn.LeakyReLU(negative_slope=0.2, inplace=True) 15 | 16 | m_features = [ 17 | common.BasicBlock(args.n_colors, out_channels, 3, bn=bn, act=act) 18 | ] 19 | for i in range(depth): 20 | in_channels = out_channels 21 | if i % 2 == 1: 22 | stride = 1 23 | out_channels *= 2 24 | else: 25 | stride = 2 26 | m_features.append(common.BasicBlock( 27 | in_channels, out_channels, 3, stride=stride, bn=bn, act=act 28 | )) 29 | 30 | self.features = nn.Sequential(*m_features) 31 | 32 | patch_size = args.patch_size // (2**((depth + 1) // 2)) 33 | m_classifier = [ 34 | nn.Linear(out_channels * patch_size**2, 1024), 35 | act, 36 | nn.Linear(1024, 1) 37 | ] 38 | self.classifier = nn.Sequential(*m_classifier) 39 | 40 | def forward(self, x): 41 | features = self.features(x) 42 | output = self.classifier(features.view(features.size(0), -1)) 43 | 44 | return output 45 | 46 | -------------------------------------------------------------------------------- /code/loss/vgg.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision.models as models 7 | from torch.autograd import Variable 8 | 9 | class VGG(nn.Module): 10 | def __init__(self, conv_index, rgb_range=1): 11 | super(VGG, self).__init__() 12 | vgg_features = models.vgg19(pretrained=True).features 13 | modules = [m for m in vgg_features] 14 | if conv_index == '22': 15 | self.vgg = nn.Sequential(*modules[:8]) 16 | elif conv_index == '54': 17 | self.vgg = nn.Sequential(*modules[:35]) 18 | 19 | vgg_mean = (0.485, 0.456, 0.406) 20 | vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range) 21 | self.sub_mean = common.MeanShift(rgb_range, vgg_mean, vgg_std) 22 | self.vgg.requires_grad = False 23 | 24 | def forward(self, sr, hr): 25 | def _forward(x): 26 | x = self.sub_mean(x) 27 | x = self.vgg(x) 28 | return x 29 | 30 | vgg_sr = _forward(sr) 31 | with torch.no_grad(): 32 | vgg_hr = _forward(hr.detach()) 33 | 34 | loss = F.mse_loss(vgg_sr, vgg_hr) 35 | 36 | return loss 37 | -------------------------------------------------------------------------------- /code/main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import utility 4 | import data 5 | import model 6 | import loss 7 | from option import args 8 | from trainer_finetune import TrainerFT 9 | 10 | torch.manual_seed(args.seed) 11 | checkpoint = utility.checkpoint(args) 12 | 13 | if checkpoint.ok: 14 | loader = data.Data(args) 15 | model = model.Model(args, checkpoint) 16 | # from IPython import embed; embed(); exit() 17 | loss = loss.Loss(args, checkpoint) if not args.test_only else None 18 | t = TrainerFT(args, loader, model, loss, checkpoint) 19 | while not t.terminate(): 20 | t.train() 21 | t.test() 22 | 23 | checkpoint.done() 24 | 25 | -------------------------------------------------------------------------------- /code/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | from importlib import import_module 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.autograd import Variable 7 | 8 | class Model(nn.Module): 9 | def __init__(self, args, ckp): 10 | super(Model, self).__init__() 11 | print('Making model...') 12 | 13 | self.scale = args.scale 14 | self.idx_scale = 0 15 | self.self_ensemble = args.self_ensemble 16 | self.chop = args.chop 17 | self.precision = args.precision 18 | self.cpu = args.cpu 19 | self.device = torch.device('cpu' if args.cpu else 'cuda') 20 | self.n_GPUs = args.n_GPUs 21 | self.save_models = args.save_models 22 | self.model_flag = args.model 23 | self.normal_lr = args.normal_lr == 'lr' 24 | self.input_res = args.input_res 25 | module = import_module('model.' + args.model.lower()) 26 | # from IPython import embed; embed(); exit() 27 | self.model = module.make_model(args).to(self.device) 28 | if args.precision == 'half': self.model.half() 29 | 30 | if not args.cpu and args.n_GPUs > 1: 31 | self.model = nn.DataParallel(self.model, range(args.n_GPUs)) 32 | print(ckp.dir) 33 | self.load( 34 | ckp.dir, 35 | pre_train=args.pre_train, 36 | resume=args.resume, 37 | cpu=args.cpu 38 | ) 39 | if args.print_model: print(self.model) 40 | 41 | def forward(self, idx_scale, *x): 42 | self.idx_scale = idx_scale 43 | target = self.get_model() 44 | if hasattr(target, 'set_scale'): 45 | target.set_scale(idx_scale) 46 | # from IPython import embed; embed(); exit() 47 | 48 | if self.self_ensemble and not self.training: 49 | if self.chop: 50 | forward_function = self.forward_chop 51 | else: 52 | forward_function = self.model.forward 53 | 54 | return self.forward_x8(x, forward_function) 55 | elif self.chop and not self.training: 56 | return self.forward_chop(x) 57 | else: 58 | if self.model_flag.lower() == 'finetune': 59 | return self.model(x) 60 | else: 61 | return self.model(x[0]) 62 | 63 | def get_model(self): 64 | if self.n_GPUs == 1: 65 | return self.model 66 | else: 67 | return self.model.module 68 | 69 | def state_dict(self, **kwargs): 70 | target = self.get_model() 71 | return target.state_dict(**kwargs) 72 | 73 | def save(self, apath, epoch, is_best=False): 74 | target = self.get_model() 75 | torch.save( 76 | target.state_dict(), 77 | os.path.join(apath, 'model', 'model_latest.pt') 78 | ) 79 | if is_best: 80 | torch.save( 81 | target.state_dict(), 82 | os.path.join(apath, 'model', 'model_best.pt') 83 | ) 84 | 85 | if self.save_models: 86 | torch.save( 87 | target.state_dict(), 88 | os.path.join(apath, 'model', 'model_{}.pt'.format(epoch)) 89 | ) 90 | 91 | def load(self, apath, pre_train='.', resume=-1, cpu=False): 92 | if cpu: 93 | kwargs = {'map_location': lambda storage, loc: storage} 94 | else: 95 | kwargs = {} 96 | print("The resume flag is {}".format(resume)) 97 | if resume == -1: 98 | self.get_model().load_state_dict( 99 | torch.load( 100 | os.path.join(apath, 'model', 'model_latest.pt'), 101 | **kwargs 102 | ), 103 | strict=False 104 | ) 105 | elif resume == 0: 106 | if pre_train != '.': 107 | print('Loading model from {}'.format(pre_train)) 108 | self.get_model().load_state_dict( 109 | torch.load(pre_train, **kwargs), 110 | strict=False 111 | ) 112 | else: 113 | self.get_model().load_state_dict( 114 | torch.load( 115 | os.path.join(apath, 'model', 'model_{}.pt'.format(resume)), 116 | **kwargs 117 | ), 118 | strict=False 119 | ) 120 | 121 | def forward_chop(self, x, shave=10, min_size=160000): 122 | scale = self.scale[self.idx_scale] 123 | n_GPUs = min(self.n_GPUs, 4) 124 | # from IPython import embed; embed(); 125 | # if len(x[0].size()) == 3: 126 | # c, h, w = x[0].size() 127 | # else: 128 | b, c, h, w = x[0].size() 129 | h_half, w_half = h // 4 * 2, w // 4 * 2 130 | h_size, w_size = h_half + shave, w_half + shave 131 | 132 | def _patch_division(x, h_size, w_size, h, w): 133 | x_list = [ 134 | x[:, :, 0:h_size, 0:w_size], 135 | x[:, :, 0:h_size, (w - w_size):w], 136 | x[:, :, (h - h_size):h, 0:w_size], 137 | x[:, :, (h - h_size):h, (w - w_size):w]] 138 | if len(x_list[0].size()) == 3: 139 | x_list = [torch.unsqueeze(x, 0) for x in x_list] 140 | # from IPython import embed; embed(); 141 | return x_list 142 | 143 | 144 | if self.model_flag.lower() == 'finetune': 145 | lr_list = _patch_division(x[0], h_size, w_size, h, w) 146 | if self.normal_lr or self.input_res == 'hr': 147 | nl_list = _patch_division(x[1], h_size, w_size, h, w) 148 | else: 149 | nl_list = _patch_division(x[1], h_size * scale, w_size * scale, h * scale, w * scale) 150 | else: 151 | lr_list = _patch_division(x[0], h_size, w_size, h, w) 152 | # from IPython import embed; embed(); 153 | if w_size * h_size < min_size: 154 | sr_list = [] 155 | for i in range(0, 4, n_GPUs): 156 | lr_batch = torch.cat(lr_list[i:(i + n_GPUs)], dim=0) 157 | if self.model_flag.lower() == 'finetune': 158 | nl_batch = torch.cat(nl_list[i:(i + n_GPUs)], dim=0) 159 | sr_batch = self.model((lr_batch, nl_batch)) 160 | else: 161 | sr_batch = self.model(lr_batch) 162 | sr_list.extend(sr_batch.chunk(n_GPUs, dim=0)) 163 | else: 164 | if self.model_flag.lower() == 'finetune': 165 | sr_list = [ 166 | self.forward_chop(patch, shave=shave, min_size=min_size) \ 167 | for patch in zip(lr_list, nl_list) 168 | ] 169 | else: 170 | sr_list = [ 171 | self.forward_chop((patch,), shave=shave, min_size=min_size) \ 172 | for patch in lr_list 173 | ] 174 | if self.input_res == 'lr': 175 | h, w = scale * h, scale * w 176 | h_half, w_half = scale * h_half, scale * w_half 177 | h_size, w_size = scale * h_size, scale * w_size 178 | shave *= scale 179 | 180 | output = x[0].new(b, c, h, w)# if self.model_flag.lower() == 'finetune' else x.new(b, c, h, w) 181 | output[:, :, 0:h_half, 0:w_half] \ 182 | = sr_list[0][:, :, 0:h_half, 0:w_half] 183 | output[:, :, 0:h_half, w_half:w] \ 184 | = sr_list[1][:, :, 0:h_half, (w_size - w + w_half):w_size] 185 | output[:, :, h_half:h, 0:w_half] \ 186 | = sr_list[2][:, :, (h_size - h + h_half):h_size, 0:w_half] 187 | output[:, :, h_half:h, w_half:w] \ 188 | = sr_list[3][:, :, (h_size - h + h_half):h_size, (w_size - w + w_half):w_size] 189 | 190 | return output 191 | 192 | def forward_x8(self, x, forward_function): 193 | def _transform(v, op): 194 | if self.precision != 'single': v = v.float() 195 | 196 | v2np = v.data.cpu().numpy() 197 | if op == 'v': 198 | tfnp = v2np[:, :, :, ::-1].copy() 199 | elif op == 'h': 200 | tfnp = v2np[:, :, ::-1, :].copy() 201 | elif op == 't': 202 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 203 | 204 | ret = torch.Tensor(tfnp).to(self.device) 205 | if self.precision == 'half': ret = ret.half() 206 | 207 | return ret 208 | 209 | lr_list = [x] 210 | for tf in 'v', 'h', 't': 211 | lr_list.extend([_transform(t, tf) for t in lr_list]) 212 | 213 | sr_list = [forward_function(aug) for aug in lr_list] 214 | for i in range(len(sr_list)): 215 | if i > 3: 216 | sr_list[i] = _transform(sr_list[i], 't') 217 | if i % 4 > 1: 218 | sr_list[i] = _transform(sr_list[i], 'h') 219 | if (i % 4) % 2 == 1: 220 | sr_list[i] = _transform(sr_list[i], 'v') 221 | 222 | output_cat = torch.cat(sr_list, dim=0) 223 | output = output_cat.mean(dim=0, keepdim=True) 224 | 225 | return output 226 | -------------------------------------------------------------------------------- /code/model/carn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import model.ops as ops 4 | from model import common 5 | 6 | def make_model(args, parent=False): 7 | return CARN(args) 8 | 9 | class Block(nn.Module): 10 | def __init__(self, 11 | in_channels, out_channels, 12 | group=1): 13 | super(Block, self).__init__() 14 | 15 | self.b1 = ops.ResidualBlock(64, 64) 16 | self.b2 = ops.ResidualBlock(64, 64) 17 | self.b3 = ops.ResidualBlock(64, 64) 18 | self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) 19 | self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) 20 | self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) 21 | 22 | def forward(self, x): 23 | c0 = o0 = x 24 | 25 | b1 = self.b1(o0) 26 | c1 = torch.cat([c0, b1], dim=1) 27 | o1 = self.c1(c1) 28 | 29 | b2 = self.b2(o1) 30 | c2 = torch.cat([c1, b2], dim=1) 31 | o2 = self.c2(c2) 32 | 33 | b3 = self.b3(o2) 34 | c3 = torch.cat([c2, b3], dim=1) 35 | o3 = self.c3(c3) 36 | 37 | return o3 38 | 39 | 40 | class CARN(nn.Module): 41 | def __init__(self, args): 42 | super(CARN, self).__init__() 43 | 44 | #scale = kwargs.get("scale") 45 | #multi_scale = kwargs.get("multi_scale") 46 | #group = kwargs.get("group", 1) 47 | multi_scale = len(args.scale) > 1 48 | self.scale_idx = 0 49 | scale = args.scale[self.scale_idx] 50 | group = 1 51 | self.scale = args.scale 52 | rgb_mean = (0.4488, 0.4371, 0.4040) 53 | rgb_std = (1.0, 1.0, 1.0) 54 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 55 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 56 | #self.sub_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=True) 57 | #self.add_mean = ops.MeanShift((0.4488, 0.4371, 0.4040), sub=False) 58 | 59 | self.entry = nn.Conv2d(3, 64, 3, 1, 1) 60 | 61 | self.b1 = Block(64, 64) 62 | self.b2 = Block(64, 64) 63 | self.b3 = Block(64, 64) 64 | self.c1 = ops.BasicBlock(64*2, 64, 1, 1, 0) 65 | self.c2 = ops.BasicBlock(64*3, 64, 1, 1, 0) 66 | self.c3 = ops.BasicBlock(64*4, 64, 1, 1, 0) 67 | 68 | self.upsample = ops.UpsampleBlock(64, scale=scale, 69 | multi_scale=multi_scale, 70 | group=group) 71 | self.exit = nn.Conv2d(64, 3, 3, 1, 1) 72 | 73 | def forward(self, x): 74 | x = self.sub_mean(x) 75 | x = self.entry(x) 76 | c0 = o0 = x 77 | 78 | b1 = self.b1(o0) 79 | c1 = torch.cat([c0, b1], dim=1) 80 | o1 = self.c1(c1) 81 | 82 | b2 = self.b2(o1) 83 | c2 = torch.cat([c1, b2], dim=1) 84 | o2 = self.c2(c2) 85 | 86 | b3 = self.b3(o2) 87 | c3 = torch.cat([c2, b3], dim=1) 88 | o3 = self.c3(c3) 89 | 90 | scale = self.scale[self.scale_idx] 91 | out = self.upsample(o3, scale=scale) 92 | 93 | out = self.exit(out) 94 | out = self.add_mean(out) 95 | 96 | return out 97 | 98 | def set_scale(self, scale_idx): 99 | self.scale_idx = scale_idx 100 | -------------------------------------------------------------------------------- /code/model/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from torch.autograd import Variable 8 | 9 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 10 | return nn.Conv2d( 11 | in_channels, out_channels, kernel_size, 12 | padding=(kernel_size//2), bias=bias) 13 | 14 | def act_vconv(res_act): 15 | res_act = res_act.lower() 16 | if res_act == 'softplus': 17 | act = nn.Softplus() 18 | elif res_act == 'sigmoid': 19 | act = nn.Sigmoid() 20 | elif res_act == 'tanh': 21 | act = nn.Tanh() 22 | elif res_act == 'elu': 23 | act = nn.ELU() 24 | else: 25 | raise NotImplementedError 26 | return act 27 | 28 | class MeanShift(nn.Conv2d): 29 | def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1): 30 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 31 | std = torch.Tensor(rgb_std) 32 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) 33 | self.weight.data.div_(std.view(3, 1, 1, 1)) 34 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) 35 | self.bias.data.div_(std) 36 | self.requires_grad = False 37 | 38 | class GroupNorm(nn.Module): 39 | def __init__(self, num_features, num_groups=32, eps=1e-5): 40 | super(GroupNorm, self).__init__() 41 | self.weight = nn.Parameter(torch.ones(1,num_features,1,1)) 42 | self.bias = nn.Parameter(torch.zeros(1,num_features,1,1)) 43 | self.num_groups = num_groups 44 | self.eps = eps 45 | 46 | def forward(self, x): 47 | N,C,H,W = x.size() 48 | G = self.num_groups 49 | assert C % G == 0 50 | 51 | x = x.view(N,G,-1) 52 | mean = x.mean(-1, keepdim=True) 53 | var = x.var(-1, keepdim=True) 54 | 55 | x = (x-mean) / (var+self.eps).sqrt() 56 | x = x.view(N,C,H,W) 57 | return x * self.weight + self.bias 58 | 59 | class BasicBlock(nn.Sequential): 60 | def __init__( 61 | self, in_channels, out_channels, kernel_size, stride=1, bias=False, 62 | bn=True, act=nn.ReLU(True)): 63 | 64 | m = [nn.Conv2d( 65 | in_channels, out_channels, kernel_size, 66 | padding=(kernel_size//2), stride=stride, bias=bias) 67 | ] 68 | if bn: m.append(nn.BatchNorm2d(out_channels)) 69 | if act is not None: m.append(act) 70 | super(BasicBlock, self).__init__(*m) 71 | 72 | class ResBlock(nn.Module): 73 | def __init__( 74 | self, conv, n_feat, kernel_size, 75 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1, num_conv=2): 76 | 77 | super(ResBlock, self).__init__() 78 | m = [] 79 | for i in range(num_conv): 80 | m.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 81 | if bn: m.append(nn.BatchNorm2d(n_feat)) 82 | if i == 0: m.append(act) 83 | 84 | self.body = nn.Sequential(*m) 85 | self.res_scale = res_scale 86 | 87 | def forward(self, x): 88 | res = self.body(x).mul(self.res_scale) 89 | res += x 90 | 91 | return res 92 | 93 | class Upsampler(nn.Sequential): 94 | def __init__(self, conv, scale, n_feat, bn=False, act=False, bias=True): 95 | 96 | m = [] 97 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 98 | for _ in range(int(math.log(scale, 2))): 99 | m.append(conv(n_feat, 4 * n_feat, 3, bias)) 100 | m.append(nn.PixelShuffle(2)) 101 | if bn: m.append(nn.BatchNorm2d(n_feat)) 102 | if act: m.append(nn.PReLU()) 103 | elif scale == 3: 104 | m.append(conv(n_feat, 9 * n_feat, 3, bias)) 105 | m.append(nn.PixelShuffle(3)) 106 | if bn: m.append(nn.BatchNorm2d(n_feat)) 107 | if act: m.append(nn.PReLU()) 108 | else: 109 | raise NotImplementedError 110 | 111 | super(Upsampler, self).__init__(*m) 112 | -------------------------------------------------------------------------------- /code/model/ddbpn.py: -------------------------------------------------------------------------------- 1 | # Deep Back-Projection Networks For Super-Resolution 2 | # https://arxiv.org/abs/1803.02735 3 | 4 | from model import common 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | 10 | def make_model(args, parent=False): 11 | return DDBPN(args) 12 | 13 | def projection_conv(in_channels, out_channels, scale, up=True): 14 | kernel_size, stride, padding = { 15 | 2: (6, 2, 2), 16 | 4: (8, 4, 2), 17 | 8: (12, 8, 2) 18 | }[scale] 19 | if up: 20 | conv_f = nn.ConvTranspose2d 21 | else: 22 | conv_f = nn.Conv2d 23 | 24 | return conv_f( 25 | in_channels, out_channels, kernel_size, 26 | stride=stride, padding=padding 27 | ) 28 | 29 | class DenseProjection(nn.Module): 30 | def __init__(self, in_channels, nr, scale, up=True, bottleneck=True): 31 | super(DenseProjection, self).__init__() 32 | if bottleneck: 33 | self.bottleneck = nn.Sequential(*[ 34 | nn.Conv2d(in_channels, nr, 1), 35 | nn.PReLU(nr) 36 | ]) 37 | inter_channels = nr 38 | else: 39 | self.bottleneck = None 40 | inter_channels = in_channels 41 | 42 | self.conv_1 = nn.Sequential(*[ 43 | projection_conv(inter_channels, nr, scale, up), 44 | nn.PReLU(nr) 45 | ]) 46 | self.conv_2 = nn.Sequential(*[ 47 | projection_conv(nr, inter_channels, scale, not up), 48 | nn.PReLU(inter_channels) 49 | ]) 50 | self.conv_3 = nn.Sequential(*[ 51 | projection_conv(inter_channels, nr, scale, up), 52 | nn.PReLU(nr) 53 | ]) 54 | 55 | def forward(self, x): 56 | if self.bottleneck is not None: 57 | x = self.bottleneck(x) 58 | 59 | a_0 = self.conv_1(x) 60 | b_0 = self.conv_2(a_0) 61 | e = b_0.sub(x) 62 | a_1 = self.conv_3(e) 63 | 64 | out = a_0.add(a_1) 65 | 66 | return out 67 | 68 | class DDBPN(nn.Module): 69 | def __init__(self, args): 70 | super(DDBPN, self).__init__() 71 | scale = args.scale[0] 72 | 73 | n0 = 128 74 | nr = 32 75 | self.depth = 6 76 | 77 | rgb_mean = (0.4488, 0.4371, 0.4040) 78 | rgb_std = (1.0, 1.0, 1.0) 79 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 80 | initial = [ 81 | nn.Conv2d(args.n_colors, n0, 3, padding=1), 82 | nn.PReLU(n0), 83 | nn.Conv2d(n0, nr, 1), 84 | nn.PReLU(nr) 85 | ] 86 | self.initial = nn.Sequential(*initial) 87 | 88 | self.upmodules = nn.ModuleList() 89 | self.downmodules = nn.ModuleList() 90 | channels = nr 91 | for i in range(self.depth): 92 | self.upmodules.append( 93 | DenseProjection(channels, nr, scale, True, i > 1) 94 | ) 95 | if i != 0: 96 | channels += nr 97 | 98 | channels = nr 99 | for i in range(self.depth - 1): 100 | self.downmodules.append( 101 | DenseProjection(channels, nr, scale, False, i != 0) 102 | ) 103 | channels += nr 104 | 105 | reconstruction = [ 106 | nn.Conv2d(self.depth * nr, args.n_colors, 3, padding=1) 107 | ] 108 | self.reconstruction = nn.Sequential(*reconstruction) 109 | 110 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 111 | 112 | def forward(self, x): 113 | x = self.sub_mean(x) 114 | x = self.initial(x) 115 | 116 | h_list = [] 117 | l_list = [] 118 | for i in range(self.depth - 1): 119 | if i == 0: 120 | l = x 121 | else: 122 | l = torch.cat(l_list, dim=1) 123 | h_list.append(self.upmodules[i](l)) 124 | l_list.append(self.downmodules[i](torch.cat(h_list, dim=1))) 125 | 126 | h_list.append(self.upmodules[-1](torch.cat(l_list, dim=1))) 127 | out = self.reconstruction(torch.cat(h_list, dim=1)) 128 | out = self.add_mean(out) 129 | 130 | return out 131 | 132 | -------------------------------------------------------------------------------- /code/model/edsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return EDSR(args) 7 | 8 | class EDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(EDSR, self).__init__() 11 | 12 | n_resblock = args.n_resblocks 13 | n_feats = args.n_feats 14 | kernel_size = 3 15 | scale = args.scale[0] 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | # define head module 23 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 24 | 25 | # define body module 26 | m_body = [ 27 | common.ResBlock( 28 | conv, n_feats, kernel_size, act=act, res_scale=args.res_scale 29 | ) for _ in range(n_resblock) 30 | ] 31 | m_body.append(conv(n_feats, n_feats, kernel_size)) 32 | 33 | # define tail module 34 | m_tail = [ 35 | common.Upsampler(conv, scale, n_feats, act=False), 36 | conv(n_feats, args.n_colors, kernel_size) 37 | ] 38 | 39 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 40 | 41 | self.head = nn.Sequential(*m_head) 42 | self.body = nn.Sequential(*m_body) 43 | self.tail = nn.Sequential(*m_tail) 44 | # from IPython import embed; embed(); exit() 45 | 46 | def forward(self, x): 47 | x = self.sub_mean(x) 48 | x = self.head(x) 49 | 50 | res = self.body(x) 51 | res += x 52 | 53 | x = self.tail(res) 54 | x = self.add_mean(x) 55 | 56 | return x 57 | 58 | def load_state_dict(self, state_dict, strict=True): 59 | own_state = self.state_dict() 60 | for name, param in state_dict.items(): 61 | if name in own_state: 62 | if isinstance(param, nn.Parameter): 63 | param = param.data 64 | try: 65 | own_state[name].copy_(param) 66 | except Exception: 67 | if name.find('tail') == -1: 68 | raise RuntimeError('While copying the parameter named {}, ' 69 | 'whose dimensions in the model are {} and ' 70 | 'whose dimensions in the checkpoint are {}.' 71 | .format(name, own_state[name].size(), param.size())) 72 | elif strict: 73 | if name.find('tail') == -1: 74 | raise KeyError('unexpected key "{}" in state_dict' 75 | .format(name)) 76 | 77 | -------------------------------------------------------------------------------- /code/model/mdsr.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return MDSR(args) 7 | 8 | class MDSR(nn.Module): 9 | def __init__(self, args, conv=common.default_conv): 10 | super(MDSR, self).__init__() 11 | n_resblocks = args.n_resblocks 12 | n_feats = args.n_feats 13 | kernel_size = 3 14 | self.scale_idx = 0 15 | 16 | act = nn.ReLU(True) 17 | 18 | rgb_mean = (0.4488, 0.4371, 0.4040) 19 | rgb_std = (1.0, 1.0, 1.0) 20 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 21 | 22 | m_head = [conv(args.n_colors, n_feats, kernel_size)] 23 | 24 | self.pre_process = nn.ModuleList([ 25 | nn.Sequential( 26 | common.ResBlock(conv, n_feats, 5, act=act), 27 | common.ResBlock(conv, n_feats, 5, act=act) 28 | ) for _ in args.scale 29 | ]) 30 | 31 | m_body = [ 32 | common.ResBlock( 33 | conv, n_feats, kernel_size, act=act 34 | ) for _ in range(n_resblocks) 35 | ] 36 | m_body.append(conv(n_feats, n_feats, kernel_size)) 37 | 38 | self.upsample = nn.ModuleList([ 39 | common.Upsampler( 40 | conv, s, n_feats, act=False 41 | ) for s in args.scale 42 | ]) 43 | 44 | m_tail = [conv(n_feats, args.n_colors, kernel_size)] 45 | 46 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 47 | 48 | self.head = nn.Sequential(*m_head) 49 | self.body = nn.Sequential(*m_body) 50 | self.tail = nn.Sequential(*m_tail) 51 | 52 | def forward(self, x): 53 | x = self.sub_mean(x) 54 | x = self.head(x) 55 | x = self.pre_process[self.scale_idx](x) 56 | 57 | res = self.body(x) 58 | res += x 59 | 60 | x = self.upsample[self.scale_idx](res) 61 | x = self.tail(x) 62 | x = self.add_mean(x) 63 | 64 | return x 65 | 66 | def set_scale(self, scale_idx): 67 | self.scale_idx = scale_idx 68 | 69 | -------------------------------------------------------------------------------- /code/model/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | 7 | def init_weights(modules): 8 | pass 9 | 10 | 11 | class MeanShift(nn.Module): 12 | def __init__(self, mean_rgb, sub): 13 | super(MeanShift, self).__init__() 14 | 15 | sign = -1 if sub else 1 16 | r = mean_rgb[0] * sign 17 | g = mean_rgb[1] * sign 18 | b = mean_rgb[2] * sign 19 | 20 | self.shifter = nn.Conv2d(3, 3, 1, 1, 0) 21 | self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1) 22 | self.shifter.bias.data = torch.Tensor([r, g, b]) 23 | 24 | # Freeze the mean shift layer 25 | for params in self.shifter.parameters(): 26 | params.requires_grad = False 27 | 28 | def forward(self, x): 29 | x = self.shifter(x) 30 | return x 31 | 32 | 33 | class BasicBlock(nn.Module): 34 | def __init__(self, 35 | in_channels, out_channels, 36 | ksize=3, stride=1, pad=1): 37 | super(BasicBlock, self).__init__() 38 | 39 | self.body = nn.Sequential( 40 | nn.Conv2d(in_channels, out_channels, ksize, stride, pad), 41 | nn.ReLU(inplace=True) 42 | ) 43 | 44 | init_weights(self.modules) 45 | 46 | def forward(self, x): 47 | out = self.body(x) 48 | return out 49 | 50 | 51 | class ResidualBlock(nn.Module): 52 | def __init__(self, 53 | in_channels, out_channels): 54 | super(ResidualBlock, self).__init__() 55 | 56 | self.body = nn.Sequential( 57 | nn.Conv2d(in_channels, out_channels, 3, 1, 1), 58 | nn.ReLU(inplace=True), 59 | nn.Conv2d(out_channels, out_channels, 3, 1, 1), 60 | ) 61 | 62 | init_weights(self.modules) 63 | 64 | def forward(self, x): 65 | out = self.body(x) 66 | out = F.relu(out + x) 67 | return out 68 | 69 | 70 | class EResidualBlock(nn.Module): 71 | def __init__(self, 72 | in_channels, out_channels, 73 | group=1): 74 | super(EResidualBlock, self).__init__() 75 | 76 | self.body = nn.Sequential( 77 | nn.Conv2d(in_channels, out_channels, 3, 1, 1, groups=group), 78 | nn.ReLU(inplace=True), 79 | nn.Conv2d(out_channels, out_channels, 3, 1, 1, groups=group), 80 | nn.ReLU(inplace=True), 81 | nn.Conv2d(out_channels, out_channels, 1, 1, 0), 82 | ) 83 | 84 | init_weights(self.modules) 85 | 86 | def forward(self, x): 87 | out = self.body(x) 88 | out = F.relu(out + x) 89 | return out 90 | 91 | 92 | class UpsampleBlock(nn.Module): 93 | def __init__(self, 94 | n_channels, scale, multi_scale, 95 | group=1): 96 | super(UpsampleBlock, self).__init__() 97 | 98 | if multi_scale: 99 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 100 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 101 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 102 | else: 103 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 104 | 105 | self.multi_scale = multi_scale 106 | 107 | def forward(self, x, scale): 108 | if self.multi_scale: 109 | if scale == 2: 110 | return self.up2(x) 111 | elif scale == 3: 112 | return self.up3(x) 113 | elif scale == 4: 114 | return self.up4(x) 115 | else: 116 | return self.up(x) 117 | 118 | 119 | class _UpsampleBlock(nn.Module): 120 | def __init__(self, 121 | n_channels, scale, 122 | group=1): 123 | super(_UpsampleBlock, self).__init__() 124 | 125 | modules = [] 126 | if scale == 2 or scale == 4 or scale == 8: 127 | for _ in range(int(math.log(scale, 2))): 128 | modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 129 | modules += [nn.PixelShuffle(2)] 130 | elif scale == 3: 131 | modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.ReLU(inplace=True)] 132 | modules += [nn.PixelShuffle(3)] 133 | 134 | self.body = nn.Sequential(*modules) 135 | init_weights(self.modules) 136 | 137 | def forward(self, x): 138 | out = self.body(x) 139 | return out 140 | -------------------------------------------------------------------------------- /code/model/rcan.py: -------------------------------------------------------------------------------- 1 | from model import common 2 | 3 | import torch.nn as nn 4 | 5 | def make_model(args, parent=False): 6 | return RCAN(args) 7 | 8 | ## Channel Attention (CA) Layer 9 | class CALayer(nn.Module): 10 | def __init__(self, channel, reduction=16): 11 | super(CALayer, self).__init__() 12 | # global average pooling: feature --> point 13 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 14 | # feature channel downscale and upscale --> channel weight 15 | self.conv_du = nn.Sequential( 16 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 17 | nn.ReLU(inplace=True), 18 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 19 | nn.Sigmoid() 20 | ) 21 | 22 | def forward(self, x): 23 | y = self.avg_pool(x) 24 | y = self.conv_du(y) 25 | return x * y 26 | 27 | ## Residual Channel Attention Block (RCAB) 28 | class RCAB(nn.Module): 29 | def __init__( 30 | self, conv, n_feat, kernel_size, reduction, 31 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 32 | 33 | super(RCAB, self).__init__() 34 | modules_body = [] 35 | for i in range(2): 36 | modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) 37 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 38 | if i == 0: modules_body.append(act) 39 | modules_body.append(CALayer(n_feat, reduction)) 40 | self.body = nn.Sequential(*modules_body) 41 | self.res_scale = res_scale 42 | 43 | def forward(self, x): 44 | res = self.body(x) 45 | #res = self.body(x).mul(self.res_scale) 46 | res += x 47 | return res 48 | 49 | ## Residual Group (RG) 50 | class ResidualGroup(nn.Module): 51 | def __init__(self, conv, n_feat, kernel_size, reduction, act, res_scale, n_resblocks): 52 | super(ResidualGroup, self).__init__() 53 | modules_body = [] 54 | modules_body = [ 55 | RCAB( 56 | conv, n_feat, kernel_size, reduction, bias=True, bn=False, act=nn.ReLU(True), res_scale=1) \ 57 | for _ in range(n_resblocks)] 58 | modules_body.append(conv(n_feat, n_feat, kernel_size)) 59 | self.body = nn.Sequential(*modules_body) 60 | 61 | def forward(self, x): 62 | res = self.body(x) 63 | res += x 64 | return res 65 | 66 | ## Residual Channel Attention Network (RCAN) 67 | class RCAN(nn.Module): 68 | def __init__(self, args, conv=common.default_conv): 69 | super(RCAN, self).__init__() 70 | 71 | n_resgroups = args.n_resgroups 72 | n_resblocks = args.n_resblocks 73 | n_feats = args.n_feats 74 | kernel_size = 3 75 | reduction = args.reduction 76 | scale = args.scale[0] 77 | act = nn.ReLU(True) 78 | 79 | # RGB mean for DIV2K 80 | rgb_mean = (0.4488, 0.4371, 0.4040) 81 | rgb_std = (1.0, 1.0, 1.0) 82 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 83 | 84 | # define head module 85 | modules_head = [conv(args.n_colors, n_feats, kernel_size)] 86 | 87 | # define body module 88 | modules_body = [ 89 | ResidualGroup( 90 | conv, n_feats, kernel_size, reduction, act=act, res_scale=args.res_scale, n_resblocks=n_resblocks) \ 91 | for _ in range(n_resgroups)] 92 | 93 | modules_body.append(conv(n_feats, n_feats, kernel_size)) 94 | 95 | # define tail module 96 | modules_tail = [ 97 | common.Upsampler(conv, scale, n_feats, act=False), 98 | conv(n_feats, args.n_colors, kernel_size)] 99 | 100 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 101 | 102 | self.head = nn.Sequential(*modules_head) 103 | self.body = nn.Sequential(*modules_body) 104 | self.tail = nn.Sequential(*modules_tail) 105 | 106 | def forward(self, x): 107 | x = self.sub_mean(x) 108 | x = self.head(x) 109 | 110 | res = self.body(x) 111 | res += x 112 | 113 | x = self.tail(res) 114 | x = self.add_mean(x) 115 | 116 | return x 117 | 118 | def load_state_dict(self, state_dict, strict=False): 119 | own_state = self.state_dict() 120 | for name, param in state_dict.items(): 121 | if name in own_state: 122 | if isinstance(param, nn.Parameter): 123 | param = param.data 124 | try: 125 | own_state[name].copy_(param) 126 | except Exception: 127 | if name.find('tail') >= 0: 128 | print('Replace pre-trained upsampler to new one...') 129 | else: 130 | raise RuntimeError('While copying the parameter named {}, ' 131 | 'whose dimensions in the model are {} and ' 132 | 'whose dimensions in the checkpoint are {}.' 133 | .format(name, own_state[name].size(), param.size())) 134 | elif strict: 135 | if name.find('tail') == -1: 136 | raise KeyError('unexpected key "{}" in state_dict' 137 | .format(name)) 138 | 139 | if strict: 140 | missing = set(own_state.keys()) - set(state_dict.keys()) 141 | if len(missing) > 0: 142 | raise KeyError('missing keys in state_dict: "{}"'.format(missing)) -------------------------------------------------------------------------------- /code/model/srresnet.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | import math 3 | import torch.nn as nn 4 | import torch 5 | from model import common 6 | # from model import ops 7 | 8 | def make_model(args, parent=False): 9 | return SRResNet(args) 10 | 11 | def norm(norm_type, channel, group): 12 | if norm_type == 'batchnorm': 13 | norm = nn.BatchNorm2d(channel) 14 | elif norm_type == 'groupnorm': 15 | norm = nn.GroupNorm(group, channel) 16 | elif norm_type == 'instancenorm': 17 | norm = nn.InstanceNorm2d(channel) 18 | elif norm_type == 'instancenorm_affine': 19 | norm = nn.InstanceNorm2d(channel, affine=True) 20 | elif norm_type == 'layernorm': 21 | norm = nn.LayerNorm(channel) 22 | else: 23 | norm = None 24 | return norm 25 | 26 | class VarBlockSimple(nn.Module): 27 | 28 | def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1, norm_f=None): 29 | super(VarBlockSimple, self).__init__() 30 | if norm_f is not None: 31 | conv_mask = [norm_f, nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=kernel_size//2, groups=n_feats), reg_act] 32 | else: 33 | conv_mask = [nn.Conv2d(n_feats, n_feats, kernel_size=kernel_size, padding=kernel_size//2, groups=n_feats), reg_act] 34 | conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()] 35 | self.rescale = rescale 36 | self.conv_mask = nn.Sequential(*conv_mask) 37 | self.conv_body = nn.Sequential(*conv_body) 38 | 39 | def forward(self, x): 40 | res = self.conv_body(self.conv_mask(x) * x) 41 | x = res.mul(self.rescale) + x 42 | return x 43 | 44 | class JointAttention(nn.Module): 45 | 46 | def __init__(self, conv=common.default_conv, n_feats=64, kernel_size=3, reg_act=nn.Softplus(), rescale=1, norm_f=None): 47 | super(JointAttention, self).__init__() 48 | mask_conv = [nn.Conv2d(n_feats, 16, kernel_size=kernel_size, stride=4, padding=kernel_size//2), nn.PReLU()] 49 | mask_deconv = nn.ConvTranspose2d(16, n_feats, kernel_size=kernel_size, stride=4, padding=1) 50 | mask_deconv_act = nn.Softmax2d() 51 | conv_body = [conv(n_feats, n_feats, kernel_size), nn.PReLU()] 52 | self.mask_conv = nn.Sequential(*mask_conv) 53 | self.mask_deconv = mask_deconv 54 | self.mask_deconv_act = mask_deconv_act 55 | # self.ca = CALayer(n_feats) 56 | self.conv_body = nn.Sequential(*conv_body) 57 | 58 | def forward(self, x): 59 | mask = self.mask_deconv_act(self.mask_deconv(self.mask_conv(x), output_size=x.size())) 60 | res = mask * x 61 | # res = self.ca(res) 62 | res = self.conv_body(res) 63 | x = res + x 64 | return x 65 | 66 | class UpsampleBlock(nn.Module): 67 | def __init__(self, 68 | n_channels, scale, multi_scale, 69 | group=1): 70 | super(UpsampleBlock, self).__init__() 71 | 72 | if multi_scale: 73 | self.up2 = _UpsampleBlock(n_channels, scale=2, group=group) 74 | self.up3 = _UpsampleBlock(n_channels, scale=3, group=group) 75 | self.up4 = _UpsampleBlock(n_channels, scale=4, group=group) 76 | else: 77 | self.up = _UpsampleBlock(n_channels, scale=scale, group=group) 78 | 79 | self.multi_scale = multi_scale 80 | 81 | def forward(self, x, scale): 82 | if self.multi_scale: 83 | if scale == 2: 84 | return self.up2(x) 85 | elif scale == 3: 86 | return self.up3(x) 87 | elif scale == 4: 88 | return self.up4(x) 89 | else: 90 | return self.up(x) 91 | 92 | 93 | class _UpsampleBlock(nn.Module): 94 | def __init__(self, 95 | n_channels, scale, 96 | group=1): 97 | super(_UpsampleBlock, self).__init__() 98 | 99 | modules = [] 100 | if scale == 2 or scale == 4 or scale == 8: 101 | for _ in range(int(math.log(scale, 2))): 102 | modules += [nn.Conv2d(n_channels, 4*n_channels, 3, 1, 1, groups=group), nn.PReLU()] 103 | modules += [nn.PixelShuffle(2)] 104 | elif scale == 3: 105 | modules += [nn.Conv2d(n_channels, 9*n_channels, 3, 1, 1, groups=group), nn.PReLU()] 106 | modules += [nn.PixelShuffle(3)] 107 | 108 | self.body = nn.Sequential(*modules) 109 | 110 | def forward(self, x): 111 | out = self.body(x) 112 | return out 113 | 114 | 115 | class SRResNet(nn.Module): 116 | def __init__(self, args, conv=common.default_conv): 117 | super(SRResNet, self).__init__() 118 | 119 | n_resblocks = args.n_resblocks 120 | n_feats = args.n_feats 121 | kernel_size = 3 122 | # scale = args.scale[0] 123 | act = nn.PReLU() 124 | 125 | multi_scale = len(args.scale) > 1 126 | self.scale_idx = 0 127 | scale = args.scale[self.scale_idx] 128 | group = 1 129 | self.scale = args.scale 130 | 131 | rgb_mean = (0.4488, 0.4371, 0.4040) 132 | rgb_std = (1.0, 1.0, 1.0) 133 | 134 | norm_f = norm(args.norm_type, args.n_feats, args.n_groups) 135 | act_vconv = common.act_vconv(args.res_act) 136 | 137 | head = [conv(args.n_colors, n_feats, kernel_size), act] 138 | body_r = [JointAttention(conv, n_feats, kernel_size, reg_act=act_vconv, norm_f=norm_f, rescale=args.res_scale) 139 | for _ in range(n_resblocks)] 140 | #body_r = [common.ResBlock(conv, n_feats, kernel_size, bn=False, act=act, res_scale=args.res_scale, num_conv=2) 141 | # for _ in range(n_resblocks)] 142 | 143 | 144 | body_conv = [conv(n_feats, n_feats, kernel_size)] 145 | #body_conv = [conv(n_feats, n_feats, kernel_size), nn.BatchNorm2d(n_feats)] 146 | 147 | # tail = [ 148 | # common.Upsampler(conv, scale, n_feats, act=act), 149 | # conv(n_feats, args.n_colors, kernel_size) 150 | # ] 151 | 152 | tail = UpsampleBlock(n_feats, 153 | scale=scale, 154 | multi_scale=multi_scale, 155 | group=group) 156 | tail_conv = [conv(n_feats, args.n_colors, kernel_size)] 157 | 158 | self.sub_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std) 159 | self.head = nn.Sequential(*head) 160 | self.body_r = nn.Sequential(*body_r) 161 | self.body_conv = nn.Sequential(*body_conv) 162 | self.tail = tail 163 | self.tail_conv = nn.Sequential(*tail_conv) 164 | self.add_mean = common.MeanShift(args.rgb_range, rgb_mean, rgb_std, 1) 165 | 166 | def forward(self, x): 167 | x = self.sub_mean(x) 168 | x = self.head(x) 169 | f = self.body_r(x) 170 | f = self.body_conv(f) 171 | scale = self.scale[self.scale_idx] 172 | x = self.tail(x + f, scale) 173 | x = self.tail_conv(x) 174 | x = self.add_mean(x) 175 | return x 176 | 177 | def set_scale(self, scale_idx): 178 | self.scale_idx = scale_idx 179 | -------------------------------------------------------------------------------- /code/option.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import template 3 | 4 | parser = argparse.ArgumentParser(description='EDSR and MDSR') 5 | 6 | parser.add_argument('--debug', action='store_true', 7 | help='Enables debug mode') 8 | parser.add_argument('--template', default='.', 9 | help='You can set various templates in option.py') 10 | 11 | # Hardware specifications 12 | parser.add_argument('--n_threads', type=int, default=3, 13 | help='number of threads for data loading') 14 | parser.add_argument('--cpu', action='store_true', 15 | help='use cpu only') 16 | parser.add_argument('--n_GPUs', type=int, default=1, 17 | help='number of GPUs') 18 | parser.add_argument('--seed', type=int, default=1, 19 | help='random seed') 20 | 21 | # Data specifications 22 | parser.add_argument('--dir', type=str, default='/scratch_net/ofsoundof/yawli/logs_3d/', 23 | help='log directory') 24 | parser.add_argument('--dir_data', type=str, default='/scratch_net/ofsoundof/yawli/Datasets/', 25 | help='dataset directory') 26 | parser.add_argument('--dir_demo', type=str, default='../test', 27 | help='demo image directory') 28 | parser.add_argument('--data_train', type=str, default='DIV2K', 29 | help='train dataset name') 30 | parser.add_argument('--data_test', type=str, default='DIV2K', 31 | help='test dataset name') 32 | parser.add_argument('--benchmark_noise', action='store_true', 33 | help='use noisy benchmark sets') 34 | parser.add_argument('--n_train', type=int, default=800, 35 | help='number of training set') 36 | parser.add_argument('--n_val', type=int, default=5, 37 | help='number of validation set') 38 | parser.add_argument('--offset_val', type=int, default=800, 39 | help='validation index offest') 40 | parser.add_argument('--ext', type=str, default='sep', 41 | help='dataset file extension') 42 | parser.add_argument('--scale', default='4', 43 | help='super resolution scale') 44 | parser.add_argument('--patch_size', type=int, default=192, 45 | help='output patch size') 46 | parser.add_argument('--rgb_range', type=int, default=255, 47 | help='maximum value of RGB') 48 | parser.add_argument('--n_colors', type=int, default=3, 49 | help='number of color channels to use') 50 | parser.add_argument('--noise', type=str, default='.', 51 | help='Gaussian noise std.') 52 | parser.add_argument('--chop', action='store_true', 53 | help='enable memory-efficient forward') 54 | 55 | # Model specifications 56 | parser.add_argument('--model', default='RCAN', 57 | help='model name') 58 | 59 | parser.add_argument('--act', type=str, default='relu', 60 | help='activation function') 61 | parser.add_argument('--pre_train', type=str, default='.', 62 | help='pre-trained model directory') 63 | parser.add_argument('--extend', type=str, default='.', 64 | help='pre-trained model directory') 65 | parser.add_argument('--n_resblocks', type=int, default=5, 66 | help='number of residual blocks') 67 | parser.add_argument('--n_feats', type=int, default=64, 68 | help='number of feature maps') 69 | parser.add_argument('--res_scale', type=float, default=1, 70 | help='residual scaling') 71 | parser.add_argument('--shift_mean', default=True, 72 | help='subtract pixel mean from the input') 73 | parser.add_argument('--precision', type=str, default='single', 74 | choices=('single', 'half'), 75 | help='FP precision for test (single | half)') 76 | 77 | parser.add_argument('--res_act', default='SIGMOID', 78 | help='activation function is res block') 79 | parser.add_argument('--reg_anchor', type=int, default=16, 80 | help='number of anchors') 81 | parser.add_argument('--reg_out', type=int, default=16, 82 | help='number of channels in the regression block') 83 | parser.add_argument('--submodel', default='carn', 84 | help='submodel name') 85 | 86 | # Texture map SR specifications 87 | parser.add_argument('--n_resblocks_ft', type=int, default=2, 88 | help='number of resblocks used for finetuning') 89 | parser.add_argument('--model_one', default='one', 90 | help='used to split the dataset. We use cross validation for training and testing. The dataset is ' 91 | 'split into 2 equal groups. When one group is used for training, the other is used for testing.' 92 | 'The groups are form automatically according to the file system.') 93 | parser.add_argument('--subset', default='.', 94 | help='extract a subset of the whole dataset. Possible choices are ., ETH3D, MiddleBury, Collection,' 95 | 'SyB3R. The default choice . means the combination of all of the four subsets.') 96 | parser.add_argument('--normal_lr', default='hr', 97 | help='use hr or lr normal map. HR normal map is used for NHR while LR normal map is used for NLR.') 98 | parser.add_argument('--input_res', default='lr', 99 | help='use hr or lr input. HR input texture map is use for HRST-CNN while LR input texture map is ' 100 | 'used for the other methods.') 101 | 102 | parser.add_argument('--n_resunits', type=int, default=6, 103 | help='number of resunits used for level one residual') 104 | parser.add_argument('--norm_type', type=str, default='groupnorm', 105 | help='Normalization type') 106 | parser.add_argument('--n_groups', type=int, default=4, 107 | help='number of groups in group normalization') 108 | 109 | parser.add_argument('--data_test_texture_sr', type=str, default='DIV2K', 110 | help='super-resolve the input images for texture SR') 111 | 112 | 113 | # Training specifications 114 | parser.add_argument('--reset', action='store_true', 115 | help='reset the training') 116 | parser.add_argument('--test_every', type=int, default=1000, 117 | help='do test per every N batches') 118 | parser.add_argument('--epochs', type=int, default=1000, 119 | help='number of epochs to train') 120 | parser.add_argument('--batch_size', type=int, default=16, 121 | help='input batch size for training') 122 | parser.add_argument('--split_batch', type=int, default=1, 123 | help='split the batch into smaller chunks') 124 | parser.add_argument('--self_ensemble', action='store_true', 125 | help='use self-ensemble method for test') 126 | parser.add_argument('--test_only', action='store_true', 127 | help='set this option to test the model') 128 | parser.add_argument('--gan_k', type=int, default=1, 129 | help='k value for adversarial loss') 130 | 131 | # Optimization specifications 132 | parser.add_argument('--lr', type=float, default=1e-4, 133 | help='learning rate') 134 | parser.add_argument('--lr_decay', type=int, default=200, 135 | help='learning rate decay per N epochs') 136 | parser.add_argument('--decay_type', type=str, default='step', 137 | help='learning rate decay type') 138 | parser.add_argument('--gamma', type=float, default=0.5, 139 | help='learning rate decay factor for step decay') 140 | parser.add_argument('--optimizer', default='ADAM', 141 | choices=('SGD', 'ADAM', 'RMSprop'), 142 | help='optimizer to use (SGD | ADAM | RMSprop)') 143 | parser.add_argument('--momentum', type=float, default=0.9, 144 | help='SGD momentum') 145 | parser.add_argument('--beta1', type=float, default=0.9, 146 | help='ADAM beta1') 147 | parser.add_argument('--beta2', type=float, default=0.999, 148 | help='ADAM beta2') 149 | parser.add_argument('--epsilon', type=float, default=1e-8, 150 | help='ADAM epsilon for numerical stability') 151 | parser.add_argument('--weight_decay', type=float, default=0, 152 | help='weight decay') 153 | 154 | # Loss specifications 155 | parser.add_argument('--loss', type=str, default='1*L1', 156 | help='loss function configuration') 157 | parser.add_argument('--skip_threshold', type=float, default='1e6', 158 | help='skipping batch that has large error') 159 | 160 | # Log specifications 161 | parser.add_argument('--save', type=str, default='test', 162 | help='file name to save') 163 | parser.add_argument('--load', type=str, default='.', 164 | help='file name to load') 165 | parser.add_argument('--resume', type=int, default=0, 166 | help='resume from specific checkpoint') 167 | parser.add_argument('--print_model', action='store_true', 168 | help='print model') 169 | parser.add_argument('--save_models', action='store_true', 170 | help='save all intermediate models') 171 | parser.add_argument('--print_every', type=int, default=100, 172 | help='how many batches to wait before logging training status') 173 | parser.add_argument('--save_results', action='store_true', 174 | help='save output results') 175 | 176 | # options for residual group and feature channel reduction 177 | parser.add_argument('--n_resgroups', type=int, default=5, 178 | help='number of residual groups') 179 | parser.add_argument('--reduction', type=int, default=16, 180 | help='number of feature maps reduction') 181 | # options for test 182 | parser.add_argument('--testpath', type=str, default='../test/DIV2K_val_LR_our', 183 | help='dataset directory for testing') 184 | parser.add_argument('--testset', type=str, default='Set5', 185 | help='dataset name for testing') 186 | 187 | args = parser.parse_args() 188 | template.set_template(args) 189 | 190 | args.scale = list(map(lambda x: int(x), args.scale.split('+'))) 191 | 192 | if args.epochs == 0: 193 | args.epochs = 1e8 194 | 195 | for arg in vars(args): 196 | if vars(args)[arg] == 'True': 197 | vars(args)[arg] = True 198 | elif vars(args)[arg] == 'False': 199 | vars(args)[arg] = False 200 | 201 | -------------------------------------------------------------------------------- /code/scripts/3d_appearance_sr.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofsoundof/3D_Appearance_SR/91fc377466e2756c6cf753b7db48ef98e4ea13c2/code/scripts/3d_appearance_sr.pdf -------------------------------------------------------------------------------- /code/scripts/contribution.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofsoundof/3D_Appearance_SR/91fc377466e2756c6cf753b7db48ef98e4ea13c2/code/scripts/contribution.jpg -------------------------------------------------------------------------------- /code/scripts/demo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # explanation of some of the import options used. 3 | # --ext sep_reset: used to reset the dataset (generate .npy files). 4 | # --data_train: used to judge how to test model, namely, whether to use cross validation. For all of the models trained with texture or texture_hr, cross validation should be used. Only testing EDSR do not used cross validation. 5 | # --model_one: used to split the dataset. 6 | # --subset: used to choose from the Subset. Five options: . (means the combination of all of the subsets), ETH3D, MiddleBury, Collection, SyB3R. 7 | # --normal_lr: use hr or lr normal map 8 | # --input_res: HRST_CNN should use the HR input texture map 9 | 10 | 11 | ########################################################################## 12 | # TEST 13 | 14 | # NLR, cross-validation used, test for the first split 15 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NLR --save Test/NLR_first --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/NLR/model_x2_split1.pt --data_train texture --data_test texture --model_one one --subset . --normal_lr lr --input_res lr --chop --reset --save_results --print_model --test_only 16 | 17 | # NLR, cross-validation used, test for the second split 18 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NLR --save Test/NLR_second --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/NLR/model_x2_split2.pt --data_train texture --data_test texture --model_one two --subset . --normal_lr lr --input_res lr --chop --reset --save_results --print_model --test_only 19 | 20 | # NHR, cross-validation used, test for the first split 21 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NHR --save Test/NHR_first --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/NHR/model_x2_split1.pt --data_train texture --data_test texture --model_one one --subset . --normal_lr hr --input_res lr --chop --reset --save_results --print_model --test_only 22 | 23 | # NHR, cross-validation used, test for the second split 24 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NHR --save Test/NHR_second --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/NHR/model_x2_split2.pt --data_train texture --data_test texture --model_one two --subset . --normal_lr hr --input_res lr --chop --reset --save_results --print_model --test_only 25 | 26 | # EDSR 27 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model EDSR --submodel EDSR --save Test/EDSR --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --data_train div2k --data_test texture --chop --reset --save_results --print_model --test_only 28 | 29 | # EDSR_FT, finetune EDSR, cross-validation used, test for the first split 30 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model EDSR --submodel EDSR --save Test/EDSR_FT_first --scale 4 --n_resblocks 32 --n_feats 256 --res_scale 0.1 --pre_train xx --data_train texture --data_test texture --model_one one --subset . --input_res lr --chop --reset --save_results --print_model 31 | 32 | 33 | ######################################################################### 34 | #Train 35 | 36 | # NLR, cross-validation used, train for the first split 37 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NLR --save NLR_first --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one one --subset . --normal_lr lr --input_res lr --chop --reset --save_results --print_model 38 | 39 | # NLR, cross-validation used, train for the second split 40 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NLR --save NLR_second --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one two --subset . --normal_lr lr --input_res lr --chop --reset --save_results --print_model 41 | 42 | # NHR, cross-validation used, train for the first split 43 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NHR --save NHR_first --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one one --subset . --normal_lr hr --input_res lr --chop --reset --save_results --print_model 44 | 45 | # NHR, cross-validation used, train for the second split 46 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model FINETUNE --submodel NHR --save NHR_second --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one two --subset . --normal_lr hr --input_res lr --chop --reset --save_results --print_model 47 | 48 | # EDSR-FT, finetune EDSR, cross-validation used, train for the first split 49 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model EDSR --submodel EDSR --save EDSR_FT_first --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one one --subset . --input_res lr --chop --reset --save_results --print_model 50 | 51 | # EDSR-FT, finetune EDSR, cross-validation used, train for the second split 52 | CUDA_VISIBLE_DEVICES=xx python ../main.py --model EDSR --submodel EDSR --save EDSR_FT_second --scale 4 --patch_size 192 --n_resblocks 32 --n_feats 256 --epochs 100 --res_scale 0.1 --print_every 100 --lr 0.0001 --lr_decay 100 --batch_size 8 --pre_train ../../experiment/model/EDSR_model/EDSR_x4.pt --test_every 200 --data_train texture --data_test texture --model_one two --subset . --input_res lr --chop --reset --save_results --print_model 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /code/scripts/finetune_sr.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Called by qsub_NLR.sh, qsub_NHR.sh and qsub_HRST_CNN.sh 3 | # Submit to GPU 4 | # We use the batch system Grid Engine in our lab. The qsub command is used to submit the jobs. If you do not use that, just use the command in demo.sh. 5 | 6 | 7 | ######################################################################################## 8 | #MODEL=EDSR 9 | #SUBMODEL=EDSR 10 | N_BLOCK=32 11 | N_FEATS=256 12 | #N_PATCH=192 13 | #SCALE=4 14 | #EPOCH=100 15 | #MODEL_ONE=one 16 | #SUBSET=. 17 | 18 | if [ $MODEL == 'SR' ]; then 19 | CHECKPOINT="${SUBMODEL}_X${SCALE}_F${N_FEATS}B${N_BLOCK}P${N_PATCH}E${EPOCH}" 20 | elif [ $MODEL == 'FINETUNE' ]; then 21 | # NLR-Sub, NLR, NHR, HRST-CNN 22 | CHECKPOINT="${SUBMODEL}_X${SCALE}_F${N_FEATS}B${N_BLOCK}P${N_PATCH}E${EPOCH}_${NORMAL}_${MODEL_ONE}_Input" 23 | else 24 | # EDSR-FT 25 | CHECKPOINT="${MODEL}_X${SCALE}_F${N_FEATS}B${N_BLOCK}P${N_PATCH}E${EPOCH}_${MODEL_ONE}_WO_NORMAL" 26 | fi 27 | 28 | if [ $MODEL == 'FINETUNE' ]; then 29 | LR=0.0001 30 | else 31 | # learning rate used for EDSR-FT 32 | LR=0.000025 33 | fi 34 | 35 | if [ ! ${SUBSET} == '.' ]; then 36 | CHECKPOINT=${CHECKPOINT}_${SUBSET} 37 | fi 38 | export CHECKPOINT=${CHECKPOINT} 39 | 40 | echo $LR 41 | echo $CHECKPOINT 42 | qsub -N $CHECKPOINT ./vpython.sh main.py --model $MODEL --submodel $SUBMODEL --save $CHECKPOINT --scale $SCALE --patch_size $N_PATCH --n_resblocks $N_BLOCK --n_feats $N_FEATS --epochs $EPOCH --res_scale 0.1 --print_every 100 --lr $LR --lr_decay 100 --batch_size 8 --pre_train ../experiment/model/EDSR_model/EDSR_x${SCALE}.pt --test_every 200 --data_train texture_hr --data_test texture_hr --model_one ${MODEL_ONE} --subset ${SUBSET} --normal_lr $NORMAL --input_res ${INPUT_RES} 43 | 44 | # --ext sep_reset: used to reset the dataset (generate .npy files). 45 | # --data_train: used to judge how to test model, namely, whether to use cross validation. For all of the models trained with texture or texture_hr, cross validation should be used. Only testing EDSR do not used cross validation. 46 | # --model_one: used to split the dataset. 47 | # --subset: used to choose from the Subset. Five options: . (means the combination of all of the subsets), ETH3D, MiddleBury, Collection, SyB3R. 48 | # --normal_lr: use hr or lr normal map 49 | # --input_res: HRST_CNN should use the HR input texture map 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /code/scripts/network_NHR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofsoundof/3D_Appearance_SR/91fc377466e2756c6cf753b7db48ef98e4ea13c2/code/scripts/network_NHR.png -------------------------------------------------------------------------------- /code/scripts/network_NLR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ofsoundof/3D_Appearance_SR/91fc377466e2756c6cf753b7db48ef98e4ea13c2/code/scripts/network_NLR.png -------------------------------------------------------------------------------- /code/scripts/qsub.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | export SUBSET=. 5 | export EPOCH=100 6 | for s in 4; do 7 | for n in one; do 8 | #if [ ! $s = 4 ] || [ ! $n = one ]; then 9 | for res in lr hr; do 10 | while true; do 11 | job_string=$(qstat -u yawli) 12 | sub_string='yawli' 13 | diff=${job_string//${sub_string}} 14 | count=$(((${#job_string} - ${#diff}) / ${#sub_string})) 15 | echo "The number of running jobs is $count." 16 | if [ 6 -gt $count ]; then 17 | export SCALE=$s 18 | export N_PATCH=$(($s * 48)) 19 | export MODEL_ONE=$n 20 | export MODEL=FINETUNE 21 | export NORMAL=$res 22 | echo "The number of running jobs ${count} is smaller than 6." 23 | echo "Submit one job for scale $SCALE, patch size $N_PATCH, model $MODEL, split ${MODEL_ONE}, resolution ${NORMAL}." 24 | bash finetune_sr.sh 25 | break 26 | fi 27 | echo "The number of running jobs is equal to 6." 28 | t=900 29 | echo "Wait for $t seconds." 30 | sleep $t 31 | done 32 | done 33 | #fi 34 | done 35 | done 36 | -------------------------------------------------------------------------------- /code/scripts/qsub_HRST_CNN.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # run HRST-CNN 5 | export EPOCH=100 6 | export INPUT_RES=hr 7 | for subset in .; 8 | do 9 | export SUBSET=$subset 10 | for s in 2 4; do 11 | for n in one two; do 12 | 13 | while true; do 14 | # count the number of submitted jobs. 15 | job_string=$(qstat -u yawli) 16 | sub_string='yawli' 17 | diff=${job_string//${sub_string}} 18 | count=$(((${#job_string} - ${#diff}) / ${#sub_string})) 19 | echo "The number of running jobs is $count." 20 | #num_eth3d=$(ls /home/yawli/Documents/3d-appearance-benchmark/SR/texture/ETH3D/x${s}/*.png | wc -l) 21 | 22 | # qsub jobs 23 | if [ 6 -gt $count ]; then 24 | export SCALE=$s 25 | export N_PATCH=$(($s * 48)) 26 | export MODEL_ONE=$n 27 | export MODEL=FINETUNE 28 | export SUBMODEL=HRST_CNN 29 | echo "The number of running jobs ${count} is smaller than 6." 30 | echo "Submit one job for scale $SCALE, patch size $N_PATCH, model $MODEL, split ${MODEL_ONE}, resolution ${NORMAL}, subset ${SUBSET}." 31 | bash finetune_sr.sh 32 | break 33 | fi 34 | echo "The number of running jobs is equal to 6." 35 | 36 | # wait for some time 37 | t=300 38 | echo "Wait for $t seconds." 39 | sleep $t 40 | done 41 | 42 | done 43 | done 44 | done 45 | -------------------------------------------------------------------------------- /code/scripts/qsub_NHR.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # run NHR 5 | export SUBSET=. 6 | export EPOCH=100 7 | for s in 2 3 4; do 8 | for n in one two; do 9 | for normal in hr; do 10 | while true; do 11 | # count the number of submitted jobs. 12 | job_string=$(qstat -u yawli) 13 | sub_string='yawli' 14 | diff=${job_string//${sub_string}} 15 | count=$(((${#job_string} - ${#diff}) / ${#sub_string})) 16 | echo "The number of running jobs is $count." 17 | 18 | # qsub jobs 19 | if [ 6 -gt $count ]; then 20 | export SCALE=$s 21 | export N_PATCH=$(($s * 48)) 22 | export MODEL_ONE=$n 23 | export MODEL=FINETUNE 24 | export SUBMODEL=NHR 25 | export NORMAL=$normal 26 | echo "The number of running jobs ${count} is smaller than 6." 27 | echo "Submit one job for scale $SCALE, patch size $N_PATCH, model $MODEL, split ${MODEL_ONE}, resolution ${NORMAL}." 28 | bash finetune_sr.sh 29 | break 30 | fi 31 | echo "The number of running jobs is equal to 6." 32 | 33 | # wait for some time 34 | t=900 35 | echo "Wait for $t seconds." 36 | sleep $t 37 | done 38 | done 39 | done 40 | done 41 | -------------------------------------------------------------------------------- /code/scripts/qsub_NLR.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | 4 | # run NLR and NLR-Sub 5 | for d in ETH3D; do 6 | for s in 4 3 2; do 7 | for n in one two; do 8 | 9 | for m in FINETUNE; do 10 | while true; do 11 | 12 | # another method to count the number of submitted jobs. 13 | count=$(qstat -u yawli | tail -n +3 | wc -l) 14 | echo "The number of running jobs is $count." 15 | 16 | # submit jobs 17 | if [ 6 -gt $count ]; then 18 | export SCALE=$s 19 | export N_PATCH=$(($s * 48)) 20 | export MODEL_ONE=$n 21 | export MODEL=$m 22 | export SUBMODEL=NLR 23 | export SUBSET=$d 24 | export EPOCH=100 25 | echo "The number of running jobs ${count} is smaller than 6." 26 | echo "Submit one job for scale $SCALE, patch size $N_PATCH, epoch $EPOCH, model $MODEL, set split ${MODEL_ONE}, subset ${SUBSET}." 27 | bash finetune_sr.sh 28 | break 29 | fi 30 | 31 | # wait for some time 32 | echo "The number of running jobs is equal to 6." 33 | echo "Current time is $(date +%Y-%m-%d/%l:%M:%S)" 34 | t=450 35 | echo "Wait for $t seconds." 36 | sleep $t 37 | done 38 | done 39 | 40 | done 41 | done 42 | done 43 | -------------------------------------------------------------------------------- /code/scripts/vpython.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #$ -l gpu=1 3 | #$ -l h_rt=120:00:00 4 | #$ -l h_vmem=40G 5 | #$ -l h="biwirender15" 6 | #$ -cwd 7 | #$ -V 8 | #$ -j y 9 | #$ -o logs/ 10 | 11 | echo "$@" 12 | echo "Reserved GPU: $SGE_GPU" 13 | export CUDA_VISIBLE_DEVICES=$SGE_GPU 14 | echo "CUDA_VISIBLE_DEVICES: $CUDA_VISIBLE_DEVICES" 15 | 16 | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES python -u $@ 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /code/template.py: -------------------------------------------------------------------------------- 1 | def set_template(args): 2 | # Set the templates here 3 | if args.template.find('jpeg') >= 0: 4 | args.data_train = 'DIV2K_jpeg' 5 | args.data_test = 'DIV2K_jpeg' 6 | args.epochs = 200 7 | args.lr_decay = 100 8 | 9 | if args.template.find('EDSR_paper') >= 0: 10 | args.model = 'EDSR' 11 | args.n_resblocks = 32 12 | args.n_feats = 256 13 | args.res_scale = 0.1 14 | 15 | if args.template.find('MDSR') >= 0: 16 | args.model = 'MDSR' 17 | args.patch_size = 48 18 | args.epochs = 650 19 | 20 | if args.template.find('DDBPN') >= 0: 21 | args.model = 'DDBPN' 22 | args.patch_size = 128 23 | args.scale = '4' 24 | 25 | args.data_test = 'Set5' 26 | 27 | args.batch_size = 20 28 | args.epochs = 1000 29 | args.lr_decay = 500 30 | args.gamma = 0.1 31 | args.weight_decay = 1e-4 32 | 33 | args.loss = '1*MSE' 34 | 35 | if args.template.find('GAN') >= 0: 36 | args.epochs = 200 37 | args.lr = 5e-5 38 | args.lr_decay = 150 39 | 40 | -------------------------------------------------------------------------------- /code/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | from decimal import Decimal 4 | 5 | import utility 6 | 7 | import torch 8 | from torch.autograd import Variable 9 | from tqdm import tqdm 10 | 11 | class Trainer(object): 12 | def __init__(self, args, loader, my_model, my_loss, ckp): 13 | self.args = args 14 | self.scale = args.scale 15 | 16 | self.ckp = ckp 17 | self.loader_train = loader.loader_train 18 | self.loader_test = loader.loader_test 19 | self.model = my_model 20 | self.loss = my_loss 21 | self.optimizer = utility.make_optimizer(args, self.model) 22 | self.scheduler = utility.make_scheduler(args, self.optimizer) 23 | 24 | if self.args.load != '.': 25 | self.optimizer.load_state_dict( 26 | torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 27 | ) 28 | for _ in range(len(ckp.log)): self.scheduler.step() 29 | 30 | self.error_last = 1e8 31 | 32 | def train(self): 33 | self.scheduler.step() 34 | self.loss.step() 35 | epoch = self.scheduler.last_epoch + 1 36 | lr = self.scheduler.get_lr()[0] 37 | 38 | self.ckp.write_log( 39 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 40 | ) 41 | self.loss.start_log() 42 | self.model.train() 43 | # from IPython import embed; embed(); exit() 44 | timer_data, timer_model = utility.timer(), utility.timer() 45 | for batch, (lr, hr, _, idx_scale) in enumerate(self.loader_train): 46 | lr, hr = self.prepare([lr, hr]) 47 | timer_data.hold() 48 | timer_model.tic() 49 | # from IPython import embed; embed(); exit() 50 | 51 | self.optimizer.zero_grad() 52 | sr = self.model(idx_scale, lr) 53 | loss = self.loss(sr, hr) 54 | if loss.item() < self.args.skip_threshold * self.error_last: 55 | loss.backward() 56 | self.optimizer.step() 57 | else: 58 | print('Skip this batch {}! (Loss: {})'.format( 59 | batch + 1, loss.item() 60 | )) 61 | 62 | timer_model.hold() 63 | 64 | if (batch + 1) % self.args.print_every == 0: 65 | self.ckp.write_log('[{}/{}]\t{}\t{:.3f}+{:.3f}s'.format( 66 | (batch + 1) * self.args.batch_size, 67 | len(self.loader_train.dataset), 68 | self.loss.display_loss(batch), 69 | timer_model.release(), 70 | timer_data.release())) 71 | 72 | timer_data.tic() 73 | 74 | self.loss.end_log(len(self.loader_train)) 75 | self.error_last = self.loss.log[-1, -1] 76 | 77 | def test(self): 78 | epoch = self.scheduler.last_epoch + 1 79 | self.ckp.write_log('\nEvaluation:') 80 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 81 | self.model.eval() 82 | 83 | timer_test = utility.timer() 84 | with torch.no_grad(): 85 | for idx_scale, scale in enumerate(self.scale): 86 | eval_acc = 0 87 | self.loader_test.dataset.set_scale(idx_scale) 88 | tqdm_test = tqdm(self.loader_test, ncols=80) 89 | for idx_img, (lr, hr, filename, _) in enumerate(tqdm_test): 90 | 91 | # from IPython import embed; embed(); 92 | filename = filename[0] 93 | no_eval = (hr.nelement() == 1) 94 | if not no_eval: 95 | lr, hr = self.prepare([lr, hr]) 96 | else: 97 | lr = self.prepare([lr])[0] 98 | 99 | sr = self.model(idx_scale, lr) 100 | sr = utility.quantize(sr, self.args.rgb_range) 101 | 102 | save_list = [sr] 103 | if not no_eval: 104 | eval_acc += utility.calc_psnr( 105 | sr, hr, scale, self.args.rgb_range, 106 | benchmark=self.loader_test.dataset.benchmark 107 | ) 108 | save_list.extend([lr, hr]) 109 | 110 | if self.args.save_results: 111 | self.ckp.save_results(filename, save_list, scale) 112 | 113 | self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) 114 | best = self.ckp.log.max(0) 115 | self.ckp.write_log( 116 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 117 | self.args.data_test, 118 | scale, 119 | self.ckp.log[-1, idx_scale], 120 | best[0][idx_scale], 121 | best[1][idx_scale] + 1 122 | ) 123 | ) 124 | 125 | self.ckp.write_log( 126 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 127 | ) 128 | if not self.args.test_only: 129 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) 130 | 131 | def prepare(self, l, volatile=False): 132 | device = torch.device('cpu' if self.args.cpu else 'cuda') 133 | def _prepare(tensor): 134 | if self.args.precision == 'half': tensor = tensor.half() 135 | return tensor.to(device) 136 | 137 | return [_prepare(_l) for _l in l] 138 | 139 | def terminate(self): 140 | if self.args.test_only: 141 | self.test() 142 | return True 143 | else: 144 | epoch = self.scheduler.last_epoch + 1 145 | return epoch >= self.args.epochs 146 | 147 | -------------------------------------------------------------------------------- /code/trainer_finetune.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | 4 | import os 5 | import math 6 | from decimal import Decimal 7 | import torch.optim as optim 8 | import utility 9 | 10 | import torch 11 | from torch.autograd import Variable 12 | from tqdm import tqdm 13 | from trainer import Trainer 14 | 15 | 16 | class TrainerFT(Trainer): 17 | def __init__(self, args, loader, my_model, my_loss, ckp): 18 | super(TrainerFT, self).__init__(args, loader, my_model, my_loss, ckp) 19 | # self.args = args 20 | # self.scale = args.scale 21 | # 22 | # self.ckp = ckp 23 | # self.loader_train = loader.loader_train 24 | # self.loader_test = loader.loader_test 25 | # self.model = my_model 26 | # self.loss = my_loss 27 | if self.args.model.lower() == 'finetune': 28 | self.optimizer = self.make_optimizer(args, self.model) 29 | # self.scheduler = utility.make_scheduler(args, self.optimizer) 30 | # 31 | # if self.args.load != '.': 32 | # self.optimizer.load_state_dict( 33 | # torch.load(os.path.join(ckp.dir, 'optimizer.pt')) 34 | # ) 35 | # for _ in range(len(ckp.log)): self.scheduler.step() 36 | # 37 | # self.error_last = 1e8 38 | 39 | def train(self): 40 | self.scheduler.step() 41 | self.loss.step() 42 | epoch = self.scheduler.last_epoch + 1 43 | lr = self.scheduler.get_lr()[0] 44 | 45 | self.ckp.write_log( 46 | '[Epoch {}]\tLearning rate: {:.2e}'.format(epoch, Decimal(lr)) 47 | ) 48 | self.loss.start_log() 49 | self.model.train() 50 | # from IPython import embed; embed(); exit() 51 | timer_data, timer_model = utility.timer(), utility.timer() 52 | for batch, (lr, nl, mk, hr, _, idx_scale) in enumerate(self.loader_train): 53 | # from IPython import embed; embed(); exit() 54 | lr, nl, mk, hr = self.prepare([lr, nl, mk, hr]) 55 | timer_data.hold() 56 | timer_model.tic() 57 | 58 | self.optimizer.zero_grad() 59 | sr = self.model(idx_scale, lr, nl, mk) 60 | # from IPython import embed; embed(); exit() 61 | loss = self.loss(sr, hr) 62 | if loss.item() < self.args.skip_threshold * self.error_last: 63 | loss.backward() 64 | self.optimizer.step() 65 | else: 66 | print('Skip this batch {}! (Loss: {})'.format( 67 | batch + 1, loss.item() 68 | )) 69 | 70 | timer_model.hold() 71 | 72 | if (batch + 1) % self.args.print_every == 0: 73 | self.ckp.write_log('[{}/{}]\t{}\t{:.3f}+{:.3f}s'.format( 74 | (batch + 1) * self.args.batch_size, 75 | len(self.loader_train.dataset), 76 | self.loss.display_loss(batch), 77 | timer_model.release(), 78 | timer_data.release())) 79 | 80 | timer_data.tic() 81 | # from IPython import embed; embed(); exit() 82 | self.loss.end_log(len(self.loader_train)) 83 | self.error_last = self.loss.log[-1, -1] 84 | 85 | def test(self): 86 | epoch = self.scheduler.last_epoch + 1 87 | self.ckp.write_log('\nEvaluation:') 88 | self.ckp.add_log(torch.zeros(1, len(self.scale))) 89 | self.model.eval() 90 | 91 | timer_test = utility.timer() 92 | with torch.no_grad(): 93 | for idx_scale, scale in enumerate(self.scale): 94 | eval_acc = 0 95 | self.loader_test.dataset.set_scale(idx_scale) 96 | tqdm_test = tqdm(self.loader_test, ncols=80) 97 | for idx_img, (lr, nl, mk, hr, filename, _) in enumerate(tqdm_test): 98 | # print('FLAG') 99 | # print(filename) 100 | filename = filename[0] 101 | print(filename) 102 | no_eval = (hr.nelement() == 1) 103 | if not no_eval: 104 | lr, nl, mk, hr = self.prepare([lr, nl, mk, hr]) 105 | else: 106 | lr, nl, mk, = self.prepare([lr, nl, mk]) 107 | 108 | sr = self.model(idx_scale, lr, nl, mk) 109 | sr = utility.quantize(sr, self.args.rgb_range) 110 | # print(sr.shape) 111 | b, c, h, w = sr.shape 112 | hr = hr[:, :, :h, :w] 113 | save_list = [sr] 114 | if not no_eval: 115 | eval_acc += utility.calc_psnr( 116 | sr, hr, scale, self.args.rgb_range, 117 | benchmark=self.loader_test.dataset.benchmark 118 | ) 119 | save_list.extend([lr, hr]) 120 | 121 | if self.args.save_results: 122 | self.ckp.save_results(filename, save_list, scale) 123 | 124 | self.ckp.log[-1, idx_scale] = eval_acc / len(self.loader_test) 125 | best = self.ckp.log.max(0) 126 | self.ckp.write_log( 127 | '[{} x{}]\tPSNR: {:.3f} (Best: {:.3f} @epoch {})'.format( 128 | self.args.data_test, 129 | scale, 130 | self.ckp.log[-1, idx_scale], 131 | best[0][idx_scale], 132 | best[1][idx_scale] + 1 133 | ) 134 | ) 135 | 136 | self.ckp.write_log( 137 | 'Total time: {:.2f}s\n'.format(timer_test.toc()), refresh=True 138 | ) 139 | if not self.args.test_only: 140 | self.ckp.save(self, epoch, is_best=(best[1][0] + 1 == epoch)) 141 | 142 | def make_optimizer(self, args, model): 143 | trainable = filter(lambda x: x.requires_grad, model.model.parameters()) 144 | # from IPython import embed; embed(); exit() 145 | finetune_id = list(map(id, model.model.body_ft.parameters())) \ 146 | + list(map(id, model.model.tail_ft.parameters()))#\ 147 | #+ list(map(id, model.model.tail_ft2.parameters())) 148 | base_params = filter(lambda x: id(x) not in finetune_id, trainable) 149 | trainable = filter(lambda x: x.requires_grad, model.model.parameters()) 150 | finetune_params = filter(lambda x: id(x) in finetune_id, trainable) 151 | if args.optimizer == 'SGD': 152 | optimizer_function = optim.SGD 153 | kwargs = {'momentum': args.momentum} 154 | elif args.optimizer == 'ADAM': 155 | optimizer_function = optim.Adam 156 | kwargs = { 157 | 'betas': (args.beta1, args.beta2), 158 | 'eps': args.epsilon 159 | } 160 | elif args.optimizer == 'RMSprop': 161 | optimizer_function = optim.RMSprop 162 | kwargs = {'eps': args.epsilon} 163 | 164 | kwargs['lr'] = args.lr * 0.1 165 | kwargs['weight_decay'] = args.weight_decay 166 | # from IPython import embed; embed(); exit() 167 | return optimizer_function([ 168 | {'params': base_params}, 169 | {'params': finetune_params, 'lr': args.lr} 170 | ], **kwargs) 171 | -------------------------------------------------------------------------------- /code/utility.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import datetime 5 | from functools import reduce 6 | 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | import numpy as np 12 | import scipy.misc as misc 13 | 14 | import torch 15 | import torch.optim as optim 16 | import torch.optim.lr_scheduler as lrs 17 | 18 | class timer(): 19 | def __init__(self): 20 | self.acc = 0 21 | self.tic() 22 | 23 | def tic(self): 24 | self.t0 = time.time() 25 | 26 | def toc(self): 27 | return time.time() - self.t0 28 | 29 | def hold(self): 30 | self.acc += self.toc() 31 | 32 | def release(self): 33 | ret = self.acc 34 | self.acc = 0 35 | 36 | return ret 37 | 38 | def reset(self): 39 | self.acc = 0 40 | 41 | class checkpoint(): 42 | def __init__(self, args): 43 | self.args = args 44 | self.ok = True 45 | self.log = torch.Tensor() 46 | now = datetime.datetime.now().strftime('%Y-%m-%d-%H:%M:%S') 47 | self.img_sfx = args.model + '_' + args.submodel 48 | 49 | if args.load == '.': 50 | if args.save == '.': args.save = now 51 | self.dir = args.dir + args.save 52 | else: 53 | self.dir = args.dir + args.load 54 | if not os.path.exists(self.dir): 55 | args.load = '.' 56 | else: 57 | self.log = torch.load(self.dir + '/psnr_log.pt') 58 | print('Continue from epoch {}...'.format(len(self.log))) 59 | 60 | if args.reset: 61 | os.system('rm -rf ' + self.dir) 62 | args.load = '.' 63 | 64 | def _make_dir(path): 65 | if not os.path.exists(path): os.makedirs(path) 66 | 67 | _make_dir(self.dir) 68 | _make_dir(self.dir + '/model') 69 | _make_dir(self.dir + '/results') 70 | 71 | open_type = 'a' if os.path.exists(self.dir + '/log.txt') else 'w' 72 | self.log_file = open(self.dir + '/log.txt', open_type) 73 | with open(self.dir + '/config.txt', open_type) as f: 74 | f.write(now + '\n\n') 75 | for arg in vars(args): 76 | f.write('{}: {}\n'.format(arg, getattr(args, arg))) 77 | f.write('\n') 78 | 79 | def save(self, trainer, epoch, is_best=False): 80 | trainer.model.save(self.dir, epoch, is_best=is_best) 81 | trainer.loss.save(self.dir) 82 | trainer.loss.plot_loss(self.dir, epoch) 83 | 84 | self.plot_psnr(epoch) 85 | torch.save(self.log, os.path.join(self.dir, 'psnr_log.pt')) 86 | torch.save( 87 | trainer.optimizer.state_dict(), 88 | os.path.join(self.dir, 'optimizer.pt') 89 | ) 90 | 91 | def add_log(self, log): 92 | self.log = torch.cat([self.log, log]) 93 | 94 | def write_log(self, log, refresh=False): 95 | print(log) 96 | self.log_file.write(log + '\n') 97 | if refresh: 98 | self.log_file.close() 99 | self.log_file = open(self.dir + '/log.txt', 'a') 100 | 101 | def done(self): 102 | self.log_file.close() 103 | 104 | def plot_psnr(self, epoch): 105 | axis = np.linspace(1, epoch, epoch) 106 | label = 'SR on {}'.format(self.args.data_test) 107 | fig = plt.figure() 108 | plt.title(label) 109 | for idx_scale, scale in enumerate(self.args.scale): 110 | plt.plot( 111 | axis, 112 | self.log[:, idx_scale].numpy(), 113 | label='Scale {}'.format(scale) 114 | ) 115 | plt.legend() 116 | plt.xlabel('Epochs') 117 | plt.ylabel('PSNR') 118 | plt.grid(True) 119 | plt.savefig('{}/test_{}.pdf'.format(self.dir, self.args.data_test)) 120 | plt.close(fig) 121 | 122 | def save_results(self, filename, save_list, scale): 123 | filename = '{}/results/{}_x{}_'.format(self.dir, filename, scale) 124 | postfix = (self.img_sfx, 'LR', 'HR') 125 | for v, p in zip(save_list, postfix): 126 | normalized = v[0].data.mul(255 / self.args.rgb_range) 127 | ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 128 | misc.imsave('{}{}.png'.format(filename, p), ndarr) 129 | 130 | #def save_results(self, filename, save_list, scale): 131 | # filepath = os.path.join(self.args.dir_data, self.args.data_test, self.args.data_test_texture_sr, 132 | # 'x{}_{}/Images/Frame000'.format(scale, self.args.model)) 133 | # if not os.path.exists(filepath): 134 | # os.makedirs(filepath) 135 | # image_id = int(os.path.splitext(filename)[0][5:]) 136 | 137 | # normalized = save_list[0][0].data.mul(255 / self.args.rgb_range) 138 | # ndarr = normalized.byte().permute(1, 2, 0).cpu().numpy() 139 | #from IPython import embed; embed(); exit() 140 | # misc.imsave(os.path.join(filepath, 'Image{}.png'.format(image_id-1)), ndarr) 141 | # misc.imsave(os.path.join(os.path.dirname(filepath), 'Image{}.png'.format(image_id)), ndarr) 142 | 143 | def quantize(img, rgb_range): 144 | pixel_range = 255 / rgb_range 145 | return img.mul(pixel_range).clamp(0, 255).round().div(pixel_range) 146 | 147 | def calc_psnr(sr, hr, scale, rgb_range, benchmark=False): 148 | diff = (sr - hr).data.div(rgb_range) 149 | shave = scale 150 | if diff.size(1) > 1: 151 | convert = diff.new(1, 3, 1, 1) 152 | convert[0, 0, 0, 0] = 65.738 153 | convert[0, 1, 0, 0] = 129.057 154 | convert[0, 2, 0, 0] = 25.064 155 | diff.mul_(convert).div_(256) 156 | diff = diff.sum(dim=1, keepdim=True) 157 | ''' 158 | if benchmark: 159 | shave = scale 160 | if diff.size(1) > 1: 161 | convert = diff.new(1, 3, 1, 1) 162 | convert[0, 0, 0, 0] = 65.738 163 | convert[0, 1, 0, 0] = 129.057 164 | convert[0, 2, 0, 0] = 25.064 165 | diff.mul_(convert).div_(256) 166 | diff = diff.sum(dim=1, keepdim=True) 167 | else: 168 | shave = scale + 6 169 | ''' 170 | valid = diff[:, :, shave:-shave, shave:-shave] 171 | mse = valid.pow(2).mean() 172 | 173 | return -10 * math.log10(mse) 174 | 175 | def make_optimizer(args, my_model): 176 | trainable = filter(lambda x: x.requires_grad, my_model.parameters()) 177 | 178 | if args.optimizer == 'SGD': 179 | optimizer_function = optim.SGD 180 | kwargs = {'momentum': args.momentum} 181 | elif args.optimizer == 'ADAM': 182 | optimizer_function = optim.Adam 183 | kwargs = { 184 | 'betas': (args.beta1, args.beta2), 185 | 'eps': args.epsilon 186 | } 187 | elif args.optimizer == 'RMSprop': 188 | optimizer_function = optim.RMSprop 189 | kwargs = {'eps': args.epsilon} 190 | 191 | kwargs['lr'] = args.lr 192 | kwargs['weight_decay'] = args.weight_decay 193 | 194 | return optimizer_function(trainable, **kwargs) 195 | 196 | def make_scheduler(args, my_optimizer): 197 | if args.decay_type == 'step': 198 | scheduler = lrs.StepLR( 199 | my_optimizer, 200 | step_size=args.lr_decay, 201 | gamma=args.gamma 202 | ) 203 | elif args.decay_type.find('step') >= 0: 204 | milestones = args.decay_type.split('_') 205 | milestones.pop(0) 206 | milestones = list(map(lambda x: int(x), milestones)) 207 | scheduler = lrs.MultiStepLR( 208 | my_optimizer, 209 | milestones=milestones, 210 | gamma=args.gamma 211 | ) 212 | 213 | return scheduler 214 | 215 | -------------------------------------------------------------------------------- /code/utils/LMDB_TEST.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import lmdb 4 | import caffe 5 | import os 6 | import pickle 7 | import cv2 8 | import numpy as np 9 | 10 | def _get_paths_from_lmdb(dataroot): 11 | env = lmdb.open(dataroot, max_readers=1, readonly=True, lock=False, readahead=False, meminit=False) 12 | keys_cache_file = os.path.join(dataroot, '_keys_cache.p') 13 | if os.path.isfile(keys_cache_file): 14 | print('read lmdb keys from cache: {}'.format(keys_cache_file)) 15 | keys = pickle.load(open(keys_cache_file, "rb")) 16 | else: 17 | with env.begin(write=False) as txn: 18 | print('creating lmdb keys cache: {}'.format(keys_cache_file)) 19 | keys = [key.decode('ascii') for key, _ in txn.cursor()] 20 | pickle.dump(keys, open(keys_cache_file, 'wb')) 21 | paths = sorted([key.encode('ascii') for key in keys if not key.endswith('.meta')]) 22 | return env, paths 23 | 24 | 25 | def get_image_paths(data_type, dataroot): 26 | env, paths = None, None 27 | if dataroot is not None: 28 | if data_type == 'lmdb': 29 | env, paths = _get_paths_from_lmdb(dataroot) 30 | elif data_type == 'img': 31 | paths = [] 32 | else: 33 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) 34 | return env, paths 35 | 36 | 37 | def _read_lmdb_img(env, path): 38 | with env.begin(write=False) as txn: 39 | buf = txn.get(path)#.encode('ascii')) 40 | #buf_meta = txn.get((path + '.meta').encode('ascii')).decode('ascii') 41 | datum = caffe.proto.caffe_pb2.Datum() 42 | datum.ParseFromString(buf) 43 | img_flat = np.frombuffer(datum.data, dtype=np.uint8) 44 | #H, W, C = [int(s) for s in buf_meta.split(',')] 45 | #print(img_flat.shape) 46 | if img_flat.shape[0] == 691200: 47 | img = img_flat.reshape(480, 480, 3) 48 | else: 49 | img = img_flat.reshape(int(480/4), int(480/4), 3) 50 | return img 51 | 52 | 53 | def read_img(env, path): 54 | # read image by cv2 or from lmdb 55 | # return: Numpy float32, HWC, BGR, [0,1] 56 | if env is None: # img 57 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 58 | else: 59 | img = _read_lmdb_img(env, path) 60 | img = img.astype(np.float32) 61 | if img.ndim == 2: 62 | img = np.expand_dims(img, axis=2) 63 | # some images have 4 channels 64 | if img.shape[2] > 3: 65 | img = img[:, :, :3] 66 | return img 67 | 68 | lmdb_path = '/scratch_net/ofsoundof/yawli/Datasets/DIV2K/GT_sub_image.lmdb' # must end with .lmdb 69 | env, paths = get_image_paths('lmdb', lmdb_path) 70 | from IPython import embed; embed(); exit() 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | -------------------------------------------------------------------------------- /code/utils/compare_PSNR_preSR.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | 11 | path = '/scratch_net/ofsoundof/yawli/3D_appearance_dataset/Collection/Texture/x1' 12 | path_sr = '/scratch_net/ofsoundof/yawli/3D_appearance_dataset/Collection/Texture_SR/x2_SR' 13 | 14 | dataset = ['Collection'] #['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 15 | psnr_all_y = [] 16 | psnr_all = [] 17 | 18 | 19 | for d in dataset: 20 | img_sr_list = glob.glob(os.path.join(path_sr, '*.png')) 21 | print('The images are: {}'.format(img_sr_list)) 22 | psnr_ys = np.zeros([len(img_sr_list)+1]) 23 | psnr_s = np.zeros([len(img_sr_list)+1]) 24 | 25 | for i in range(len(img_sr_list)): 26 | img_sr_n = img_sr_list[i] 27 | name_img = os.path.splitext(os.path.basename(img_sr_n))[0] 28 | #print(name_img) 29 | 30 | img_sr = Image.open(img_sr_n) 31 | img_hr = Image.open(os.path.join(path, name_img + '.png')) 32 | #from IPython import embed; embed(); 33 | w, h = img_sr.size 34 | s = 2 35 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 36 | img_hr_s = shave(img_hr_s, s) 37 | img_sr_s = shave(np.asarray(img_sr), s) 38 | psnr_ys[i], psnr_s[i] = cal_pnsr_all(img_hr_s, img_sr_s) 39 | print('Image, {}: PSNR, {}; Difference, {}'.format(name_img, psnr_s[i], np.sum((img_hr_s - img_sr_s)**2))) 40 | 41 | psnr_ys[-1] = np.mean(psnr_ys[:-1], axis=0) 42 | psnr_s[-1] = np.mean(psnr_s[:-1], axis=0) 43 | psnr_all_y.append(psnr_ys) 44 | psnr_all.append(psnr_s) 45 | print(np.round(psnr_s,2)) 46 | -------------------------------------------------------------------------------- /code/utils/compare_PSNR_preSR_image.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | 11 | path = '/scratch_net/ofsoundof/yawli/3D_appearance_dataset' 12 | 13 | dataset = ['Collection'] #['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 14 | subset = [['Buddha', 'Bunny', 'Fountain', 'Beethoven', 'Relief', 'Bird']] 15 | 16 | 17 | 18 | for d in range(len(dataset)): 19 | D = dataset[d] 20 | psnr_all = [] 21 | psnr_mean = np.zeros(len(subset[d]) + 1) 22 | for s in range(len(subset[d])): 23 | S = subset[d][s] 24 | img_hr_list = glob.glob(os.path.join(path, D, S, 'x1/Images/*.png')) 25 | img_sr_list = glob.glob(os.path.join(path, D, S, 'x2_RCAN/Images/*.png')) 26 | 27 | psnr_s = np.zeros([len(img_sr_list)+1]) 28 | 29 | for i in range(len(img_sr_list)): 30 | img_sr_n = img_sr_list[i] 31 | img_hr_n = img_hr_list[i] 32 | 33 | img_sr = Image.open(img_sr_n) 34 | img_hr = Image.open(img_hr_n) 35 | #from IPython import embed; embed(); 36 | w, h = img_sr.size 37 | crop = 2 38 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 39 | img_hr_s = shave(img_hr_s, crop) 40 | img_sr_s = shave(np.asarray(img_sr), crop) 41 | _, psnr_s[i] = cal_pnsr_all(img_hr_s, img_sr_s) 42 | 43 | psnr_s[-1] = np.mean(psnr_s[:-1], axis=0) 44 | psnr_all.append(psnr_s) 45 | psnr_mean[s] = psnr_s[-1] 46 | #from IPython import embed; embed(); 47 | print(psnr_s) 48 | print('The mean for {} in {} is {}.'.format(S, D, psnr_mean[s])) 49 | #print(psnr_all) 50 | psnr_mean[-1] = np.mean(psnr_mean[:-1]) 51 | print(np.round(psnr_mean,2)) 52 | -------------------------------------------------------------------------------- /code/utils/compute_PSNR_SR_1NLR.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | # compute psnr value for EDSR, EDSR-FT, and NLR in the paper. 11 | #path_sr = '/home/yawli/projects/RCAN/RCAN_TrainCode/experiment/test' 12 | path_sr = '/scratch_net/ofsoundof/yawli/experiment/test' 13 | method_sr = ['EDSR'] 14 | len_m = len(method_sr) 15 | train_flag = [[v+'_TEST', v+'_CONTINUE', 'FINETUNE_'+v] for v in method_sr] 16 | psnr_flag = ['_T', '_WON', '_WN'] 17 | len_f = len(psnr_flag) 18 | method_up = ['Nearest', 'Bilinear', 'Bicubic', 'Lanczos'] 19 | method_all = method_up + [method_sr[m]+psnr_flag[f] for m in range(len_m) for f in range(len_f)] 20 | 21 | dataset = ['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 22 | path = '/scratch_net/ofsoundof/yawli/Datasets/texture_map' 23 | psnr_all_y = [] 24 | psnr_all = [] 25 | 26 | 27 | with open('./results/UP_PSNR.pkl', 'rb') as f: 28 | psnr_up = pkl.load(f) 29 | 30 | for d in dataset: 31 | img_hr_list = glob.glob(os.path.join(path, d, 'x1/Texture/*.png')) 32 | num = len_m * len_f 33 | print('The images are: '.format(img_hr_list)) 34 | psnr_ys = np.zeros([len(img_hr_list)+1, 4+num, 3]) 35 | psnr_ys[:, :4, :] = psnr_up['psnr_up_y'][dataset.index(d)] 36 | psnr_s = np.zeros([len(img_hr_list)+1, 4+num, 3]) 37 | psnr_s[:, :4, :] = psnr_up['psnr_up'][dataset.index(d)] 38 | 39 | for i in range(len(img_hr_list)): 40 | img_hr_n = img_hr_list[i] 41 | name_img = os.path.splitext(os.path.basename(img_hr_n))[0] 42 | print(name_img) 43 | img_hr = Image.open(img_hr_n) 44 | for s in range(2, 5): 45 | for m in range(len_m): 46 | for f in range(len_f): 47 | img_sr_n = os.path.join(path_sr, method_sr[m]+'_X{}_'.format(s)+train_flag[m][f], 'results', 48 | name_img+'_x{}_'.format(s)+train_flag[m][f]+'.png') 49 | print(img_sr_n) 50 | #from IPython import embed; embed(); 51 | if os.path.exists(img_sr_n): 52 | img_sr = Image.open(img_sr_n) 53 | w, h = img_sr.size 54 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 55 | img_hr_s = shave(img_hr_s, s) 56 | img_sr_s = shave(np.asarray(img_sr), s) 57 | psnr_ys[i, m*len_f+f+4, s-2], psnr_s[i, m*len_f+f+4, s-2] = cal_pnsr_all(img_hr_s, img_sr_s) 58 | 59 | psnr_ys[-1, :, :] = np.mean(psnr_ys[:-1, :, :], axis=0) 60 | psnr_s[-1, :, :] = np.mean(psnr_s[:-1, :, :], axis=0) 61 | psnr_all_y.append(psnr_ys) 62 | psnr_all.append(psnr_s) 63 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_ys, d, './results/ALL_PSNR_Y.html') 64 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_s, d, './results/ALL_PSNR_RGB.html') 65 | 66 | with open('./results/ALL_PSNR.pkl', 'wb') as f: 67 | pkl.dump({'psnr_all_y': psnr_all_y, 'psnr_all': psnr_all, 'method': method_all}, f) 68 | 69 | psnr_sm_y = np.zeros((5, len(method_all), 3)) 70 | psnr_sm = np.zeros((5, len(method_all), 3)) 71 | 72 | for i in range(len(dataset)): 73 | psnr_sm_y[i, :, :] = psnr_all_y[i][-1, :, :] 74 | psnr_sm[i, :, :] = psnr_all[i][-1, :, :] 75 | psnr_sm_y[-1, :, :] += np.sum(psnr_all_y[i][:-1, :, :], axis=0) 76 | psnr_sm[-1, :, :] += np.sum(psnr_all[i][:-1, :, :], axis=0) 77 | psnr_sm_y[-1, :, :] = psnr_sm_y[-1, :, :]/24 78 | psnr_sm[-1, :, :] = psnr_sm[-1, :, :]/24 79 | save_html(dataset, method_all, psnr_sm_y, 'All', './results/ALL_Summary_Y.html') 80 | save_html(dataset, method_all, psnr_sm, 'All', './results/ALL_Summary_RGB.html') 81 | -------------------------------------------------------------------------------- /code/utils/compute_PSNR_SR_2Sub.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | #path_sr = '/home/yawli/projects/RCAN/RCAN_TrainCode/experiment/test' 11 | path_sr = '/scratch_net/ofsoundof/yawli/experiment/test' 12 | method_sr = ['EDSR'] 13 | len_m = len(method_sr) 14 | train_flag = [[v+'_CONTINUE', 'FINETUNE_'+v] for v in method_sr] 15 | psnr_flag = ['_WON_Solo', '_WN_Solo']#, '_WN_Solo_E100'] 16 | len_f = len(psnr_flag) 17 | 18 | dataset = ['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 19 | path = '/scratch_net/ofsoundof/yawli/Datasets/texture_map' 20 | psnr_all_y = [] 21 | psnr_all = [] 22 | 23 | 24 | with open('./results/ALL_PSNR.pkl', 'rb') as f: 25 | psnr_up = pkl.load(f) 26 | method_all = psnr_up['method'] + [method_sr[m]+psnr_flag[f] for m in range(len_m) for f in range(len_f)] 27 | num_pre = len(psnr_up['method']) 28 | 29 | for d in dataset: 30 | img_hr_list = glob.glob(os.path.join(path, d, 'x1/Texture/*.png')) 31 | num = len_m * len_f 32 | psnr_ys = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 33 | psnr_ys[:, :num_pre, :] = psnr_up['psnr_all_y'][dataset.index(d)] 34 | psnr_s = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 35 | psnr_s[:, :num_pre, :] = psnr_up['psnr_all'][dataset.index(d)] 36 | 37 | for i in range(len(img_hr_list)): 38 | img_hr_n = img_hr_list[i] 39 | name_img = os.path.splitext(os.path.basename(img_hr_n))[0] 40 | print(name_img) 41 | img_hr = Image.open(img_hr_n) 42 | for s in range(2, 5): 43 | for m in range(len_m): 44 | for f in range(len_f): 45 | flag = d + '_E100' if f == 2 else d 46 | #from IPython import embed; embed() 47 | t_flag = train_flag[m][1] if f == 2 else train_flag[m][f] 48 | img_sr_n = os.path.join(path_sr, method_sr[m]+'_X{}_'.format(s)+t_flag+'_'+flag, 'results', 49 | name_img+'_x{}_'.format(s)+t_flag+'.png') 50 | if os.path.exists(img_sr_n): 51 | print(img_sr_n) 52 | img_sr = Image.open(img_sr_n) 53 | w, h = img_sr.size 54 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 55 | img_hr_s = shave(img_hr_s, s) 56 | img_sr_s = shave(np.asarray(img_sr), s) 57 | psnr_ys[i, m*len_f+f+num_pre, s-2], psnr_s[i, m*len_f+f+num_pre, s-2] = cal_pnsr_all(img_hr_s, img_sr_s) 58 | 59 | psnr_ys[-1, :, :] = np.mean(psnr_ys[:-1, :, :], axis=0) 60 | psnr_s[-1, :, :] = np.mean(psnr_s[:-1, :, :], axis=0) 61 | psnr_all_y.append(psnr_ys) 62 | psnr_all.append(psnr_s) 63 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_ys, d, './results/TOTAL_PSNR_Y.html') 64 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_s, d, './results/TOTAL_PSNR_RGB.html') 65 | 66 | with open('./results/TOTAL_PSNR.pkl', 'wb') as f: 67 | pkl.dump({'psnr_all_y': psnr_all_y, 'psnr_all': psnr_all, 'method': method_all}, f) 68 | 69 | psnr_sm_y = np.zeros((5, len(method_all), 3)) 70 | psnr_sm = np.zeros((5, len(method_all), 3)) 71 | 72 | for i in range(len(dataset)): 73 | psnr_sm_y[i, :, :] = psnr_all_y[i][-1, :, :] 74 | psnr_sm[i, :, :] = psnr_all[i][-1, :, :] 75 | psnr_sm_y[-1, :, :] += np.sum(psnr_all_y[i][:-1, :, :], axis=0) 76 | psnr_sm[-1, :, :] += np.sum(psnr_all[i][:-1, :, :], axis=0) 77 | psnr_sm_y[-1, :, :] = psnr_sm_y[-1, :, :]/24 78 | psnr_sm[-1, :, :] = psnr_sm[-1, :, :]/24 79 | 80 | #psnr_sm_y[-1, -2:, :] = 0 81 | #psnr_sm[-1, -2:, :] = 0 82 | 83 | save_html(dataset, method_all, psnr_sm_y, 'All', './results/TOTAL_Summary_Y.html') 84 | save_html(dataset, method_all, psnr_sm, 'All', './results/TOTAL_Summary_RGB.html') 85 | -------------------------------------------------------------------------------- /code/utils/compute_PSNR_SR_3NHR.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | # compute PSNR value for NHR. 11 | #path_sr = '/home/yawli/projects/RCAN/RCAN_TrainCode/experiment/test' 12 | path_sr = '/scratch_net/ofsoundof/yawli/experiment/' 13 | method_sr = ['EDSR'] 14 | len_m = len(method_sr) 15 | train_flag = ['hr','lr','hr'] 16 | psnr_flag = ['_NORMAL_HR', '_NORMAL_LR', '_Res32']#, '_WN_Solo_E100'] 17 | len_f = len(psnr_flag) 18 | 19 | dataset = ['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 20 | path = '/scratch_net/ofsoundof/yawli/Datasets/texture_map' 21 | psnr_all_y = [] 22 | psnr_all = [] 23 | 24 | 25 | with open('./results/TOTAL_PSNR.pkl', 'rb') as f: 26 | psnr = pkl.load(f) 27 | method_all = psnr['method'] + [method_sr[m]+psnr_flag[f] for m in range(len_m) for f in range(len_f)] 28 | num_pre = len(psnr['method']) 29 | 30 | for d in dataset: 31 | img_hr_list = glob.glob(os.path.join(path, d, 'x1/Texture/*.png')) 32 | num = len_m * len_f 33 | psnr_ys = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 34 | psnr_ys[:, :num_pre, :] = psnr['psnr_all_y'][dataset.index(d)] 35 | psnr_s = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 36 | psnr_s[:, :num_pre, :] = psnr['psnr_all'][dataset.index(d)] 37 | 38 | for i in range(len(img_hr_list)): 39 | img_hr_n = img_hr_list[i] 40 | name_img = os.path.splitext(os.path.basename(img_hr_n))[0] 41 | print(name_img) 42 | img_hr = Image.open(img_hr_n) 43 | for s in range(2, 5): 44 | for m in range(len_m): 45 | for f in range(len_f): 46 | flag = d + '_E100' if f == 2 else d 47 | #from IPython import embed; embed() 48 | tail = '_Res32' if f == 2 else '' 49 | img_sr_n1 = os.path.join(path_sr, method_sr[m]+'_X{}_F256B32P{}E100_{}_one'.format(s,s*48,train_flag[f]) + tail, 'results', 50 | name_img+'_x{}_'.format(s)+'FINETUNE_EDSR.png') 51 | img_sr_n2 = os.path.join(path_sr, method_sr[m]+'_X{}_F256B32P{}E100_{}_two'.format(s,s*48,train_flag[f]) + tail, 'results', 52 | name_img+'_x{}_'.format(s)+'FINETUNE_EDSR.png') 53 | img_sr_n = img_sr_n1 if os.path.exists(img_sr_n1) else img_sr_n2 54 | if os.path.exists(img_sr_n): 55 | print(img_sr_n) 56 | img_sr = Image.open(img_sr_n) 57 | w, h = img_sr.size 58 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 59 | img_hr_s = shave(img_hr_s, s) 60 | img_sr_s = shave(np.asarray(img_sr), s) 61 | psnr_ys[i, m*len_f+f+num_pre, s-2], psnr_s[i, m*len_f+f+num_pre, s-2] = cal_pnsr_all(img_hr_s, img_sr_s) 62 | 63 | psnr_ys[-1, :, :] = np.mean(psnr_ys[:-1, :, :], axis=0) 64 | psnr_s[-1, :, :] = np.mean(psnr_s[:-1, :, :], axis=0) 65 | psnr_all_y.append(psnr_ys) 66 | psnr_all.append(psnr_s) 67 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_ys, d, './results/THL_PSNR_Y.html') 68 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_s, d, './results/THL_PSNR_RGB.html') 69 | 70 | with open('./results/THL_PSNR.pkl', 'wb') as f: 71 | pkl.dump({'psnr_all_y': psnr_all_y, 'psnr_all': psnr_all, 'method': method_all}, f) 72 | 73 | psnr_sm_y = np.zeros((5, len(method_all), 3)) 74 | psnr_sm = np.zeros((5, len(method_all), 3)) 75 | 76 | for i in range(len(dataset)): 77 | psnr_sm_y[i, :, :] = psnr_all_y[i][-1, :, :] 78 | psnr_sm[i, :, :] = psnr_all[i][-1, :, :] 79 | psnr_sm_y[-1, :, :] += np.sum(psnr_all_y[i][:-1, :, :], axis=0) 80 | psnr_sm[-1, :, :] += np.sum(psnr_all[i][:-1, :, :], axis=0) 81 | psnr_sm_y[-1, :, :] = psnr_sm_y[-1, :, :]/24 82 | psnr_sm[-1, :, :] = psnr_sm[-1, :, :]/24 83 | 84 | #psnr_sm_y[-1, -2:, :] = 0 85 | #psnr_sm[-1, -2:, :] = 0 86 | 87 | save_html(dataset, method_all, psnr_sm_y, 'All', './results/THL_Summary_Y.html') 88 | save_html(dataset, method_all, psnr_sm, 'All', './results/THL_Summary_RGB.html') 89 | -------------------------------------------------------------------------------- /code/utils/compute_PSNR_SR_4HRST.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import pickle as pkl 8 | from compute_PSNR_UP import * 9 | 10 | #path_sr = '/home/yawli/projects/RCAN/RCAN_TrainCode/experiment/test' 11 | path_sr = '/scratch_net/ofsoundof/yawli/experiment/' 12 | method_sr = ['FSRCNN', 'SRRESNET', 'RCAN', 'HRTM2014', 'HRTM2014_Post_Joint']#, 'HRTM2014_Post_Sep'] 13 | len_m = len(method_sr) 14 | train_flag = [''] 15 | psnr_flag = [''] 16 | len_f = len(psnr_flag) 17 | 18 | dataset = ['MiddleBury', 'Collection', 'ETH3D', 'SyB3R'] 19 | dataset_HRTM2014 = ['ColMapMiddlebury', 'Collection', 'ETH3D', 'SyB3R'] 20 | path = '/scratch_net/ofsoundof/yawli/Datasets/texture_map' 21 | psnr_all_y = [] 22 | psnr_all = [] 23 | 24 | 25 | with open('./results/THL_PSNR.pkl', 'rb') as f: 26 | psnr = pkl.load(f) 27 | method_all = psnr['method'] + [method_sr[m]+psnr_flag[f] for m in range(len_m) for f in range(len_f)] 28 | num_pre = len(psnr['method']) 29 | 30 | for d in dataset: 31 | img_hr_list = glob.glob(os.path.join(path, d, 'x1/Texture/*.png')) 32 | num = len_m * len_f 33 | psnr_ys = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 34 | psnr_ys[:, :num_pre, :] = psnr['psnr_all_y'][dataset.index(d)] 35 | psnr_s = np.zeros([len(img_hr_list)+1, num_pre+num, 3]) 36 | psnr_s[:, :num_pre, :] = psnr['psnr_all'][dataset.index(d)] 37 | 38 | for i in range(len(img_hr_list)): 39 | img_hr_n = img_hr_list[i] 40 | name_img = os.path.splitext(os.path.basename(img_hr_n))[0] 41 | print(name_img) 42 | img_hr = Image.open(img_hr_n) 43 | for s in range(2, 5): 44 | for m in range(len_m): 45 | for f in range(len_f): 46 | #from IPython import embed; embed() 47 | if method_sr[m] == 'FSRCNN': 48 | img_sr_n = os.path.join(path_sr, 'test', method_sr[m], 'x{}'.format(s), 49 | name_img+'_{}.png'.format(method_sr[m])) 50 | elif method_sr[m] == 'SRRESNET': 51 | img_sr_n = os.path.join(path_sr, 'test', method_sr[m], method_sr[m]+'_X{}_B16F64P{}'.format(s,s*24), 'results', 52 | name_img+'_x{}_SR_{}.png'.format(s,method_sr[m])) 53 | elif method_sr[m] == 'RCAN': 54 | img_sr_n = os.path.join(path_sr, 'test', method_sr[m], method_sr[m]+'_X{}'.format(s), 'results', 55 | name_img+'_x{}_{}_{}.png'.format(s,method_sr[m],method_sr[m])) 56 | elif method_sr[m] == 'HRTM2014': 57 | img_sr_n = os.path.join('/home/yawli/Documents/3d-appearance-benchmark/SR/texture', d, 58 | 'x{}/{}.png'.format(s,name_img)) 59 | else: 60 | tail = '_' + d if method_sr[m] == 'HRTM2014_Post_Sep' else '' 61 | img_sr_n1 = os.path.join(path_sr, 'test/HRST+/EDSR_X{}_F256B32P{}E100_hr_one_Input'.format(s,s*48) + tail, 'results', 62 | name_img+'_x{}_'.format(s)+'FINETUNE_EDSR.png') 63 | img_sr_n2 = os.path.join(path_sr, 'test/HRST+/EDSR_X{}_F256B32P{}E100_hr_two_Input'.format(s,s*48) + tail, 'results', 64 | name_img+'_x{}_'.format(s)+'FINETUNE_EDSR.png') 65 | img_sr_n = img_sr_n1 if os.path.exists(img_sr_n1) else img_sr_n2 66 | if os.path.exists(img_sr_n): 67 | print(img_sr_n) 68 | img_sr = Image.open(img_sr_n) 69 | w, h = img_sr.size 70 | img_hr_s = np.asarray(img_hr)[:h, :w, :] 71 | img_hr_s = shave(img_hr_s, s) 72 | img_sr_s = shave(np.asarray(img_sr), s) 73 | psnr_ys[i, m*len_f+f+num_pre, s-2], psnr_s[i, m*len_f+f+num_pre, s-2] = cal_pnsr_all(img_hr_s, img_sr_s) 74 | 75 | psnr_ys[-1, :, :] = np.mean(psnr_ys[:-1, :, :], axis=0) 76 | psnr_s[-1, :, :] = np.mean(psnr_s[:-1, :, :], axis=0) 77 | psnr_all_y.append(psnr_ys) 78 | psnr_all.append(psnr_s) 79 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_ys, d, './results/TA_PSNR_Y.html') 80 | save_html([os.path.splitext(os.path.basename(n))[0] for n in img_hr_list], method_all, psnr_s, d, './results/TA_PSNR_RGB.html') 81 | 82 | with open('./results/TA_PSNR.pkl', 'wb') as f: 83 | pkl.dump({'psnr_all_y': psnr_all_y, 'psnr_all': psnr_all, 'method': method_all}, f) 84 | 85 | psnr_sm_y = np.zeros((5, len(method_all), 3)) 86 | psnr_sm = np.zeros((5, len(method_all), 3)) 87 | 88 | for i in range(len(dataset)): 89 | psnr_sm_y[i, :, :] = psnr_all_y[i][-1, :, :] 90 | psnr_sm[i, :, :] = psnr_all[i][-1, :, :] 91 | psnr_sm_y[-1, :, :] += np.sum(psnr_all_y[i][:-1, :, :], axis=0) 92 | psnr_sm[-1, :, :] += np.sum(psnr_all[i][:-1, :, :], axis=0) 93 | psnr_sm_y[-1, :, :] = psnr_sm_y[-1, :, :]/24 94 | psnr_sm[-1, :, :] = psnr_sm[-1, :, :]/24 95 | 96 | #psnr_sm_y[-1, -2:, :] = 0 97 | #psnr_sm[-1, -2:, :] = 0 98 | 99 | save_html(dataset, method_all, psnr_sm_y, 'All', './results/TA_Summary_Y.html') 100 | save_html(dataset, method_all, psnr_sm, 'All', './results/TA_Summary_RGB.html') 101 | -------------------------------------------------------------------------------- /code/utils/compute_PSNR_UP.py: -------------------------------------------------------------------------------- 1 | __author__ = 'yawli' 2 | 3 | import numpy as np 4 | import os 5 | import glob 6 | from PIL import Image 7 | import copy 8 | import pickle as pkl 9 | 10 | def shave(img, scale): 11 | return img[scale:-scale, scale:-scale, :] 12 | 13 | def cal_pnsr(img_hr, img_mr, mask): 14 | ''' 15 | compute psnr value. Black regions are excluded from the computing. 16 | ''' 17 | mask_sum = np.sum(mask) 18 | mask_sum = mask_sum * 3 if img_hr.ndim == 3 else mask_sum 19 | mse = np.sum(np.square(img_hr - img_mr))/mask_sum 20 | psnr = 10 * np.log10(255**2/mse) 21 | return psnr 22 | 23 | def cal_pnsr_all(img_hr, img_mr): 24 | img_hr_y = rgb2ycbcr(img_hr).astype(np.float32) 25 | img_mr_y = rgb2ycbcr(img_mr).astype(np.float32) 26 | mask = (img_hr_y != 16).astype(np.float32) 27 | img_hr = img_hr.astype(np.float32) 28 | img_mr = img_mr.astype(np.float32) 29 | # from IPython import embed; embed(); exit() 30 | psnr_y = cal_pnsr(img_hr_y, img_mr_y, mask) 31 | psnr = cal_pnsr(img_hr, img_mr, mask) 32 | return psnr_y, psnr 33 | 34 | def rgb2ycbcr(img, only_y=True): 35 | '''same as matlab rgb2ycbcr 36 | only_y: only return Y channel 37 | Input: 38 | uint8, [0, 255] 39 | float, [0, 1] 40 | ''' 41 | in_img_type = img.dtype 42 | img = img.astype(np.float32) 43 | if in_img_type != np.uint8: 44 | img *= 255. 45 | # convert 46 | if only_y: 47 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 48 | else: 49 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 50 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 51 | if in_img_type == np.uint8: 52 | rlt = rlt.round() 53 | else: 54 | rlt /= 255. 55 | return rlt.astype(in_img_type) 56 | 57 | def bgr2ycbcr(img, only_y=True): 58 | '''bgr version of rgb2ycbcr 59 | only_y: only return Y channel 60 | Input: 61 | uint8, [0, 255] 62 | float, [0, 1] 63 | ''' 64 | in_img_type = img.dtype 65 | img.astype(np.float32) 66 | if in_img_type != np.uint8: 67 | img *= 255. 68 | # convert 69 | if only_y: 70 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 71 | else: 72 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 73 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 74 | if in_img_type == np.uint8: 75 | rlt = rlt.round() 76 | else: 77 | rlt /= 255. 78 | return rlt.astype(in_img_type) 79 | 80 | def ycbcr2rgb(img): 81 | '''same as matlab ycbcr2rgb 82 | Input: 83 | uint8, [0, 255] 84 | float, [0, 1] 85 | ''' 86 | in_img_type = img.dtype 87 | img.astype(np.float32) 88 | if in_img_type != np.uint8: 89 | img *= 255. 90 | # convert 91 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 92 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 93 | if in_img_type == np.uint8: 94 | rlt = rlt.round() 95 | else: 96 | rlt /= 255. 97 | return rlt.astype(in_img_type) 98 | 99 | def modcrop(img_in, scale): 100 | # img_in: Numpy, HWC or HW 101 | img = np.copy(img_in) 102 | if img.ndim == 2: 103 | H, W = img.shape 104 | H_r, W_r = H % scale, W % scale 105 | img = img[:H - H_r, :W - W_r] 106 | elif img.ndim == 3: 107 | H, W, C = img.shape 108 | H_r, W_r = H % scale, W % scale 109 | img = img[:H - H_r, :W - W_r, :] 110 | else: 111 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 112 | return img 113 | 114 | def save_file(cont, path): 115 | with open(path, 'a') as f: 116 | f.write(cont) 117 | 118 | def html_table(lol, cap): 119 | yield '
' + ' | '.join(sublist) + ' | ' 124 | yield '