├── illustrations ├── face_result.png ├── architecture.png └── computation_graph.png ├── datasets ├── example_face_8X │ ├── HR │ │ ├── 16.png │ │ ├── 36.png │ │ ├── 110.png │ │ ├── 217.png │ │ ├── 253.png │ │ └── 334.png │ └── LR │ │ ├── 16.png │ │ ├── 36.png │ │ ├── 110.png │ │ ├── 217.png │ │ ├── 253.png │ │ └── 334.png └── example_general_4X │ ├── HR │ └── butterfly.png │ └── LR │ └── butterfly.png ├── experiments └── pretrained_models │ └── README.md ├── download-weights ├── cog.yaml ├── codes ├── models │ ├── modules │ │ ├── thops.py │ │ ├── HCFlowNet_Rescaling_arch.py │ │ ├── FlowStep.py │ │ ├── HCFlowNet_SR_arch.py │ │ ├── loss.py │ │ ├── module_util.py │ │ ├── ActNorms.py │ │ ├── ConditionalFlow.py │ │ ├── Permutations.py │ │ ├── FlowNet_SR_x4.py │ │ ├── FlowNet_Rescaling_x4.py │ │ ├── discriminator_vgg_arch.py │ │ ├── FlowNet_SR_x8.py │ │ └── AffineCouplings.py │ ├── __init__.py │ ├── networks.py │ ├── lr_scheduler.py │ └── base_model.py ├── scripts │ ├── png2npy.py │ └── prepare_data_pkl.py ├── utils │ ├── timer.py │ ├── dist_util.py │ └── imresize.py ├── options │ ├── test │ │ ├── test_SR_CelebA_8X_HCFlow.yml │ │ ├── test_Rescaling_DF2K_4X_HCFlow.yml │ │ └── test_SR_DF2K_4X_HCFlow.yml │ ├── train │ │ ├── train_SR_DF2K_4X_HCFlow.yml │ │ ├── train_SR_DF2K_4X_HCFlow+.yml │ │ ├── train_SR_DF2K_4X_HCFlow++.yml │ │ ├── train_SR_CelebA_8X_HCFlow.yml │ │ ├── train_SR_CelebA_8X_HCFlow+.yml │ │ ├── train_SR_CelebA_8X_HCFlow++.yml │ │ └── train_Rescaling_DF2K_4X_HCFlow.yml │ └── options.py ├── data │ ├── __init__.py │ ├── GTLQnpy_dataset.py │ ├── LQ_dataset.py │ ├── data_sampler.py │ ├── GT_dataset.py │ ├── GTLQ_dataset.py │ ├── GTLQx_dataset.py │ └── LRHR_PKL_dataset.py └── test_HCFlow.py ├── requirements.txt ├── README.md └── LICENSE /illustrations/face_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/illustrations/face_result.png -------------------------------------------------------------------------------- /illustrations/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/illustrations/architecture.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/16.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/36.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/16.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/16.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/36.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/36.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/110.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/217.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/217.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/253.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/253.png -------------------------------------------------------------------------------- /datasets/example_face_8X/HR/334.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/HR/334.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/110.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/110.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/217.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/217.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/253.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/253.png -------------------------------------------------------------------------------- /datasets/example_face_8X/LR/334.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_face_8X/LR/334.png -------------------------------------------------------------------------------- /experiments/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | 2 | Put pretrained models here as `./pretrained_models/SR_CelebA_X8_HCFlow++.pth`. 3 | -------------------------------------------------------------------------------- /illustrations/computation_graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/illustrations/computation_graph.png -------------------------------------------------------------------------------- /datasets/example_general_4X/HR/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_general_4X/HR/butterfly.png -------------------------------------------------------------------------------- /datasets/example_general_4X/LR/butterfly.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingyunLiang/HCFlow/HEAD/datasets/example_general_4X/LR/butterfly.png -------------------------------------------------------------------------------- /download-weights: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | wget https://github.com/JingyunLiang/HCFlow/releases/download/v0.0/SR_DF2K_X4_HCFlow++.pth -P ./experiments/pretrained_models 3 | wget https://github.com/JingyunLiang/HCFlow/releases/download/v0.0/SR_CelebA_X8_HCFlow++.pth -P ./experiments/pretrained_models -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | gpu: true 3 | python_version: "3.8" 4 | system_packages: 5 | - "libgl1-mesa-glx" 6 | - "libglib2.0-0" 7 | python_packages: 8 | - "lmdb==1.2.1" 9 | - "torchvision==0.9.0" 10 | - "torch==1.8.0" 11 | - "ipython==7.19.0" 12 | - "lpips==0.1.3" 13 | - "matplotlib==3.3.2" 14 | - "natsort==7.0.1" 15 | - "numpy==1.19.4" 16 | - "opencv-python==4.4.0.46" 17 | - "pandas==1.1.4" 18 | - "Pillow==8.0.1" 19 | - "scipy==1.5.3" 20 | - "tqdm==4.51.0" 21 | 22 | predict: "predict.py:Predictor" 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /codes/models/modules/thops.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def sum(tensor, dim=None, keepdim=False): 5 | if dim is None: 6 | # sum up all dim 7 | return torch.sum(tensor) 8 | else: 9 | if isinstance(dim, int): 10 | dim = [dim] 11 | dim = sorted(dim) 12 | for d in dim: 13 | tensor = tensor.sum(dim=d, keepdim=True) 14 | if not keepdim: 15 | for i, d in enumerate(dim): 16 | tensor.squeeze_(d-i) 17 | return tensor 18 | 19 | 20 | def mean(tensor, dim=None, keepdim=False): 21 | if dim is None: 22 | # mean all dim 23 | return torch.mean(tensor) 24 | else: 25 | if isinstance(dim, int): 26 | dim = [dim] 27 | dim = sorted(dim) 28 | for d in dim: 29 | tensor = tensor.mean(dim=d, keepdim=True) 30 | if not keepdim: 31 | for i, d in enumerate(dim): 32 | tensor.squeeze_(d-i) 33 | return tensor 34 | 35 | 36 | 37 | def split_feature(tensor, type="split"): 38 | """ 39 | type = ["split", "cross"] 40 | """ 41 | C = tensor.size(1) 42 | if type == "split": 43 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 44 | elif type == "cross": 45 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 46 | 47 | 48 | def cat_feature(tensor_a, tensor_b): 49 | return torch.cat((tensor_a, tensor_b), dim=1) 50 | 51 | 52 | def pixels(tensor): 53 | return int(tensor.size(2) * tensor.size(3)) -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | 5 | try: 6 | import local_config 7 | except: 8 | local_config = None 9 | 10 | 11 | logger = logging.getLogger('base') 12 | 13 | 14 | def find_model_using_name(model_name): 15 | # Given the option --model [modelname], 16 | # the file "models/modelname_model.py" 17 | # will be imported. 18 | model_filename = "models." + model_name + "_model" 19 | modellib = importlib.import_module(model_filename) 20 | 21 | # In the file, the class called ModelNameModel() will 22 | # be instantiated. It has to be a subclass of torch.nn.Module, 23 | # and it is case-insensitive. 24 | model = None 25 | target_model_name = model_name.replace('_', '') + 'Model' 26 | for name, cls in modellib.__dict__.items(): 27 | if name.lower() == target_model_name.lower(): 28 | model = cls 29 | 30 | if model is None: 31 | print( 32 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( 33 | model_filename, target_model_name)) 34 | exit(0) 35 | 36 | return model 37 | 38 | 39 | def create_model(opt, step=0, **opt_kwargs): 40 | if local_config is not None: 41 | opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth')) 42 | 43 | for k, v in opt_kwargs.items(): 44 | opt[k] = v 45 | 46 | model = opt['model'] 47 | 48 | M = find_model_using_name(model) 49 | 50 | m = M(opt, step) 51 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 52 | return m 53 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | argon2-cffi==20.1.0 3 | async-generator==1.10 4 | attrs==20.2.0 5 | backcall==0.2.0 6 | bleach==3.2.1 7 | certifi==2020.6.20 8 | cffi==1.14.3 9 | cycler==0.10.0 10 | dataclasses==0.6 11 | decorator==4.4.2 12 | defusedxml==0.6.0 13 | entrypoints==0.3 14 | environment-kernels==1.1.1 15 | future==0.18.2 16 | imageio==2.9.0 17 | importlib-metadata==2.0.0 18 | ipykernel==5.3.4 19 | ipython==7.19.0 20 | ipython-genutils==0.2.0 21 | ipywidgets==7.5.1 22 | jedi==0.17.2 23 | Jinja2==2.11.2 24 | jsonschema==3.2.0 25 | jupyter==1.0.0 26 | jupyter-client==6.1.7 27 | jupyter-console==6.2.0 28 | jupyter-core==4.6.3 29 | jupyterlab-pygments==0.1.2 30 | kiwisolver==1.3.1 31 | lpips==0.1.3 32 | MarkupSafe==1.1.1 33 | matplotlib==3.3.2 34 | mistune==0.8.4 35 | natsort==7.0.1 36 | nbclient==0.5.1 37 | nbconvert==6.0.7 38 | nbformat==5.0.8 39 | nest-asyncio==1.4.2 40 | networkx==2.5 41 | notebook==6.1.4 42 | numpy==1.19.4 43 | opencv-python==4.4.0.46 44 | packaging==20.4 45 | pandas==1.1.4 46 | pandocfilters==1.4.3 47 | parso==0.7.1 48 | pexpect==4.8.0 49 | pickleshare==0.7.5 50 | Pillow==8.0.1 51 | prometheus-client==0.8.0 52 | prompt-toolkit==3.0.8 53 | ptyprocess==0.6.0 54 | pycparser==2.20 55 | Pygments==2.7.2 56 | pyparsing==2.4.7 57 | pyrsistent==0.17.3 58 | python-dateutil==2.8.1 59 | pytz==2020.4 60 | PyWavelets==1.1.1 61 | PyYAML==5.3.1 62 | pyzmq==19.0.2 63 | qtconsole==4.7.7 64 | QtPy==1.9.0 65 | scikit-image==0.17.2 66 | scipy==1.5.3 67 | Send2Trash==1.5.0 68 | six==1.15.0 69 | terminado==0.9.1 70 | tensorboard==2.4.0 71 | testpath==0.4.4 72 | tifffile==2020.10.1 73 | torch==1.7.1 74 | torchvision==0.8.1 75 | tornado==6.1 76 | tqdm==4.51.0 77 | traitlets==5.0.5 78 | typing-extensions==3.7.4.3 79 | wcwidth==0.2.5 80 | webencodings==0.5.1 81 | widgetsnbextension==3.5.1 82 | zipp==3.4.0 83 | -------------------------------------------------------------------------------- /codes/scripts/png2npy.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import skimage.io as sio 4 | import numpy as np 5 | 6 | # usage: python scripts/png2npy.py --pathFrom ../datasets/DIV2K+Flickr2K/DIV2K+Flickr2K_HR/ --pathTo ../datasets/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_HR/ 7 | # python scripts/png2npy.py --pathFrom ../datasets/DIV2K+Flickr2K/DIV2K+Flickr2K_LR_bicubic/ --pathTo ../datasets/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_LR_bicubic/ 8 | 9 | 10 | parser = argparse.ArgumentParser(description='Pre-processing .png images') 11 | parser.add_argument('--pathFrom', default='', 12 | help='directory of images to convert') 13 | parser.add_argument('--pathTo', default='', 14 | help='directory of images to save') 15 | parser.add_argument('--split', default=True, 16 | help='save individual images') 17 | parser.add_argument('--select', default='', 18 | help='select certain path') 19 | 20 | args = parser.parse_args() 21 | 22 | for (path, dirs, files) in os.walk(args.pathFrom): 23 | print(path) 24 | targetDir = os.path.join(args.pathTo, path[len(args.pathFrom) + 1:]) 25 | if len(args.select) > 0 and path.find(args.select) == -1: 26 | continue 27 | 28 | if not os.path.exists(targetDir): 29 | os.mkdir(targetDir) 30 | 31 | if len(dirs) == 0: 32 | pack = {} 33 | n = 0 34 | for fileName in files: 35 | (idx, ext) = os.path.splitext(fileName) 36 | if ext == '.png': 37 | image = sio.imread(os.path.join(path, fileName)) 38 | if args.split: 39 | np.save(os.path.join(targetDir, idx + '.npy'), image) 40 | n += 1 41 | if n % 100 == 0: 42 | print('Converted ' + str(n) + ' images.') 43 | -------------------------------------------------------------------------------- /codes/utils/timer.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | 4 | class ScopeTimer: 5 | def __init__(self, name): 6 | self.name = name 7 | 8 | def __enter__(self): 9 | self.start = time.time() 10 | return self 11 | 12 | def __exit__(self, *args): 13 | self.end = time.time() 14 | self.interval = self.end - self.start 15 | print("{} {:.3E}".format(self.name, self.interval)) 16 | 17 | 18 | class Timer: 19 | def __init__(self): 20 | self.times = [] 21 | 22 | def tick(self): 23 | self.times.append(time.time()) 24 | 25 | def get_average_and_reset(self): 26 | if len(self.times) < 2: 27 | return -1 28 | avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1) 29 | self.times = [self.times[-1]] 30 | return avg 31 | 32 | def get_last_iteration(self): 33 | if len(self.times) < 2: 34 | return 0 35 | return self.times[-1] - self.times[-2] 36 | 37 | 38 | class TickTock: 39 | def __init__(self): 40 | self.time_pairs = [] 41 | self.current_time = None 42 | 43 | def tick(self): 44 | self.current_time = time.time() 45 | 46 | def tock(self): 47 | assert self.current_time is not None, self.current_time 48 | self.time_pairs.append([self.current_time, time.time()]) 49 | self.current_time = None 50 | 51 | def get_average_and_reset(self): 52 | if len(self.time_pairs) == 0: 53 | return -1 54 | deltas = [t2 - t1 for t1, t2 in self.time_pairs] 55 | avg = sum(deltas) / len(deltas) 56 | self.time_pairs = [] 57 | return avg 58 | 59 | def get_last_iteration(self): 60 | if len(self.time_pairs) == 0: 61 | return -1 62 | return self.time_pairs[-1][1] - self.time_pairs[-1][0] 63 | -------------------------------------------------------------------------------- /codes/options/test/test_SR_CelebA_8X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 002_HCFlow_CelebA_x8_bicSR_test 3 | suffix: ~ 4 | use_tb_logger: true 5 | model: HCFlow_SR 6 | distortion: sr 7 | scale: 8 8 | quant: 256 9 | gpu_ids: [0] 10 | 11 | 12 | #### datasets 13 | datasets: 14 | test0: 15 | name: example_face_8X 16 | mode: GTLQ 17 | dataroot_GT: ../datasets/example_face_8X/HR 18 | dataroot_LQ: ../datasets/example_face_8X/LR 19 | 20 | # val: 21 | # name: SR_CelebA_8X_160_val 22 | # mode: LRHR_PKL 23 | # dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 24 | # dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 25 | # n_max: 20 26 | # 27 | # test: 28 | # name: SR_CelebA_8X_160_test 29 | # mode: LRHR_PKL 30 | # dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_te.pklv4 31 | # dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_te_X8.pklv4 32 | # n_max: 5000 33 | 34 | 35 | 36 | 37 | #### network structures 38 | network_G: 39 | which_model_G: HCFlowNet_SR 40 | in_nc: 3 41 | out_nc: 3 42 | act_norm_start_step: 100 43 | 44 | flowDownsampler: 45 | K: 26 46 | L: 3 47 | flow_permutation: invconv 48 | flow_coupling: Affine 49 | nn_module: FCN 50 | hidden_channels: 64 51 | cond_channels: ~ 52 | splitOff: 53 | enable: true 54 | after_flowstep: [13, 13, 13] 55 | flow_permutation: invconv 56 | flow_coupling: Affine 57 | stage1: True 58 | nn_module: FCN 59 | nn_module_last: Conv2dZeros 60 | hidden_channels: 64 61 | RRDB_nb: [5, 5] 62 | RRDB_nf: 64 63 | RRDB_gc: 32 64 | 65 | 66 | 67 | #### validation settings 68 | val: 69 | heats: [0, 0.8] 70 | n_sample: 1 71 | 72 | 73 | path: 74 | strict_load: true 75 | load_submodule: ~ 76 | # pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth 77 | # pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow+.pth 78 | pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow++.pth 79 | 80 | 81 | -------------------------------------------------------------------------------- /codes/options/test/test_Rescaling_DF2K_4X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 003_HCFlow_DF2K_x4_rescaling_test 3 | suffix: ~ 4 | use_tb_logger: true 5 | model: HCFlow_Rescaling 6 | distortion: sr 7 | scale: 4 8 | gpu_ids: [0] 9 | 10 | 11 | datasets: 12 | test0: 13 | name: example 14 | mode: GTLQ 15 | dataroot_GT: ../datasets/example_general_4X/HR 16 | dataroot_LQ: ../datasets/example_general_4X/LR 17 | 18 | # test_1: 19 | # name: Set5 20 | # mode: GTLQx 21 | # dataroot_GT: ../datasets/Set5/HR 22 | # dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 23 | # 24 | # test_2: 25 | # name: Set14 26 | # mode: GTLQx 27 | # dataroot_GT: ../datasets/Set14/HR 28 | # dataroot_LQ: ../datasets/Set14/LR_bicubic/X4 29 | # 30 | # test_3: 31 | # name: BSD100 32 | # mode: GTLQx 33 | # dataroot_GT: ../datasets/BSD100/HR 34 | # dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4 35 | # 36 | # test_4: 37 | # name: Urban100 38 | # mode: GTLQx 39 | # dataroot_GT: ../datasets/Urban100/HR 40 | # dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4 41 | # 42 | # test_5: 43 | # name: DIV2K-validation 44 | # mode: GTLQx 45 | # dataroot_GT: ../datasets/DIV2K/HR 46 | # dataroot_LQ: ../datasets/DIV2K/LR_bicubic/X4 47 | 48 | 49 | #### network structures 50 | network_G: 51 | which_model_G: HCFlowNet_Rescaling 52 | in_nc: 3 53 | out_nc: 3 54 | act_norm_start_step: 100 55 | 56 | flowDownsampler: 57 | K: 14 58 | L: 2 59 | squeeze: haar # better than squeeze2d 60 | flow_permutation: none # bettter than invconv 61 | flow_coupling: Affine3shift # better than affine 62 | nn_module: DenseBlock # better than FCN 63 | hidden_channels: 32 64 | cond_channels: ~ 65 | splitOff: 66 | enable: true 67 | after_flowstep: [6, 6] 68 | flow_permutation: invconv 69 | flow_coupling: Affine 70 | stage1: True 71 | feature_extractor: RRDB 72 | nn_module: FCN 73 | nn_module_last: Conv2dZeros 74 | hidden_channels: 64 75 | RRDB_nb: [2,1] 76 | RRDB_nf: 64 77 | RRDB_gc: 16 78 | 79 | 80 | 81 | #### validation settings 82 | val: 83 | heats: [1.0] 84 | n_sample: 1 85 | 86 | 87 | path: 88 | strict_load: true 89 | load_submodule: ~ 90 | pretrain_model_G: ../experiments/pretrained_models/Rescaling_DF2K_X4_HCFlow.pth 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /codes/options/test/test_SR_DF2K_4X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_HCFlow_DF2K_x4_bicSR_test 3 | suffix: ~ 4 | use_tb_logger: true 5 | model: HCFlow_SR 6 | distortion: sr 7 | scale: 4 8 | quant: 64 9 | gpu_ids: [0] 10 | 11 | 12 | 13 | datasets: 14 | test0: 15 | name: example 16 | mode: GTLQ 17 | dataroot_GT: ../datasets/example_general_4X/HR 18 | dataroot_LQ: ../datasets/example_general_4X/LR 19 | 20 | # test_1: 21 | # name: Set5 22 | # mode: GTLQx 23 | # dataroot_GT: ../datasets/Set5/HR 24 | # dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 25 | 26 | # test_2: 27 | # name: Set14 28 | # mode: GTLQx 29 | # dataroot_GT: ../datasets/Set14/HR 30 | # dataroot_LQ: ../datasets/Set14/LR_bicubic/X4 31 | # 32 | # test_3: 33 | # name: BSD100 34 | # mode: GTLQx 35 | # dataroot_GT: ../datasets/BSD100/HR 36 | # dataroot_LQ: ../datasets/BSD100/LR_bicubic/X4 37 | # 38 | # test_4: 39 | # name: Urban100 40 | # mode: GTLQx 41 | # dataroot_GT: ../datasets/Urban100/HR 42 | # dataroot_LQ: ../datasets/Urban100/LR_bicubic/X4 43 | # 44 | # test_5: 45 | # name: DIV2K-va-4X 46 | # mode: GTLQ 47 | # dataroot_GT: ../datasets/srflow_datasets/div2k-validation-modcrop8-gt 48 | # dataroot_LQ: ../datasets/srflow_datasets/div2k-validation-modcrop8-x4 49 | 50 | 51 | #### network structures 52 | network_G: 53 | which_model_G: HCFlowNet_SR 54 | in_nc: 3 55 | out_nc: 3 56 | act_norm_start_step: 100 57 | 58 | flowDownsampler: 59 | K: 26 60 | L: 2 61 | flow_permutation: invconv 62 | flow_coupling: Affine 63 | nn_module: FCN 64 | hidden_channels: 64 65 | cond_channels: ~ 66 | splitOff: 67 | enable: true 68 | after_flowstep: [13, 13] 69 | flow_permutation: invconv 70 | flow_coupling: Affine 71 | nn_module: FCN 72 | nn_module_last: Conv2dZeros 73 | hidden_channels: 64 74 | RRDB_nb: [7, 7] 75 | RRDB_nf: 64 76 | RRDB_gc: 32 77 | 78 | 79 | #### validation settings 80 | val: 81 | heats: [0,0, 0.9] 82 | n_sample: 1 83 | 84 | 85 | path: 86 | strict_load: true 87 | load_submodule: ~ 88 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth 89 | # pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow+.pth 90 | pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow++.pth 91 | 92 | 93 | 94 | -------------------------------------------------------------------------------- /codes/models/modules/HCFlowNet_Rescaling_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | from utils.util import opt_get 8 | from models.modules.FlowNet_Rescaling_x4 import FlowNet 9 | from models.modules import Basic, thops 10 | 11 | 12 | 13 | class HCFlowNet_Rescaling(nn.Module): 14 | def __init__(self, opt, step=None): 15 | super(HCFlowNet_Rescaling, self).__init__() 16 | self.opt = opt 17 | self.quant = opt_get(opt, ['datasets', 'train', 'quant'], 256) 18 | 19 | hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160) 20 | hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3) 21 | 22 | # hr->lr+z 23 | self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt) 24 | 25 | # hr: HR image, lr: LR image, z: latent variable, u: conditional variable 26 | def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None, 27 | add_gt_noise=False, step=None, reverse=False, training=True): 28 | 29 | # hr->z 30 | if not reverse: 31 | return self.normal_flow_diracLR(hr, lr, u, step=step, training=training) 32 | # z->hr 33 | else: # setting z to lr!!! 34 | return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training) 35 | 36 | 37 | #########################################diracLR 38 | # hr->lr+z, diracLR 39 | def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True): 40 | # 1. quantitize HR 41 | # hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant # no quantization is better 42 | 43 | # 2. hr->lr+z 44 | fake_lr_from_hr, fake_z1, fake_z2 = self.flow(hr=hr, u=u, logdet=None, reverse=False, training=training) 45 | 46 | return torch.clamp(fake_lr_from_hr, 0, 1), fake_z1, fake_z2 47 | 48 | # lr+z->hr 49 | def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True): 50 | 51 | # lr+z->hr 52 | fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training) 53 | 54 | return torch.clamp(fake_hr, 0, 1) 55 | 56 | 57 | def get_score(self, disc_loss_sigma, z): 58 | score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ 59 | z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) 60 | return -score_real -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=True) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0, 25 | pin_memory=True) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | if mode == 'GT': # load HR image and generate LR on-the-fly 31 | from data.GT_dataset import GTDataset as D 32 | dataset = D(dataset_opt) 33 | elif mode == 'GTLQ': # load generated HR-LR image pairs 34 | from data.GTLQ_dataset import GTLQDataset as D 35 | dataset = D(dataset_opt) 36 | elif mode == 'GTLQx': # load generated HR-LR image pairs, and replace with x4 37 | from data.GTLQx_dataset import GTLQxDataset as D 38 | dataset = D(dataset_opt) 39 | elif mode == 'LQ': # load LR image for testing 40 | from data.LQ_dataset import LQDataset as D 41 | dataset = D(dataset_opt) 42 | elif mode == 'LRHR_PKL': 43 | from data.LRHR_PKL_dataset import LRHR_PKLDataset as D 44 | dataset = D(dataset_opt) 45 | elif mode == 'GTLQnpy': # load generated HR-LR image pairs 46 | from data.GTLQnpy_dataset import GTLQnpyDataset as D 47 | dataset = D(dataset_opt) 48 | else: 49 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 50 | 51 | logger = logging.getLogger('base') 52 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 53 | dataset_opt['name'])) 54 | return dataset 55 | -------------------------------------------------------------------------------- /codes/models/modules/FlowStep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from utils.util import opt_get 5 | from models.modules import ActNorms, Permutations, AffineCouplings 6 | 7 | 8 | class FlowStep(nn.Module): 9 | def __init__(self, in_channels, cond_channels=None, flow_permutation='invconv', flow_coupling='Affine', LRvsothers=True, 10 | actnorm_scale=1.0, LU_decomposed=False, opt=None): 11 | super().__init__() 12 | self.flow_permutation = flow_permutation 13 | self.flow_coupling = flow_coupling 14 | 15 | # 1. actnorm 16 | self.actnorm = ActNorms.ActNorm2d(in_channels, actnorm_scale) 17 | 18 | # 2. permute # todo: maybe hurtful for downsampling; presever the structure of downsampling 19 | if self.flow_permutation == "invconv": 20 | self.permute = Permutations.InvertibleConv1x1(in_channels, LU_decomposed=LU_decomposed) 21 | elif self.flow_permutation == "none": 22 | self.permute = None 23 | 24 | # 3. coupling 25 | if self.flow_coupling == "AffineInjector": 26 | self.affine = AffineCouplings.AffineCouplingInjector(in_channels=in_channels, cond_channels=cond_channels, opt=opt) 27 | elif self.flow_coupling == "noCoupling": 28 | pass 29 | elif self.flow_coupling == "Affine": 30 | self.affine = AffineCouplings.AffineCoupling(in_channels=in_channels, cond_channels=cond_channels, opt=opt) 31 | elif self.flow_coupling == "Affine3shift": 32 | self.affine = AffineCouplings.AffineCoupling3shift(in_channels=in_channels, cond_channels=cond_channels, LRvsothers=LRvsothers, opt=opt) 33 | 34 | def forward(self, z, u=None, logdet=None, reverse=False): 35 | if not reverse: 36 | return self.normal_flow(z, u, logdet) 37 | else: 38 | return self.reverse_flow(z, u) 39 | 40 | def normal_flow(self, z, u=None, logdet=None): 41 | # 1. actnorm 42 | z, logdet = self.actnorm(z, logdet=logdet, reverse=False) 43 | 44 | # 2. permute 45 | if self.permute is not None: 46 | z, logdet = self.permute( z, logdet=logdet, reverse=False) 47 | 48 | # 3. coupling 49 | z, logdet = self.affine(z, u=u, logdet=logdet, reverse=False) 50 | 51 | return z, logdet 52 | 53 | def reverse_flow(self, z, u=None, logdet=None): 54 | # 1.coupling 55 | z, _ = self.affine(z, u=u, reverse=True) 56 | 57 | # 2. permute 58 | if self.permute is not None: 59 | z, _ = self.permute(z, reverse=True) 60 | 61 | # 3. actnorm 62 | z, _ = self.actnorm(z, reverse=True) 63 | 64 | return z, logdet 65 | 66 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import torch 4 | import models.modules.discriminator_vgg_arch as SRGAN_arch 5 | 6 | logger = logging.getLogger('base') 7 | 8 | 9 | def find_model_using_name(model_name): 10 | model_filename = "models.modules." + model_name + "_arch" 11 | modellib = importlib.import_module(model_filename) 12 | 13 | model = None 14 | target_model_name = model_name.replace('_Net', '') 15 | for name, cls in modellib.__dict__.items(): 16 | if name.lower() == target_model_name.lower(): 17 | model = cls 18 | 19 | if model is None: 20 | print( 21 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( 22 | model_filename, target_model_name)) 23 | exit(0) 24 | 25 | return model 26 | 27 | def define_Flow(opt, step): 28 | opt_net = opt['network_G'] 29 | which_model = opt_net['which_model_G'] 30 | 31 | Arch = find_model_using_name(which_model) 32 | netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 33 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step) 34 | return netG 35 | 36 | def define_G(opt, step): 37 | which_model = opt['network_G']['which_model_G'] 38 | 39 | Arch = find_model_using_name(which_model) 40 | netG = Arch(opt=opt, step=step) 41 | return netG 42 | 43 | #### Discriminator 44 | def define_D(opt): 45 | opt_net = opt['network_D'] 46 | which_model = opt_net['which_model_D'] 47 | 48 | if which_model == 'discriminator_vgg_128': 49 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 50 | elif which_model == 'discriminator_vgg_160': 51 | netD = SRGAN_arch.Discriminator_VGG_160(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 52 | elif which_model == 'PatchGANDiscriminator': 53 | netD = SRGAN_arch.PatchGANDiscriminator(in_nc=opt_net['in_nc'], ndf=opt_net['ndf'], n_layers=opt_net['n_layers'],) 54 | else: 55 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 56 | return netD 57 | 58 | 59 | #### Define Network used for Perceptual Loss 60 | def define_F(opt, use_bn=False): 61 | gpu_ids = opt['gpu_ids'] 62 | device = torch.device('cuda' if gpu_ids else 'cpu') 63 | # PyTorch pretrained VGG19-54, before ReLU. 64 | if use_bn: 65 | feature_layer = 49 66 | else: 67 | feature_layer = 34 68 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 69 | use_input_norm=True, device=device) 70 | netF.eval() # No need to train 71 | return netF 72 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_DF2K_4X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_DF2K_x4_bicSR_HCFlow 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 4 7 | quant: 64 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: DF2K_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 17 | dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: GTLQx 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 30 | 31 | 32 | #### network structures 33 | network_G: 34 | which_model_G: HCFlowNet_SR 35 | in_nc: 3 36 | out_nc: 3 37 | act_norm_start_step: 100 38 | 39 | flowDownsampler: 40 | K: 26 41 | L: 2 42 | flow_permutation: invconv 43 | flow_coupling: Affine 44 | nn_module: FCN 45 | hidden_channels: 64 46 | cond_channels: ~ 47 | splitOff: 48 | enable: true 49 | after_flowstep: [13, 13] 50 | flow_permutation: invconv 51 | flow_coupling: Affine 52 | nn_module: FCN 53 | nn_module_last: Conv2dZeros 54 | hidden_channels: 64 55 | RRDB_nb: [7, 7] 56 | RRDB_nf: 64 57 | RRDB_gc: 32 58 | 59 | network_D: 60 | which_model_D: discriminator_vgg_160 61 | in_nc: 3 62 | nf: 64 63 | 64 | 65 | #### path 66 | path: 67 | pretrain_model_G: ~ 68 | strict_load: true 69 | resume_state: auto 70 | 71 | 72 | #### training settings: learning rate scheme, loss 73 | train: 74 | lr_G: !!float 2.5e-4 75 | lr_scheme: MultiStepLR 76 | weight_decay_G: 0 77 | max_grad_clip: 5 78 | max_grad_norm: 100 79 | beta1: 0.9 80 | beta2: 0.99 81 | niter: 300000 82 | warmup_iter: -1 # no warm up 83 | lr_steps_rel: [0.5,0.75,0.9,0.95] 84 | lr_gamma: 0.5 85 | restarts: ~ 86 | restart_weights: ~ 87 | eta_min: !!float 1e-8 88 | 89 | nll_weight: 1 90 | 91 | # pixel loss 92 | pixel_criterion_hr: l1 93 | pixel_weight_hr: 0 94 | 95 | # perceptual loss 96 | eps_std_reverse: 0.9 97 | feature_criterion: l1 98 | feature_weight: 0 99 | 100 | # gan loss 101 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 102 | gan_weight: 0 103 | 104 | lr_D: 0 105 | beta1_D: 0.9 106 | beta2_D: 0.99 107 | D_update_ratio: 1 108 | D_init_iters: 1500 109 | 110 | manual_seed: 0 111 | val_freq: !!float 5e3 112 | 113 | 114 | #### validation settings 115 | val: 116 | heats: [0.0, 0.9] 117 | n_sample: 3 118 | 119 | 120 | #### logger 121 | logger: 122 | print_freq: 200 123 | save_checkpoint_freq: !!float 5e3 124 | 125 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_DF2K_4X_HCFlow+.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_DF2K_x4_bicSR_HCFlow+ 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 4 7 | quant: 64 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: DF2K_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 17 | dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: GTLQx 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 30 | 31 | 32 | #### network structures 33 | network_G: 34 | which_model_G: HCFlowNet_SR 35 | in_nc: 3 36 | out_nc: 3 37 | act_norm_start_step: 100 38 | 39 | flowDownsampler: 40 | K: 26 41 | L: 2 42 | flow_permutation: invconv 43 | flow_coupling: Affine 44 | nn_module: FCN 45 | hidden_channels: 64 46 | cond_channels: ~ 47 | splitOff: 48 | enable: true 49 | after_flowstep: [13, 13] 50 | flow_permutation: invconv 51 | flow_coupling: Affine 52 | nn_module: FCN 53 | nn_module_last: Conv2dZeros 54 | hidden_channels: 64 55 | RRDB_nb: [7, 7] 56 | RRDB_nf: 64 57 | RRDB_gc: 32 58 | 59 | network_D: 60 | which_model_D: discriminator_vgg_160 61 | in_nc: 3 62 | nf: 64 63 | 64 | 65 | #### path 66 | path: 67 | pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth 68 | strict_load: true 69 | resume_state: auto 70 | 71 | 72 | #### training settings: learning rate scheme, loss 73 | train: 74 | lr_G: !!float 5e-5 75 | lr_scheme: MultiStepLR 76 | weight_decay_G: 0 77 | max_grad_clip: 5 78 | max_grad_norm: 100 79 | beta1: 0.9 80 | beta2: 0.99 81 | niter: 50000 82 | warmup_iter: -1 # no warm up 83 | lr_steps: [20000, 40000] 84 | lr_gamma: 0.5 85 | restarts: ~ 86 | restart_weights: ~ 87 | eta_min: !!float 1e-8 88 | 89 | nll_weight: !!float 2e-3 90 | 91 | # pixel loss 92 | pixel_criterion_hr: l1 93 | pixel_weight_hr: 1.0 94 | 95 | # perceptual loss 96 | eps_std_reverse: 0.9 97 | feature_criterion: l1 98 | feature_weight: 0 99 | 100 | # gan loss 101 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 102 | gan_weight: 0 103 | 104 | lr_D: 0 105 | beta1_D: 0.9 106 | beta2_D: 0.99 107 | D_update_ratio: 1 108 | D_init_iters: 1500 109 | 110 | manual_seed: 0 111 | val_freq: !!float 5e3 112 | 113 | 114 | #### validation settings 115 | val: 116 | heats: [0.0, 0.9] # 0.9 has best visual quality for general SR 117 | n_sample: 3 118 | 119 | 120 | #### logger 121 | logger: 122 | print_freq: 200 123 | save_checkpoint_freq: !!float 5e3 124 | 125 | -------------------------------------------------------------------------------- /codes/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_DF2K_4X_HCFlow++.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 001_DF2K_x4_bicSR_HCFlow++ 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 4 7 | quant: 64 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: DF2K_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: ../datasets/srflow_datasets/DF2K-tr.pklv4 17 | dataroot_LQ: ../datasets/srflow_datasets/DF2K-tr_X4.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: Set5 27 | mode: GTLQx 28 | dataroot_GT: ../datasets/Set5/HR 29 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 30 | 31 | 32 | #### network structures 33 | network_G: 34 | which_model_G: HCFlowNet_SR 35 | in_nc: 3 36 | out_nc: 3 37 | act_norm_start_step: 100 38 | 39 | flowDownsampler: 40 | K: 26 41 | L: 2 42 | flow_permutation: invconv 43 | flow_coupling: Affine 44 | nn_module: FCN 45 | hidden_channels: 64 46 | cond_channels: ~ 47 | splitOff: 48 | enable: true 49 | after_flowstep: [13, 13] 50 | flow_permutation: invconv 51 | flow_coupling: Affine 52 | nn_module: FCN 53 | nn_module_last: Conv2dZeros 54 | hidden_channels: 64 55 | RRDB_nb: [7, 7] 56 | RRDB_nf: 64 57 | RRDB_gc: 32 58 | 59 | network_D: 60 | which_model_D: discriminator_vgg_160 61 | in_nc: 3 62 | nf: 64 63 | 64 | 65 | #### path 66 | path: 67 | pretrain_model_G: ../experiments/pretrained_models/SR_DF2K_X4_HCFlow.pth 68 | strict_load: true 69 | resume_state: auto 70 | 71 | 72 | #### training settings: learning rate scheme, loss 73 | train: 74 | lr_G: !!float 1.25e-5 75 | lr_scheme: MultiStepLR 76 | weight_decay_G: 0 77 | max_grad_clip: 5 78 | max_grad_norm: 100 79 | beta1: 0.9 80 | beta2: 0.99 81 | niter: 50000 82 | warmup_iter: -1 # no warm up 83 | lr_steps: [20000, 40000] 84 | lr_gamma: 0.5 85 | restarts: ~ 86 | restart_weights: ~ 87 | eta_min: !!float 1e-8 88 | 89 | nll_weight: !!float 2e-3 90 | 91 | # pixel loss 92 | pixel_criterion_hr: l1 93 | pixel_weight_hr: 1.0 94 | 95 | # perceptual loss 96 | eps_std_reverse: 0.9 97 | feature_criterion: l1 98 | feature_weight: !!float 5e-2 # balance diversity and lpips 99 | 100 | # gan loss 101 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 102 | gan_weight: !!float 5e-1 103 | 104 | lr_D: !!float 5e-5 105 | beta1_D: 0.9 106 | beta2_D: 0.99 107 | D_update_ratio: 1 108 | D_init_iters: 1500 109 | 110 | manual_seed: 0 111 | val_freq: !!float 5e3 112 | 113 | 114 | #### validation settings 115 | val: 116 | heats: [0.0, 0.9] # 0.9 has best visual quality for general SR 117 | n_sample: 3 118 | 119 | 120 | #### logger 121 | logger: 122 | print_freq: 200 123 | save_checkpoint_freq: !!float 5e3 124 | 125 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_CelebA_8X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 002_CelebA_x8_bicSR_HCFlow 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 8 7 | quant: 256 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: CelebA_160_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 17 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: CelebA_160_va 27 | mode: LRHR_PKL 28 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 29 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 30 | n_max: 20 31 | 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: HCFlowNet_SR 36 | in_nc: 3 37 | out_nc: 3 38 | act_norm_start_step: 100 39 | 40 | flowDownsampler: 41 | K: 26 42 | L: 3 43 | flow_permutation: invconv 44 | flow_coupling: Affine 45 | nn_module: FCN 46 | hidden_channels: 64 47 | cond_channels: ~ 48 | splitOff: 49 | enable: true 50 | after_flowstep: [13, 13, 13] 51 | flow_permutation: invconv 52 | flow_coupling: Affine 53 | nn_module: FCN 54 | nn_module_last: Conv2dZeros 55 | hidden_channels: 64 56 | RRDB_nb: [5, 5] 57 | RRDB_nf: 64 58 | RRDB_gc: 32 59 | 60 | network_D: 61 | which_model_D: discriminator_vgg_160 62 | in_nc: 3 63 | nf: 64 64 | 65 | 66 | #### path 67 | path: 68 | pretrain_model_G: ~ 69 | strict_load: true 70 | resume_state: auto 71 | 72 | 73 | #### training settings: learning rate scheme, loss 74 | train: 75 | lr_G: !!float 2.5e-4 76 | lr_scheme: MultiStepLR 77 | weight_decay_G: 0 78 | max_grad_clip: 5 79 | max_grad_norm: 100 80 | beta1: 0.9 81 | beta2: 0.99 82 | niter: 350000 83 | warmup_iter: -1 # no warm up 84 | lr_steps: [200000, 250000, 280000, 310000, 340000] 85 | lr_gamma: 0.5 86 | restarts: ~ 87 | restart_weights: ~ 88 | eta_min: !!float 1e-8 89 | 90 | nll_weight: 1 91 | 92 | # pixel loss 93 | pixel_criterion_hr: l1 94 | pixel_weight_hr: 0 95 | 96 | # perceptual loss 97 | eps_std_reverse: 0.8 98 | feature_criterion: l1 99 | feature_weight: 0 100 | 101 | # gan loss 102 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 103 | gan_weight: 0 104 | 105 | lr_D: 0 106 | beta1_D: 0.9 107 | beta2_D: 0.99 108 | D_update_ratio: 1 109 | D_init_iters: 1500 110 | 111 | manual_seed: 0 112 | val_freq: !!float 5e3 113 | 114 | 115 | #### validation settings 116 | val: 117 | heats: [0.0, 0.8] 118 | n_sample: 3 119 | 120 | 121 | #### logger 122 | logger: 123 | print_freq: 200 124 | save_checkpoint_freq: !!float 5e3 125 | 126 | -------------------------------------------------------------------------------- /codes/models/modules/HCFlowNet_SR_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from utils.util import opt_get 7 | from models.modules import Basic, thops 8 | 9 | 10 | 11 | class HCFlowNet_SR(nn.Module): 12 | def __init__(self, opt, step=None): 13 | super(HCFlowNet_SR, self).__init__() 14 | self.opt = opt 15 | self.quant = opt_get(opt, ['quant'], 256) 16 | 17 | hr_size = opt_get(opt, ['datasets', 'train', 'GT_size'], 160) 18 | hr_channel = opt_get(opt, ['network_G', 'in_nc'], 3) 19 | scale = opt_get(opt, ['scale']) 20 | 21 | if scale == 4: 22 | from models.modules.FlowNet_SR_x4 import FlowNet 23 | elif scale == 8: 24 | from models.modules.FlowNet_SR_x8 import FlowNet 25 | else: 26 | raise NotImplementedError('Scale {} is not implemented'.format(scale)) 27 | 28 | # hr->lr+z 29 | self.flow = FlowNet((hr_size, hr_size, hr_channel), opt=opt) 30 | 31 | self.quantization = Basic.Quantization() 32 | 33 | # hr: HR image, lr: LR image, z: latent variable, u: conditional variable 34 | def forward(self, hr=None, lr=None, z=None, u=None, eps_std=None, 35 | add_gt_noise=False, step=None, reverse=False, training=True): 36 | 37 | # hr->z 38 | if not reverse: 39 | return self.normal_flow_diracLR(hr, lr, u, step=step, training=training) 40 | # z->hr 41 | else: 42 | return self.reverse_flow_diracLR(lr, z, u, eps_std=eps_std, training=training) 43 | 44 | 45 | #########################################diracLR 46 | # hr->lr+z, diracLR 47 | def normal_flow_diracLR(self, hr, lr, u=None, step=None, training=True): 48 | # 1. quantitize HR 49 | pixels = thops.pixels(hr) 50 | 51 | # according to Glow and ours, it should be u~U(0,a) (0.06 better in practice), not u~U(-0.5,0.5) (though better in theory) 52 | hr = hr + (torch.rand(hr.shape, device=hr.device)) / self.quant 53 | logdet = torch.zeros_like(hr[:, 0, 0, 0]) + float(-np.log(self.quant) * pixels) 54 | 55 | # 2. hr->lr+z 56 | fake_lr_from_hr, logdet = self.flow(hr=hr, u=u, logdet=logdet, reverse=False, training=training) 57 | 58 | # note in rescaling, we use LR for LR loss before quantization 59 | fake_lr_from_hr = self.quantization(fake_lr_from_hr) 60 | 61 | # 3. loss, Gaussian with small variance to approximate Dirac delta function of LR. 62 | # for the second term, using small log-variance may lead to svd problem, for both exp and tanh version 63 | objective = logdet + Basic.GaussianDiag.logp(lr, -torch.ones_like(lr)*6, fake_lr_from_hr) 64 | 65 | nll = ((-objective) / float(np.log(2.) * pixels)).mean() 66 | 67 | return torch.clamp(fake_lr_from_hr, 0, 1), nll 68 | 69 | # lr+z->hr 70 | def reverse_flow_diracLR(self, lr, z, u, eps_std, training=True): 71 | 72 | # lr+z->hr 73 | fake_hr = self.flow(z=lr, u=u, eps_std=eps_std, reverse=True, training=training) 74 | 75 | return torch.clamp(fake_hr, 0, 1) 76 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_CelebA_8X_HCFlow+.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 002_CelebA_x8_bicSR_HCFlow+ 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 8 7 | quant: 256 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: CelebA_160_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 17 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: CelebA_160_va 27 | mode: LRHR_PKL 28 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 29 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 30 | n_max: 20 31 | 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: HCFlowNet_SR 36 | in_nc: 3 37 | out_nc: 3 38 | act_norm_start_step: 100 39 | 40 | flowDownsampler: 41 | K: 26 42 | L: 3 43 | flow_permutation: invconv 44 | flow_coupling: Affine 45 | nn_module: FCN 46 | hidden_channels: 64 # 64 and 128 are similar 47 | cond_channels: ~ # affine coupling in the main trunk, testo 3 or None 48 | splitOff: 49 | enable: true 50 | after_flowstep: [13, 13, 13] 51 | flow_permutation: invconv 52 | flow_coupling: Affine 53 | stage1: True 54 | nn_module: FCN 55 | nn_module_last: Conv2dZeros 56 | hidden_channels: 64 57 | RRDB_nb: [5, 5] 58 | RRDB_nf: 64 59 | RRDB_gc: 32 60 | 61 | network_D: 62 | which_model_D: discriminator_vgg_160 63 | in_nc: 3 64 | nf: 64 65 | 66 | 67 | #### path 68 | path: 69 | pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth 70 | strict_load: true 71 | resume_state: auto 72 | 73 | 74 | #### training settings: learning rate scheme, loss 75 | train: 76 | lr_G: !!float 5e-5 77 | lr_scheme: MultiStepLR 78 | weight_decay_G: 0 79 | max_grad_clip: 5 80 | max_grad_norm: 100 81 | beta1: 0.9 82 | beta2: 0.99 83 | niter: 50000 84 | warmup_iter: -1 # no warm up 85 | lr_steps: [20000, 40000] 86 | lr_gamma: 0.5 87 | restarts: ~ 88 | restart_weights: ~ 89 | eta_min: !!float 1e-8 90 | 91 | nll_weight: !!float 2e-3 92 | 93 | # pixel loss 94 | pixel_criterion_hr: l1 95 | pixel_weight_hr: 1.0 96 | 97 | # perceptual loss 98 | eps_std_reverse: 0.8 99 | feature_criterion: l1 100 | feature_weight: 0 101 | 102 | # gan loss 103 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 104 | gan_weight: 0 105 | 106 | lr_D: 0 107 | beta1_D: 0.9 108 | beta2_D: 0.99 109 | D_update_ratio: 1 110 | D_init_iters: 1500 111 | 112 | manual_seed: 0 113 | val_freq: !!float 5e3 114 | 115 | 116 | #### validation settings 117 | val: 118 | heats: [0.0, 0.8] # 0.8 has best visual quality for face SR 119 | n_sample: 3 120 | 121 | 122 | #### logger 123 | logger: 124 | print_freq: 200 125 | save_checkpoint_freq: !!float 5e3 126 | 127 | -------------------------------------------------------------------------------- /codes/options/train/train_SR_CelebA_8X_HCFlow++.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 002_CelebA_x8_bicSR_HCFlow++ 3 | use_tb_logger: true 4 | model: HCFlow_SR 5 | distortion: sr 6 | scale: 8 7 | quant: 256 8 | gpu_ids: [0] 9 | 10 | 11 | #### datasets 12 | datasets: 13 | train: 14 | name: CelebA_160_tr 15 | mode: LRHR_PKL 16 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr.pklv4 17 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_tr_X8.pklv4 18 | 19 | use_shuffle: true 20 | n_workers: 16 21 | batch_size: 16 22 | GT_size: 160 23 | use_flip: true 24 | color: RGB 25 | val: 26 | name: CelebA_160_va 27 | mode: LRHR_PKL 28 | dataroot_GT: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va.pklv4 29 | dataroot_LQ: /cluster/work/cvl/sr_datasets/srflow/celebA/CelebAHq_160_MBic_va_X8.pklv4 30 | n_max: 20 31 | 32 | 33 | #### network structures 34 | network_G: 35 | which_model_G: HCFlowNet_SR 36 | in_nc: 3 37 | out_nc: 3 38 | act_norm_start_step: 100 39 | 40 | flowDownsampler: 41 | K: 26 42 | L: 3 43 | flow_permutation: invconv 44 | flow_coupling: Affine 45 | nn_module: FCN 46 | hidden_channels: 64 # 64 and 128 are similar 47 | cond_channels: ~ # affine coupling in the main trunk, testo 3 or None 48 | splitOff: 49 | enable: true 50 | after_flowstep: [13, 13, 13] 51 | flow_permutation: invconv 52 | flow_coupling: Affine 53 | stage1: True 54 | nn_module: FCN 55 | nn_module_last: Conv2dZeros 56 | hidden_channels: 64 57 | RRDB_nb: [5, 5] 58 | RRDB_nf: 64 59 | RRDB_gc: 32 60 | 61 | network_D: 62 | which_model_D: discriminator_vgg_160 63 | in_nc: 3 64 | nf: 64 65 | 66 | 67 | #### path 68 | path: 69 | pretrain_model_G: ../experiments/pretrained_models/SR_CelebA_X8_HCFlow.pth 70 | strict_load: true 71 | resume_state: auto 72 | 73 | 74 | #### training settings: learning rate scheme, loss 75 | train: 76 | lr_G: !!float 1.25e-5 77 | lr_scheme: MultiStepLR 78 | weight_decay_G: 0 79 | max_grad_clip: 5 80 | max_grad_norm: 100 81 | beta1: 0.9 82 | beta2: 0.99 83 | niter: 50000 84 | warmup_iter: -1 # no warm up 85 | lr_steps: [20000, 40000] 86 | lr_gamma: 0.5 87 | restarts: ~ 88 | restart_weights: ~ 89 | eta_min: !!float 1e-8 90 | 91 | nll_weight: !!float 2e-3 92 | 93 | # pixel loss 94 | pixel_criterion_hr: l1 95 | pixel_weight_hr: 1.0 96 | 97 | # perceptual loss 98 | eps_std_reverse: 0.8 99 | feature_criterion: l1 100 | feature_weight: !!float 5e-2 # balance diversity and lpips 101 | 102 | # gan loss 103 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 104 | gan_weight: !!float 5e-1 105 | 106 | lr_D: !!float 5e-5 107 | beta1_D: 0.9 108 | beta2_D: 0.99 109 | D_update_ratio: 1 110 | D_init_iters: 1500 111 | 112 | manual_seed: 0 113 | val_freq: !!float 5e3 114 | 115 | 116 | #### validation settings 117 | val: 118 | heats: [0.0, 0.8] # 0.8 has best visual quality for face SR 119 | n_sample: 3 120 | 121 | 122 | #### logger 123 | logger: 124 | print_freq: 200 125 | save_checkpoint_freq: !!float 5e3 126 | 127 | -------------------------------------------------------------------------------- /codes/options/train/train_Rescaling_DF2K_4X_HCFlow.yml: -------------------------------------------------------------------------------- 1 | #### general settings 2 | name: 003_DF2K_x4_rescaling_HCFlow 3 | use_tb_logger: true 4 | model: HCFlow_Rescaling 5 | distortion: sr 6 | scale: 4 7 | gpu_ids: [0] 8 | 9 | 10 | #### datasets 11 | datasets: 12 | train: 13 | name: DF2K_tr 14 | mode: GTLQnpy 15 | dataroot_GT: /cluster/work/cvl/jinliang_dataset/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_HR 16 | dataroot_LQ: /cluster/work/cvl/jinliang_dataset/DIV2K+Flickr2K_decoded/DIV2K+Flickr2K_LR_bicubic/X4 17 | 18 | use_shuffle: true 19 | n_workers: 16 20 | batch_size: 16 21 | GT_size: 160 22 | use_flip: true 23 | use_rot: true 24 | color: RGB 25 | 26 | val: 27 | name: Set5 28 | mode: GTLQx 29 | dataroot_GT: ../datasets/Set5/HR 30 | dataroot_LQ: ../datasets/Set5/LR_bicubic/X4 31 | 32 | 33 | # The optimization may not be stable for rescaling (+-0.1dB). A simple trick: for each stage of learning rate, 34 | # resume training from the best model of the previous stage of learning rate. 35 | #### network structures 36 | network_G: 37 | which_model_G: HCFlowNet_Rescaling 38 | in_nc: 3 39 | out_nc: 3 40 | act_norm_start_step: 100 41 | 42 | flowDownsampler: 43 | K: 14 44 | L: 2 45 | squeeze: haar # better than squeeze2d 46 | flow_permutation: none # bettter than invconv 47 | flow_coupling: Affine3shift # better than affine 48 | nn_module: DenseBlock # better than FCN 49 | hidden_channels: 32 50 | cond_channels: ~ 51 | splitOff: 52 | enable: true 53 | after_flowstep: [6, 6] 54 | flow_permutation: invconv 55 | flow_coupling: Affine 56 | stage1: True 57 | feature_extractor: RRDB 58 | nn_module: FCN 59 | nn_module_last: Conv2dZeros 60 | hidden_channels: 64 61 | RRDB_nb: [2,1] 62 | RRDB_nf: 64 63 | RRDB_gc: 16 64 | 65 | 66 | #### path 67 | path: 68 | pretrain_model_G: ~ 69 | strict_load: true 70 | resume_state: auto 71 | 72 | #### training settings: learning rate scheme, loss 73 | train: 74 | two_stage_opt: True 75 | 76 | lr_G: !!float 2.5e-4 77 | lr_scheme: MultiStepLR 78 | weight_decay_G: 0 79 | max_grad_clip: 5 80 | max_grad_norm: 100 81 | beta1: 0.9 82 | beta2: 0.99 83 | niter: 500000 84 | warmup_iter: -1 # no warm up 85 | lr_steps: [100000, 200000, 300000, 400000, 450000] 86 | lr_gamma: 0.5 87 | restarts: ~ 88 | restart_weights: ~ 89 | eta_min: !!float 1e-8 90 | 91 | weight_z: !!float 1e-5 92 | 93 | pixel_criterion_lr: l2 94 | pixel_weight_lr: !!float 5e-2 95 | 96 | eps_std_reverse: 1.0 97 | pixel_criterion_hr: l1 98 | pixel_weight_hr: 1.0 99 | 100 | # perceptual loss 101 | feature_criterion: l1 102 | feature_weight: 0 103 | 104 | # gan loss 105 | gan_type: gan # gan | lsgan | wgangp | ragan (patchgan uses lsgan) 106 | gan_weight: 0 107 | 108 | lr_D: 0 109 | beta1_D: 0.9 110 | beta2_D: 0.99 111 | D_update_ratio: 1 112 | D_init_iters: 1500 113 | 114 | manual_seed: 0 115 | val_freq: !!float 5e3 116 | 117 | 118 | #### validation settings 119 | val: 120 | heats: [0.0, 1.0] 121 | 122 | 123 | #### logger 124 | logger: 125 | print_freq: 200 126 | save_checkpoint_freq: !!float 5e3 127 | 128 | -------------------------------------------------------------------------------- /codes/models/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [gan(vanilla) | lsgan | wgan-gp | ragan] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | 76 | class ReconstructionLoss(nn.Module): 77 | def __init__(self, losstype='l2', eps=1e-6): 78 | super(ReconstructionLoss, self).__init__() 79 | self.losstype = losstype 80 | self.eps = eps 81 | 82 | def forward(self, x, target): 83 | if self.losstype == 'l2': 84 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) 85 | elif self.losstype == 'l1': 86 | diff = x - target 87 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) 88 | else: 89 | print("reconstruction loss type error!") 90 | return 0 91 | 92 | -------------------------------------------------------------------------------- /codes/data/GTLQnpy_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import sys 9 | import os 10 | 11 | try: 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from data.util import imresize_np 14 | from utils import util as utils 15 | except ImportError: 16 | pass 17 | 18 | 19 | class GTLQnpyDataset(data.Dataset): 20 | ''' 21 | Load HR-LR image npy pairs. Make sure HR-LR images are in the same order. 22 | ''' 23 | 24 | def __init__(self, opt): 25 | super(GTLQnpyDataset, self).__init__() 26 | self.opt = opt 27 | self.LR_paths, self.GT_paths = None, None 28 | self.scale = opt['scale'] 29 | if self.opt['phase'] == 'train': 30 | self.GT_size = opt['GT_size'] 31 | self.LR_size = self.GT_size // self.scale 32 | 33 | self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list 34 | self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list 35 | 36 | assert self.GT_paths, 'Error: GT paths are empty.' 37 | if self.LR_paths and self.GT_paths: 38 | assert len(self.LR_paths) == len( 39 | self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( 40 | len(self.LR_paths), len(self.GT_paths)) 41 | 42 | def __getitem__(self, index): 43 | # get GT and LR image 44 | GT_path = self.GT_paths[index] 45 | # LR_path = self.LR_paths[index] 46 | LR_path = GT_path.replace('DIV2K+Flickr2K_HR', 'DIV2K+Flickr2K_LR_bicubic/X4').replace('.npy','x{}.npy'.format(self.scale)) 47 | img_GT = util.read_img_fromnpy(np.load(GT_path)) 48 | img_LR = util.read_img_fromnpy(np.load(LR_path)) # return: Numpy float32, HWC, BGR, [0,1] 49 | 50 | if self.opt['phase'] == 'train': 51 | # crop 52 | H, W, C = img_LR.shape 53 | rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) 54 | rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) 55 | rnd_top_GT = rnd_top_LR * self.scale 56 | rnd_left_GT = rnd_left_LR * self.scale 57 | 58 | img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] 59 | img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] 60 | 61 | # augmentation - flip, rotate 62 | img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], 63 | self.opt['use_rot'], self.opt['mode']) 64 | 65 | # change color space if necessary, deal with gray image 66 | if self.opt['color']: 67 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 68 | img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] 69 | 70 | # BGR to RGB, HWC to CHW, numpy to tensor 71 | if img_GT.shape[2] == 3: 72 | img_GT = img_GT[:, :, [2, 1, 0]] 73 | if img_LR.shape[2] == 3: 74 | img_LR = img_LR[:, :, [2, 1, 0]] 75 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 76 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 77 | 78 | return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} 79 | 80 | def __len__(self): 81 | return len(self.GT_paths) 82 | 83 | -------------------------------------------------------------------------------- /codes/models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | def initialize_weights_xavier(net_l, scale=1): 27 | if not isinstance(net_l, list): 28 | net_l = [net_l] 29 | for net in net_l: 30 | for m in net.modules(): 31 | if isinstance(m, nn.Conv2d): 32 | init.xavier_normal_(m.weight) 33 | m.weight.data *= scale # for residual block 34 | if m.bias is not None: 35 | m.bias.data.zero_() 36 | elif isinstance(m, nn.Linear): 37 | init.xavier_normal_(m.weight) 38 | m.weight.data *= scale 39 | if m.bias is not None: 40 | m.bias.data.zero_() 41 | elif isinstance(m, nn.BatchNorm2d): 42 | init.constant_(m.weight, 1) 43 | init.constant_(m.bias.data, 0.0) 44 | 45 | 46 | def make_layer(block, n_layers): 47 | layers = [] 48 | for _ in range(n_layers): 49 | layers.append(block()) 50 | return nn.Sequential(*layers) 51 | 52 | 53 | class ResidualBlock_noBN(nn.Module): 54 | '''Residual block w/o BN 55 | ---Conv-ReLU-Conv-+- 56 | |________________| 57 | ''' 58 | 59 | def __init__(self, nf=64): 60 | super(ResidualBlock_noBN, self).__init__() 61 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 62 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 63 | 64 | # initialization 65 | initialize_weights([self.conv1, self.conv2], 0.1) 66 | 67 | def forward(self, x): 68 | identity = x 69 | out = F.relu(self.conv1(x), inplace=True) 70 | out = self.conv2(out) 71 | return identity + out 72 | 73 | 74 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 75 | """Warp an image or feature map with optical flow 76 | Args: 77 | x (Tensor): size (N, C, H, W) 78 | flow (Tensor): size (N, H, W, 2), normal value 79 | interp_mode (str): 'nearest' or 'bilinear' 80 | padding_mode (str): 'zeros' or 'border' or 'reflection' 81 | 82 | Returns: 83 | Tensor: warped image or feature map 84 | """ 85 | assert x.size()[-2:] == flow.size()[1:3] 86 | B, C, H, W = x.size() 87 | # mesh grid 88 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 89 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 90 | grid.requires_grad = False 91 | grid = grid.type_as(x) 92 | vgrid = grid + flow 93 | # scale grid to [-1,1] 94 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 95 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 96 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 97 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 98 | return output 99 | -------------------------------------------------------------------------------- /codes/scripts/prepare_data_pkl.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import sys 18 | 19 | import numpy as np 20 | import random 21 | import imageio 22 | 23 | from natsort import natsort 24 | from tqdm import tqdm 25 | 26 | def get_img_paths(dir_path, wildcard='*.png'): 27 | return natsort.natsorted(glob.glob(dir_path + '/' + wildcard)) 28 | 29 | def create_all_dirs(path): 30 | if "." in path.split("/")[-1]: 31 | dirs = os.path.dirname(path) 32 | else: 33 | dirs = path 34 | os.makedirs(dirs, exist_ok=True) 35 | 36 | def to_pklv4(obj, path, vebose=False): 37 | create_all_dirs(path) 38 | with open(path, 'wb') as f: 39 | pickle.dump(obj, f, protocol=4) 40 | if vebose: 41 | print("Wrote {}".format(path)) 42 | 43 | 44 | from imresize import imresize 45 | 46 | def random_crop(img, size): 47 | h, w, c = img.shape 48 | 49 | h_start = np.random.randint(0, h - size) 50 | h_end = h_start + size 51 | 52 | w_start = np.random.randint(0, w - size) 53 | w_end = w_start + size 54 | 55 | return img[h_start:h_end, w_start:w_end] 56 | 57 | 58 | def imread(img_path): 59 | img = imageio.imread(img_path) 60 | if len(img.shape) == 2: 61 | img = np.stack([img, ] * 3, axis=2) 62 | return img 63 | 64 | 65 | def to_pklv4_1pct(obj, path, vebose): 66 | n = int(round(len(obj) * 0.01)) 67 | path = path.replace(".", "_1pct.") 68 | to_pklv4(obj[:n], path, vebose=True) 69 | 70 | 71 | def main(dir_path): 72 | hrs = [] 73 | lqs = [] 74 | 75 | img_paths = get_img_paths(dir_path) 76 | for img_path in tqdm(img_paths): 77 | img = imread(img_path) 78 | 79 | for i in range(47): 80 | crop = random_crop(img, 160) 81 | cropX4 = imresize(crop, scalar_scale=0.25) 82 | hrs.append(crop) 83 | lqs.append(cropX4) 84 | 85 | shuffle_combined(hrs, lqs) 86 | 87 | hrs_path = get_hrs_path(dir_path) 88 | to_pklv4(hrs, hrs_path, vebose=True) 89 | to_pklv4_1pct(hrs, hrs_path, vebose=True) 90 | 91 | lqs_path = get_lqs_path(dir_path) 92 | to_pklv4(lqs, lqs_path, vebose=True) 93 | to_pklv4_1pct(lqs, lqs_path, vebose=True) 94 | 95 | 96 | def get_hrs_path(dir_path): 97 | base_dir = os.path.dirname(dir_path) 98 | name = os.path.basename(dir_path) 99 | hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4') 100 | return hrs_path 101 | 102 | 103 | def get_lqs_path(dir_path): 104 | base_dir = os.path.dirname(dir_path) 105 | name = os.path.basename(dir_path) 106 | hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4') 107 | return hrs_path 108 | 109 | 110 | def shuffle_combined(hrs, lqs): 111 | combined = list(zip(hrs, lqs)) 112 | random.shuffle(combined) 113 | hrs[:], lqs[:] = zip(*combined) 114 | 115 | 116 | if __name__ == "__main__": 117 | dir_path = sys.argv[1] 118 | assert os.path.isdir(dir_path) 119 | main(dir_path) 120 | -------------------------------------------------------------------------------- /codes/data/LQ_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import sys 9 | import os 10 | 11 | try: 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from data.util import imresize_np 14 | from utils import util as utils 15 | except ImportError: 16 | pass 17 | 18 | 19 | class LQDataset(data.Dataset): 20 | ''' 21 | Load LR images only. 22 | ''' 23 | 24 | def __init__(self, opt): 25 | super(LQDataset, self).__init__() 26 | self.opt = opt 27 | self.LR_paths, self.GT_paths = None, None 28 | self.LR_env, self.GT_env = None, None # environment for lmdb 29 | self.scale = opt['scale'] 30 | if self.opt['phase'] == 'train': 31 | self.GT_size = opt['GT_size'] 32 | self.LR_size = self.GT_size // self.scale 33 | 34 | # read image list from lmdb or image files 35 | if opt['data_type'] == 'lmdb': 36 | self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) 37 | self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) 38 | elif opt['data_type'] == 'img': 39 | self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list 40 | # self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list 41 | else: 42 | print('Error: data_type is not matched in Dataset') 43 | assert self.LR_paths, 'Error: LR paths are empty.' 44 | 45 | def _init_lmdb(self): 46 | # https://github.com/chainer/chainermn/issues/129 47 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 48 | meminit=False) 49 | if self.opt['dataroot_LQ'] is not None: 50 | self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 51 | meminit=False) 52 | else: 53 | self.LR_env = 'No lmdb input for LR' 54 | 55 | def __getitem__(self, index): 56 | if self.opt['data_type'] == 'lmdb': 57 | if (self.GT_env is None) or (self.LR_env is None): 58 | self._init_lmdb() 59 | 60 | if self.opt['data_type'] == 'lmdb': 61 | resolution = [int(s) for s in self.GT_sizes[index].split('_')] 62 | else: 63 | resolution = None 64 | 65 | 66 | # loading code from srflow test 67 | # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] 68 | # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 69 | # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] 70 | # pad_factor = 2 71 | # h, w, c = img_LR.shape 72 | # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), 73 | # right=int(np.ceil(w / pad_factor) * pad_factor - w)) 74 | # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 75 | 76 | 77 | # get LR image 78 | LR_path = self.LR_paths[index] 79 | img_LR = util.read_img(self.LR_env, LR_path, resolution) 80 | 81 | # change color space if necessary, deal with gray image 82 | if self.opt['color']: 83 | img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] 84 | 85 | # BGR to RGB, HWC to CHW, numpy to tensor 86 | if img_LR.shape[2] == 3: 87 | img_LR = img_LR[:, :, [2, 1, 0]] 88 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 89 | 90 | return {'LQ': img_LR, 'LQ_path': LR_path} 91 | 92 | def __len__(self): 93 | return len(self.LR_paths) 94 | 95 | 96 | def impad(img, top=0, bottom=0, left=0, right=0, color=255): 97 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') 98 | -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() #Returns a random permutation of integers from 0 to n - 1 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | 67 | 68 | class EnlargedSampler(Sampler): 69 | """Sampler that restricts data loading to a subset of the dataset. 70 | Modified from torch.utils.data.distributed.DistributedSampler 71 | Support enlarging the dataset for iteration-based training, for saving 72 | time when restart the dataloader after each epoch 73 | Args: 74 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 75 | num_replicas (int | None): Number of processes participating in 76 | the training. It is usually the world_size. 77 | rank (int | None): Rank of the current process within num_replicas. 78 | ratio (int): Enlarging ratio. Default: 1. 79 | """ 80 | 81 | def __init__(self, dataset, num_replicas, rank, ratio=1): 82 | self.dataset = dataset 83 | self.num_replicas = num_replicas 84 | self.rank = rank 85 | self.epoch = 0 86 | self.num_samples = math.ceil( 87 | len(self.dataset) * ratio / self.num_replicas) 88 | self.total_size = self.num_samples * self.num_replicas 89 | 90 | def __iter__(self): 91 | # deterministically shuffle based on epoch 92 | g = torch.Generator() 93 | g.manual_seed(self.epoch) 94 | indices = torch.randperm(self.total_size, generator=g).tolist() 95 | 96 | dataset_size = len(self.dataset) 97 | indices = [v % dataset_size for v in indices] 98 | 99 | # subsample 100 | indices = indices[self.rank:self.total_size:self.num_replicas] 101 | assert len(indices) == self.num_samples 102 | 103 | return iter(indices) 104 | 105 | def __len__(self): 106 | return self.num_samples 107 | 108 | def set_epoch(self, epoch): 109 | self.epoch = epoch -------------------------------------------------------------------------------- /codes/models/modules/ActNorms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | 4 | from models.modules import thops 5 | 6 | 7 | class _ActNorm(nn.Module): 8 | """ 9 | Activation Normalization 10 | Initialize the bias and scale with a given minibatch, 11 | so that the output per-channel have zero mean and unit variance for that. 12 | 13 | After initialization, `bias` and `logs` will be trained as parameters. 14 | """ 15 | 16 | def __init__(self, num_features, scale=1.): 17 | super().__init__() 18 | # register mean and scale 19 | size = [1, num_features, 1, 1] 20 | self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) 21 | self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) 22 | self.num_features = num_features 23 | self.scale = float(scale) 24 | self.inited = False 25 | 26 | def _check_input_dim(self, input): 27 | return NotImplemented 28 | 29 | def initialize_parameters(self, input): 30 | self._check_input_dim(input) 31 | if not self.training: 32 | return 33 | if (self.bias != 0).any(): 34 | self.inited = True 35 | return 36 | assert input.device == self.bias.device, (input.device, self.bias.device) 37 | with torch.no_grad(): 38 | bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 39 | vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 40 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 41 | self.bias.data.copy_(bias.data) 42 | self.logs.data.copy_(logs.data) 43 | self.inited = True 44 | 45 | def _center(self, input, reverse=False, offset=None): 46 | bias = self.bias 47 | 48 | if offset is not None: 49 | bias = bias + offset 50 | 51 | if not reverse: 52 | return input + bias 53 | else: 54 | return input - bias 55 | 56 | def _scale(self, input, logdet=None, reverse=False, offset=None): 57 | logs = self.logs 58 | 59 | if offset is not None: 60 | logs = logs + offset 61 | 62 | if not reverse: 63 | input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 64 | # input = input * torch.exp(logs+logs_offset) 65 | else: 66 | input = input * torch.exp(-logs) 67 | if logdet is not None: 68 | """ 69 | logs is log_std of `mean of channels` 70 | so we need to multiply pixels 71 | """ 72 | dlogdet = thops.sum(logs) * thops.pixels(input) 73 | if reverse: 74 | dlogdet *= -1 75 | logdet = logdet + dlogdet 76 | return input, logdet 77 | 78 | def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): 79 | if not self.inited: 80 | self.initialize_parameters(input) 81 | 82 | if offset_mask is not None: 83 | logs_offset *= offset_mask 84 | bias_offset *= offset_mask 85 | # no need to permute dims as old version 86 | if not reverse: 87 | # center and scale 88 | input = self._center(input, reverse, bias_offset) 89 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 90 | else: 91 | # scale and center 92 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 93 | input = self._center(input, reverse, bias_offset) 94 | return input, logdet 95 | 96 | 97 | class ActNorm2d(_ActNorm): 98 | def __init__(self, num_features, scale=1.): 99 | super().__init__(num_features, scale) 100 | 101 | def _check_input_dim(self, input): 102 | assert len(input.size()) == 4 103 | assert input.size(1) == self.num_features, ( 104 | "[ActNorm]: input should be in shape as `BCHW`," 105 | " channels should be {} rather than {}".format( 106 | self.num_features, input.size())) 107 | 108 | 109 | class MaskedActNorm2d(ActNorm2d): 110 | def __init__(self, num_features, scale=1.): 111 | super().__init__(num_features, scale) 112 | 113 | def forward(self, input, mask, logdet=None, reverse=False): 114 | 115 | assert mask.dtype == torch.bool 116 | output, logdet_out = super().forward(input, logdet, reverse) 117 | 118 | input[mask] = output[mask] 119 | logdet[mask] = logdet_out[mask] 120 | 121 | return input, logdet 122 | 123 | -------------------------------------------------------------------------------- /codes/models/modules/ConditionalFlow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from . import thops 7 | from utils.util import opt_get 8 | from models.modules.Basic import Conv2d, Conv2dZeros, GaussianDiag, DenseBlock, RRDB, FCN 9 | from models.modules.FlowStep import FlowStep 10 | 11 | import functools 12 | import models.modules.module_util as mutil 13 | 14 | 15 | class ConditionalFlow(nn.Module): 16 | def __init__(self, num_channels, num_channels_split, n_flow_step=0, opt=None, num_levels_condition=0, SR=True): 17 | super().__init__() 18 | self.SR = SR 19 | 20 | # number of levels of RRDB features. One level of conditional feature is enough for image rescaling 21 | num_features_condition = 2 if self.SR else 1 22 | 23 | # feature extraction 24 | RRDB_nb = opt_get(opt, ['RRDB_nb'], [5, 5]) 25 | RRDB_nf = opt_get(opt, ['RRDB_nf'], 64) 26 | RRDB_gc = opt_get(opt, ['RRDB_gc'], 32) 27 | RRDB_f = functools.partial(RRDB, nf=RRDB_nf, gc=RRDB_gc) 28 | self.conv_first = nn.Conv2d(num_channels_split + RRDB_nf*num_features_condition*num_levels_condition, RRDB_nf, 3, 1, 1, bias=True) 29 | self.RRDB_trunk0 = mutil.make_layer(RRDB_f, RRDB_nb[0]) 30 | self.RRDB_trunk1 = mutil.make_layer(RRDB_f, RRDB_nb[1]) 31 | self.trunk_conv1 = nn.Conv2d(RRDB_nf, RRDB_nf, 3, 1, 1, bias=True) 32 | 33 | # conditional flow 34 | self.additional_flow_steps = nn.ModuleList() 35 | for k in range(n_flow_step): 36 | self.additional_flow_steps.append(FlowStep(in_channels=num_channels-num_channels_split, 37 | cond_channels=RRDB_nf*num_features_condition, 38 | flow_permutation=opt['flow_permutation'], 39 | flow_coupling=opt['flow_coupling'], opt=opt)) 40 | 41 | self.f = Conv2dZeros(RRDB_nf*num_features_condition, (num_channels-num_channels_split)*2) 42 | 43 | 44 | def forward(self, z, u, eps_std=None, logdet=0., reverse=False, training=True): 45 | # for image SR 46 | if self.SR: 47 | if not reverse: 48 | conditional_feature = self.get_conditional_feature_SR(u) 49 | 50 | for layer in self.additional_flow_steps: 51 | z, logdet = layer(z, u=conditional_feature, logdet=logdet, reverse=False) 52 | 53 | h = self.f(conditional_feature) 54 | mean, logs = thops.split_feature(h, "cross") 55 | logdet += GaussianDiag.logp(mean, logs, z) 56 | 57 | return logdet, conditional_feature 58 | 59 | else: 60 | conditional_feature = self.get_conditional_feature_SR(u) 61 | 62 | h = self.f(conditional_feature) 63 | mean, logs = thops.split_feature(h, "cross") 64 | z = GaussianDiag.sample(mean, logs, eps_std) 65 | 66 | for layer in reversed(self.additional_flow_steps): 67 | z, _ = layer(z, u=conditional_feature, reverse=True) 68 | 69 | return z, logdet, conditional_feature 70 | else: 71 | # for image rescaling 72 | if not reverse: 73 | conditional_feature = self.get_conditional_feature_Rescaling(u) 74 | 75 | for layer in self.additional_flow_steps: 76 | z, logdet = layer(z, u=conditional_feature, logdet=logdet, reverse=False) 77 | 78 | h = self.f(conditional_feature) 79 | mean, scale = thops.split_feature(h, "cross") 80 | logscale = 0.318 * torch.atan(2 * scale) 81 | z = (z - mean) * torch.exp(-logscale) 82 | 83 | return z, conditional_feature 84 | 85 | else: 86 | conditional_feature = self.get_conditional_feature_Rescaling(u) 87 | 88 | h = self.f(conditional_feature) 89 | mean, scale = thops.split_feature(h, "cross") 90 | logscale = 0.318 * torch.atan(2 * scale) 91 | z = GaussianDiag.sample(mean, logscale, eps_std) 92 | 93 | for layer in reversed(self.additional_flow_steps): 94 | z, _ = layer(z, u=conditional_feature, reverse=True) 95 | 96 | return z, conditional_feature 97 | 98 | 99 | def get_conditional_feature_SR(self, u): 100 | u_feature_first = self.conv_first(u) 101 | u_feature1 = self.RRDB_trunk0(u_feature_first) 102 | u_feature2 = self.trunk_conv1(self.RRDB_trunk1(u_feature1)) + u_feature_first 103 | 104 | return torch.cat([u_feature1, u_feature2], 1) 105 | 106 | def get_conditional_feature_Rescaling(self, u): 107 | u_feature_first = self.conv_first(u) 108 | u_feature = self.trunk_conv1(self.RRDB_trunk1(self.RRDB_trunk0(u_feature_first))) + u_feature_first 109 | 110 | return u_feature 111 | 112 | 113 | -------------------------------------------------------------------------------- /codes/models/modules/Permutations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | from torch.nn import functional as F 5 | import scipy.linalg 6 | 7 | from models.modules import thops 8 | 9 | 10 | class Permute2d(nn.Module): 11 | def __init__(self, num_channels, shuffle): 12 | super().__init__() 13 | self.num_channels = num_channels 14 | self.indices = np.arange(self.num_channels - 1, -1, -1).astype(np.long) 15 | self.indices_inverse = np.zeros((self.num_channels), dtype=np.long) 16 | for i in range(self.num_channels): 17 | self.indices_inverse[self.indices[i]] = i 18 | if shuffle: 19 | self.reset_indices() 20 | 21 | def reset_indices(self): 22 | np.random.shuffle(self.indices) 23 | for i in range(self.num_channels): 24 | self.indices_inverse[self.indices[i]] = i 25 | 26 | def forward(self, input, logdet=None, reverse=False): 27 | if not reverse: 28 | return input[:, self.indices, :, :], logdet 29 | else: 30 | return input[:, self.indices_inverse, :, :], logdet 31 | 32 | 33 | class InvertibleConv1x1(nn.Module): 34 | def __init__(self, num_channels, LU_decomposed=False): 35 | super().__init__() 36 | w_shape = [num_channels, num_channels] 37 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) 38 | if not LU_decomposed: 39 | # Sample a random orthogonal matrix: 40 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 41 | else: 42 | # W = PL(U+diag(s)) 43 | np_p, np_l, np_u = scipy.linalg.lu(w_init) 44 | np_s = np.diag(np_u) 45 | np_sign_s = np.sign(np_s) 46 | np_log_s = np.log(np.abs(np_s)) 47 | np_u = np.triu(np_u, k=1) 48 | l_mask = np.tril(np.ones(w_shape, dtype=np.float32), -1) 49 | eye = np.eye(*w_shape, dtype=np.float32) 50 | 51 | self.register_buffer('p', torch.Tensor(np_p.astype(np.float32))) # remains fixed 52 | self.register_buffer('sign_s', torch.Tensor(np_sign_s.astype(np.float32))) # the sign is fixed 53 | self.l = nn.Parameter(torch.Tensor(np_l.astype(np.float32))) # optimized except diagonal 1 54 | self.log_s = nn.Parameter(torch.Tensor(np_log_s.astype(np.float32))) 55 | self.u = nn.Parameter(torch.Tensor(np_u.astype(np.float32))) # optimized 56 | self.l_mask = torch.Tensor(l_mask) 57 | self.eye = torch.Tensor(eye) 58 | self.w_shape = w_shape 59 | self.LU = LU_decomposed 60 | 61 | def get_weight(self, input, reverse): 62 | # The difference in computational cost will become significant for large c, although for the networks in 63 | # our experiments we did not measure a large difference in wallclock computation time. 64 | if not self.LU: 65 | if not reverse: 66 | # pixels = thops.pixels(input) 67 | # GPU version 68 | # dlogdet = torch.slogdet(self.weight)[1] * pixels 69 | # CPU version is 2x faster, https://github.com/didriknielsen/survae_flows/issues/5. 70 | dlogdet = (torch.slogdet(self.weight.to('cpu'))[1] * thops.pixels(input)).to(self.weight.device) 71 | weight = self.weight.view(self.w_shape[0], self.w_shape[1], 1, 1) 72 | else: 73 | dlogdet = 0 74 | weight = torch.inverse(self.weight.double()).float().view(self.w_shape[0], self.w_shape[1], 1, 1) 75 | 76 | 77 | return weight, dlogdet 78 | else: 79 | self.p = self.p.to(input.device) 80 | self.sign_s = self.sign_s.to(input.device) 81 | self.l_mask = self.l_mask.to(input.device) 82 | self.eye = self.eye.to(input.device) 83 | l = self.l * self.l_mask + self.eye 84 | u = self.u * self.l_mask.transpose(0, 1).contiguous() + torch.diag(self.sign_s * torch.exp(self.log_s)) 85 | dlogdet = thops.sum(self.log_s) * thops.pixels(input) 86 | if not reverse: 87 | w = torch.matmul(self.p, torch.matmul(l, u)) 88 | else: 89 | l = torch.inverse(l.double()).float() 90 | u = torch.inverse(u.double()).float() 91 | w = torch.matmul(u, torch.matmul(l, self.p.inverse())) 92 | return w.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 93 | 94 | def forward(self, input, logdet=None, reverse=False): 95 | """ 96 | log-det = log|abs(|W|)| * pixels 97 | """ 98 | weight, dlogdet = self.get_weight(input, reverse) 99 | if not reverse: 100 | z = F.conv2d(input, weight) # fc layer, ie, permute channel 101 | if logdet is not None: 102 | logdet = logdet + dlogdet 103 | return z, logdet 104 | else: 105 | z = F.conv2d(input, weight) 106 | if logdet is not None: 107 | logdet = logdet - dlogdet 108 | return z, logdet 109 | -------------------------------------------------------------------------------- /codes/data/GT_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import sys 9 | import os 10 | 11 | try: 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from data.util import imresize_np 14 | from utils import util as utils 15 | except ImportError: 16 | pass 17 | 18 | 19 | class GTDataset(data.Dataset): 20 | ''' 21 | Load GT images only, generate LR on-the-fly. 22 | ''' 23 | 24 | def __init__(self, opt): 25 | super(GTDataset, self).__init__() 26 | self.opt = opt 27 | self.LR_paths, self.GT_paths = None, None 28 | self.LR_env, self.GT_env = None, None # environment for lmdb 29 | self.scale = opt['scale'] 30 | if self.opt['phase'] == 'train': 31 | self.GT_size = opt['GT_size'] 32 | self.LR_size = self.GT_size // self.scale 33 | 34 | # read image list from lmdb or image files 35 | if opt['data_type'] == 'lmdb': 36 | self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) 37 | self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) 38 | elif opt['data_type'] == 'img': 39 | # self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list 40 | self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list 41 | else: 42 | print('Error: data_type is not matched in Dataset') 43 | assert self.GT_paths, 'Error: GT paths are empty.' 44 | 45 | 46 | def _init_lmdb(self): 47 | # https://github.com/chainer/chainermn/issues/129 48 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 49 | meminit=False) 50 | if self.opt['dataroot_LQ'] is not None: 51 | self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 52 | meminit=False) 53 | else: 54 | self.LR_env = 'No lmdb input for LR' 55 | 56 | def __getitem__(self, index): 57 | if self.opt['data_type'] == 'lmdb': 58 | if (self.GT_env is None) or (self.LR_env is None): 59 | self._init_lmdb() 60 | 61 | if self.opt['data_type'] == 'lmdb': 62 | resolution = [int(s) for s in self.GT_sizes[index].split('_')] 63 | else: 64 | resolution = None 65 | 66 | 67 | # loading code from srflow test 68 | # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] 69 | # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 70 | # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] 71 | # pad_factor = 2 72 | # h, w, c = img_LR.shape 73 | # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), 74 | # right=int(np.ceil(w / pad_factor) * pad_factor - w)) 75 | # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 76 | 77 | 78 | # get GT and LR image 79 | GT_path = self.GT_paths[index] 80 | LR_path = GT_path 81 | img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] 82 | img_LR = imresize_np(img_GT, 1/self.scale) 83 | 84 | 85 | if self.opt['phase'] == 'train': 86 | # crop 87 | H, W, C = img_LR.shape 88 | rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) 89 | rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) 90 | rnd_top_GT = rnd_top_LR * self.scale 91 | rnd_left_GT = rnd_left_LR * self.scale 92 | 93 | img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] 94 | img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] 95 | 96 | # augmentation - flip, rotate 97 | img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], 98 | self.opt['use_rot'], self.opt['mode']) 99 | 100 | # change color space if necessary, deal with gray image 101 | if self.opt['color']: 102 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 103 | img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] 104 | 105 | # BGR to RGB, HWC to CHW, numpy to tensor 106 | if img_GT.shape[2] == 3: 107 | img_GT = img_GT[:, :, [2, 1, 0]] 108 | if img_LR.shape[2] == 3: 109 | img_LR = img_LR[:, :, [2, 1, 0]] 110 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 111 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 112 | 113 | 114 | # modcrop 115 | _, H, W = img_LR.size() 116 | img_GT = img_GT[:, :H*self.scale, :W*self.scale] 117 | 118 | return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} 119 | 120 | def __len__(self): 121 | return len(self.GT_paths) 122 | 123 | 124 | def impad(img, top=0, bottom=0, left=0, right=0, color=255): 125 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') 126 | -------------------------------------------------------------------------------- /codes/data/GTLQ_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import sys 9 | import os 10 | 11 | try: 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from data.util import imresize_np 14 | from utils import util as utils 15 | except ImportError: 16 | pass 17 | 18 | 19 | class GTLQDataset(data.Dataset): 20 | ''' 21 | Load HR-LR image pairs. 22 | ''' 23 | 24 | def __init__(self, opt): 25 | super(GTLQDataset, self).__init__() 26 | self.opt = opt 27 | self.LR_paths, self.GT_paths = None, None 28 | self.LR_env, self.GT_env = None, None # environment for lmdb 29 | self.scale = opt['scale'] 30 | if self.opt['phase'] == 'train': 31 | self.GT_size = opt['GT_size'] 32 | self.LR_size = self.GT_size // self.scale 33 | 34 | # read image list from lmdb or image files 35 | if opt['data_type'] == 'lmdb': 36 | self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) 37 | self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) 38 | elif opt['data_type'] == 'img': 39 | self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list 40 | self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list 41 | else: 42 | print('Error: data_type is not matched in Dataset') 43 | assert self.GT_paths, 'Error: GT paths are empty.' 44 | if self.LR_paths and self.GT_paths: 45 | assert len(self.LR_paths) == len( 46 | self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( 47 | len(self.LR_paths), len(self.GT_paths)) 48 | 49 | def _init_lmdb(self): 50 | # https://github.com/chainer/chainermn/issues/129 51 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 52 | meminit=False) 53 | if self.opt['dataroot_LQ'] is not None: 54 | self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 55 | meminit=False) 56 | else: 57 | self.LR_env = 'No lmdb input for LR' 58 | 59 | def __getitem__(self, index): 60 | if self.opt['data_type'] == 'lmdb': 61 | if (self.GT_env is None) or (self.LR_env is None): 62 | self._init_lmdb() 63 | 64 | if self.opt['data_type'] == 'lmdb': 65 | resolution = [int(s) for s in self.GT_sizes[index].split('_')] 66 | else: 67 | resolution = None 68 | 69 | 70 | # loading code from srflow test 71 | # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] 72 | # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 73 | # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] 74 | # pad_factor = 2 75 | # h, w, c = img_LR.shape 76 | # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), 77 | # right=int(np.ceil(w / pad_factor) * pad_factor - w)) 78 | # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 79 | 80 | 81 | # get GT and LR image 82 | GT_path = self.GT_paths[index] 83 | LR_path = self.LR_paths[index] 84 | # LR_path = GT_path.replace('HR', 'LR_bicubic/X4').replace('.png','x{}.png'.format(self.scale)) 85 | img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] 86 | img_LR = util.read_img(self.LR_env, LR_path, resolution) 87 | 88 | 89 | if self.opt['phase'] == 'train': 90 | # crop 91 | H, W, C = img_LR.shape 92 | rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) 93 | rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) 94 | rnd_top_GT = rnd_top_LR * self.scale 95 | rnd_left_GT = rnd_left_LR * self.scale 96 | 97 | img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] 98 | img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] 99 | 100 | # augmentation - flip, rotate 101 | img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], 102 | self.opt['use_rot'], self.opt['mode']) 103 | 104 | # change color space if necessary, deal with gray image 105 | if self.opt['color']: 106 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 107 | img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] 108 | 109 | # BGR to RGB, HWC to CHW, numpy to tensor 110 | if img_GT.shape[2] == 3: 111 | img_GT = img_GT[:, :, [2, 1, 0]] 112 | if img_LR.shape[2] == 3: 113 | img_LR = img_LR[:, :, [2, 1, 0]] 114 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 115 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 116 | 117 | 118 | # modcrop 119 | _, H, W = img_LR.size() 120 | img_GT = img_GT[:, :H*self.scale, :W*self.scale] 121 | 122 | return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} 123 | 124 | def __len__(self): 125 | return len(self.GT_paths) 126 | 127 | 128 | def impad(img, top=0, bottom=0, left=0, right=0, color=255): 129 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') 130 | -------------------------------------------------------------------------------- /codes/data/GTLQx_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import sys 9 | import os 10 | 11 | try: 12 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 13 | from data.util import imresize_np 14 | from utils import util as utils 15 | except ImportError: 16 | pass 17 | 18 | 19 | class GTLQxDataset(data.Dataset): 20 | ''' 21 | Load HR-LR image pairs. 22 | ''' 23 | 24 | def __init__(self, opt): 25 | super(GTLQxDataset, self).__init__() 26 | self.opt = opt 27 | self.LR_paths, self.GT_paths = None, None 28 | self.LR_env, self.GT_env = None, None # environment for lmdb 29 | self.scale = opt['scale'] 30 | if self.opt['phase'] == 'train': 31 | self.GT_size = opt['GT_size'] 32 | self.LR_size = self.GT_size // self.scale 33 | 34 | # read image list from lmdb or image files 35 | if opt['data_type'] == 'lmdb': 36 | self.LR_paths, self.LR_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) 37 | self.GT_paths, self.GT_sizes = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) 38 | elif opt['data_type'] == 'img': 39 | self.LR_paths = util.get_image_paths(opt['data_type'], opt['dataroot_LQ']) # LR list 40 | self.GT_paths = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) # GT list 41 | else: 42 | print('Error: data_type is not matched in Dataset') 43 | assert self.GT_paths, 'Error: GT paths are empty.' 44 | if self.LR_paths and self.GT_paths: 45 | assert len(self.LR_paths) == len( 46 | self.GT_paths), 'GT and LR datasets have different number of images - {}, {}.'.format( 47 | len(self.LR_paths), len(self.GT_paths)) 48 | 49 | def _init_lmdb(self): 50 | # https://github.com/chainer/chainermn/issues/129 51 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 52 | meminit=False) 53 | if self.opt['dataroot_LQ'] is not None: 54 | self.LR_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 55 | meminit=False) 56 | else: 57 | self.LR_env = 'No lmdb input for LR' 58 | 59 | def __getitem__(self, index): 60 | if self.opt['data_type'] == 'lmdb': 61 | if (self.GT_env is None) or (self.LR_env is None): 62 | self._init_lmdb() 63 | 64 | if self.opt['data_type'] == 'lmdb': 65 | resolution = [int(s) for s in self.GT_sizes[index].split('_')] 66 | else: 67 | resolution = None 68 | 69 | 70 | # loading code from srflow test 71 | # img_GT = cv2.imread(GT_path)[:, :, [2, 1, 0]] 72 | # img_GT = torch.Tensor(img_GT.transpose([2, 0, 1]).astype(np.float32)) / 255 73 | # img_LR = cv2.imread(LR_path)[:, :, [2, 1, 0]] 74 | # pad_factor = 2 75 | # h, w, c = img_LR.shape 76 | # img_LR = impad(img_LR, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), 77 | # right=int(np.ceil(w / pad_factor) * pad_factor - w)) 78 | # img_LR = torch.Tensor(img_LR.transpose([2, 0, 1]).astype(np.float32)) / 255 79 | 80 | 81 | # get GT and LR image 82 | GT_path = self.GT_paths[index] 83 | # LR_path = self.LR_paths[index] 84 | LR_path = GT_path.replace('HR', 'LR_bicubic/X4').replace('.png','x{}.png'.format(self.scale)) 85 | img_GT = util.read_img(self.GT_env, GT_path, resolution) # return: Numpy float32, HWC, BGR, [0,1] 86 | img_LR = util.read_img(self.LR_env, LR_path, resolution) 87 | 88 | 89 | if self.opt['phase'] == 'train': 90 | # crop 91 | H, W, C = img_LR.shape 92 | rnd_top_LR = random.randint(0, max(0, H - self.LR_size)) 93 | rnd_left_LR = random.randint(0, max(0, W - self.LR_size)) 94 | rnd_top_GT = rnd_top_LR * self.scale 95 | rnd_left_GT = rnd_left_LR * self.scale 96 | 97 | img_GT = img_GT[rnd_top_GT:rnd_top_GT + self.GT_size, rnd_left_GT:rnd_left_GT + self.GT_size, :] 98 | img_LR = img_LR[rnd_top_LR:rnd_top_LR + self.LR_size, rnd_left_LR:rnd_left_LR + self.LR_size, :] 99 | 100 | # augmentation - flip, rotate 101 | img_GT, img_LR = util.augment([img_GT, img_LR], self.opt['use_flip'], 102 | self.opt['use_rot'], self.opt['mode']) 103 | 104 | # change color space if necessary, deal with gray image 105 | if self.opt['color']: 106 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 107 | img_LR = util.channel_convert(img_LR.shape[2], self.opt['color'], [img_LR])[0] 108 | 109 | # BGR to RGB, HWC to CHW, numpy to tensor 110 | if img_GT.shape[2] == 3: 111 | img_GT = img_GT[:, :, [2, 1, 0]] 112 | if img_LR.shape[2] == 3: 113 | img_LR = img_LR[:, :, [2, 1, 0]] 114 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 115 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 116 | 117 | 118 | # modcrop 119 | _, H, W = img_LR.size() 120 | img_GT = img_GT[:, :H*self.scale, :W*self.scale] 121 | 122 | return {'LQ': img_LR, 'GT': img_GT, 'LQ_path': LR_path, 'GT_path': GT_path} 123 | 124 | def __len__(self): 125 | return len(self.GT_paths) 126 | 127 | 128 | def impad(img, top=0, bottom=0, left=0, right=0, color=255): 129 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') 130 | -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | 7 | Loader, Dumper = OrderedYaml() 8 | 9 | 10 | def parse(opt_path, gpu_ids=None, is_train=True): 11 | with open(opt_path, mode='r') as f: 12 | opt = yaml.load(f, Loader=Loader) 13 | # export CUDA_VISIBLE_DEVICES 14 | if gpu_ids is not None: opt['gpu_ids'] = [int(x) for x in gpu_ids.split(',')] 15 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 16 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 17 | print('exporting CUDA_VISIBLE_DEVICES=' + gpu_list) 18 | 19 | opt['is_train'] = is_train 20 | if opt['distortion'] == 'sr': 21 | scale = opt['scale'] 22 | 23 | if 'datasets' in opt: 24 | # datasets 25 | for phase, dataset in opt['datasets'].items(): 26 | phase = phase.split('_')[0] 27 | print(dataset) 28 | dataset['phase'] = phase 29 | if opt['distortion'] == 'sr': 30 | dataset['scale'] = scale 31 | is_lmdb = False 32 | if dataset.get('dataroot_GT', None) is not None: 33 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 34 | if dataset['dataroot_GT'].endswith('lmdb'): 35 | is_lmdb = True 36 | # if dataset.get('dataroot_GT_bg', None) is not None: 37 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 38 | if dataset.get('dataroot_LQ', None) is not None: 39 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 40 | if dataset['dataroot_LQ'].endswith('lmdb'): 41 | is_lmdb = True 42 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 43 | if dataset['mode'].endswith('mc'): # for memcached 44 | dataset['data_type'] = 'mc' 45 | dataset['mode'] = dataset['mode'].replace('_mc', '') 46 | 47 | # path 48 | for key, path in opt['path'].items(): 49 | if path and key in opt['path'] and key != 'strict_load': 50 | opt['path'][key] = osp.expanduser(path) 51 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 52 | 53 | if is_train: 54 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 55 | opt['path']['experiments_root'] = experiments_root 56 | opt['path']['models'] = osp.join(experiments_root, 'models') 57 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 58 | opt['path']['log'] = experiments_root 59 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 60 | 61 | # change some options for debug mode 62 | if 'debug' in opt['name']: 63 | opt['train']['val_freq'] = 8 64 | opt['logger']['print_freq'] = 1 65 | opt['logger']['save_checkpoint_freq'] = 8 66 | else: # test 67 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 68 | opt['path']['results_root'] = results_root 69 | opt['path']['log'] = results_root 70 | 71 | 72 | # network 73 | if opt['distortion'] == 'sr': 74 | opt['network_G']['scale'] = scale 75 | 76 | # relative learning rate 77 | if 'train' in opt: 78 | niter = opt['train']['niter'] 79 | if 'T_period_rel' in opt['train']: 80 | opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']] 81 | if 'restarts_rel' in opt['train']: 82 | opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']] 83 | if 'lr_steps_rel' in opt['train']: 84 | opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']] 85 | if 'lr_steps_inverse_rel' in opt['train']: 86 | opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] 87 | print(opt['train']) 88 | 89 | 90 | return opt 91 | 92 | 93 | def dict2str(opt, indent_l=1): 94 | '''dict to string for logger''' 95 | msg = '' 96 | for k, v in opt.items(): 97 | if isinstance(v, dict): 98 | msg += ' ' * (indent_l * 2) + k + ':[\n' 99 | msg += dict2str(v, indent_l + 1) 100 | msg += ' ' * (indent_l * 2) + ']\n' 101 | else: 102 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 103 | return msg 104 | 105 | 106 | class NoneDict(dict): 107 | def __missing__(self, key): 108 | return None 109 | 110 | 111 | # convert to NoneDict, which return None for missing key. 112 | def dict_to_nonedict(opt): 113 | if isinstance(opt, dict): 114 | new_opt = dict() 115 | for key, sub_opt in opt.items(): 116 | new_opt[key] = dict_to_nonedict(sub_opt) 117 | return NoneDict(**new_opt) 118 | elif isinstance(opt, list): 119 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 120 | else: 121 | return opt 122 | 123 | 124 | def check_resume(opt, resume_iter): 125 | '''Check resume states and pretrain_model paths (overriding pretrain_paths)''' 126 | logger = logging.getLogger('base') 127 | if opt['path']['resume_state']: 128 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 129 | 'pretrain_model_D', None) is not None: 130 | logger.warning('pretrain_model path will be ignored when resuming training.') 131 | 132 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], '{}_G.pth'.format(resume_iter)) 133 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 134 | 135 | if opt['train']['gan_weight'] > 0: 136 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 137 | '{}_D.pth'.format(resume_iter)) 138 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 139 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restarts = [v + 1 for v in self.restarts] 16 | self.restart_weights = weights if weights else [1] 17 | assert len(self.restarts) == len( 18 | self.restart_weights), 'restarts and their weights do not match.' 19 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 20 | 21 | def get_lr(self): 22 | if self.last_epoch in self.restarts: 23 | if self.clear_state: 24 | self.optimizer.state = defaultdict(dict) 25 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 26 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 27 | if self.last_epoch not in self.milestones: 28 | return [group['lr'] for group in self.optimizer.param_groups] 29 | return [ 30 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 31 | for group in self.optimizer.param_groups 32 | ] 33 | 34 | 35 | class CosineAnnealingLR_Restart(_LRScheduler): 36 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 37 | self.T_period = T_period 38 | self.T_max = self.T_period[0] # current T period 39 | self.eta_min = eta_min 40 | self.restarts = restarts if restarts else [0] 41 | self.restarts = [v + 1 for v in self.restarts] 42 | self.restart_weights = weights if weights else [1] 43 | self.last_restart = 0 44 | assert len(self.restarts) == len( 45 | self.restart_weights), 'restarts and their weights do not match.' 46 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 47 | 48 | def get_lr(self): 49 | if self.last_epoch == 0: 50 | return self.base_lrs 51 | elif self.last_epoch in self.restarts: 52 | self.last_restart = self.last_epoch 53 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 54 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 55 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 56 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 57 | return [ 58 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 59 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 60 | ] 61 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 62 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 63 | (group['lr'] - self.eta_min) + self.eta_min 64 | for group in self.optimizer.param_groups] 65 | 66 | 67 | if __name__ == "__main__": 68 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 69 | betas=(0.9, 0.99)) 70 | ############################## 71 | # MultiStepLR_Restart 72 | ############################## 73 | ## Original 74 | lr_steps = [200000, 400000, 600000, 800000] 75 | restarts = None 76 | restart_weights = None 77 | 78 | ## two 79 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 80 | restarts = [500000] 81 | restart_weights = [1] 82 | 83 | ## four 84 | lr_steps = [ 85 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 86 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 87 | ] 88 | restarts = [250000, 500000, 750000] 89 | restart_weights = [1, 1, 1] 90 | 91 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 92 | clear_state=False) 93 | 94 | ############################## 95 | # Cosine Annealing Restart 96 | ############################## 97 | ## two 98 | T_period = [500000, 500000] 99 | restarts = [500000] 100 | restart_weights = [1] 101 | 102 | ## four 103 | T_period = [250000, 250000, 250000, 250000] 104 | restarts = [250000, 500000, 750000] 105 | restart_weights = [1,1, 1] 106 | 107 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 108 | weights=restart_weights) 109 | 110 | ############################## 111 | # Draw figure 112 | ############################## 113 | N_iter = 1000000 114 | lr_l = list(range(N_iter)) 115 | for i in range(N_iter): 116 | scheduler.step() 117 | current_lr = optimizer.param_groups[0]['lr'] 118 | lr_l[i] = current_lr 119 | 120 | import matplotlib as mpl 121 | from matplotlib import pyplot as plt 122 | import matplotlib.ticker as mtick 123 | mpl.style.use('default') 124 | import seaborn 125 | seaborn.set(style='whitegrid') 126 | seaborn.set_context('paper') 127 | 128 | plt.figure(1) 129 | plt.subplot(111) 130 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 131 | plt.title('Title', fontsize=16, color='k') 132 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 133 | legend = plt.legend(loc='upper right', shadow=False) 134 | ax = plt.gca() 135 | labels = ax.get_xticks().tolist() 136 | for k, v in enumerate(labels): 137 | labels[k] = str(int(v / 1000)) + 'K' 138 | ax.set_xticklabels(labels) 139 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 140 | 141 | ax.set_ylabel('Learning rate') 142 | ax.set_xlabel('Iteration') 143 | fig = plt.gcf() 144 | plt.show() 145 | -------------------------------------------------------------------------------- /codes/models/modules/FlowNet_SR_x4.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | import torch.nn.functional as F 5 | 6 | from utils.util import opt_get 7 | from models.modules import Basic 8 | from models.modules.FlowStep import FlowStep 9 | from models.modules.ConditionalFlow import ConditionalFlow 10 | 11 | class FlowNet(nn.Module): 12 | def __init__(self, image_shape, opt=None): 13 | assert image_shape[2] == 1 or image_shape[2] == 3 14 | super().__init__() 15 | H, W, self.C = image_shape 16 | self.opt = opt 17 | self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) 18 | self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) 19 | if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) 20 | 21 | n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) 22 | flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') 23 | flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') 24 | cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) 25 | enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) 26 | after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) 27 | if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) 28 | 29 | # construct flow 30 | self.layers = nn.ModuleList() 31 | self.output_shapes = [] 32 | 33 | for level in range(self.L): 34 | # 1. Squeeze 35 | self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing 36 | self.C, H, W = self.C * 4, H // 2, W // 2 37 | self.output_shapes.append([-1, self.C, H, W]) 38 | 39 | # 2. main FlowSteps (unconditional flow) 40 | for k in range(self.K[level]-after_splitOff_flowStep[level]): 41 | self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, 42 | flow_permutation=flow_permutation, 43 | flow_coupling=flow_coupling, 44 | opt=opt['network_G']['flowDownsampler'])) 45 | self.output_shapes.append([-1, self.C, H, W]) 46 | 47 | # 3. additional FlowSteps (split + conditional flow) 48 | if enable_splitOff: 49 | if level == 0: 50 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 51 | self.level0_condFlow = ConditionalFlow(num_channels=self.C, 52 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 53 | n_flow_step=after_splitOff_flowStep[level], 54 | opt=opt['network_G']['flowDownsampler']['splitOff'], 55 | num_levels_condition=1, SR=True) 56 | elif level == 1: 57 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 58 | self.level1_condFlow = ConditionalFlow(num_channels=self.C, 59 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 60 | n_flow_step=after_splitOff_flowStep[level], 61 | opt=opt['network_G']['flowDownsampler']['splitOff'], 62 | num_levels_condition=0, SR=True) 63 | self.C = self.C // 2 if level < self.L-1 else 3 64 | self.output_shapes.append([-1, self.C, H, W]) 65 | 66 | 67 | self.H = H 68 | self.W = W 69 | self.scaleH = image_shape[0] / H 70 | self.scaleW = image_shape[1] / W 71 | print('shapes:', self.output_shapes) 72 | 73 | # nodetach version; 0.05 better than detach version, 0.30 better when using only nll loss 74 | def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): 75 | if not reverse: 76 | return self.normal_flow(hr, u=u, logdet=logdet, training=training) 77 | else: 78 | return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) 79 | 80 | 81 | ''' 82 | hr->y1+z1->y2+z2 83 | ''' 84 | def normal_flow(self, z, u=None, logdet=None, training=True): 85 | for layer, shape in zip(self.layers, self.output_shapes): 86 | if isinstance(layer, FlowStep): 87 | z, logdet = layer(z, u, logdet=logdet, reverse=False) 88 | elif isinstance(layer, Basic.SqueezeLayer): 89 | z, logdet = layer(z, logdet=logdet, reverse=False) 90 | elif isinstance(layer, Basic.Split): 91 | if layer.level == 0: 92 | z, a1 = layer(z, reverse=False) 93 | y1 = z.clone() 94 | elif layer.level == 1: 95 | z, a2 = layer(z, reverse=False) 96 | logdet, conditional_feature2 = self.level1_condFlow(a2, z, logdet=logdet, reverse=False, training=training) 97 | 98 | conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) 99 | logdet, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) 100 | 101 | return z, logdet 102 | 103 | ''' 104 | y2+z2->y1+z1->hr 105 | ''' 106 | def reverse_flow(self, z, u=None, eps_std=None, training=True): 107 | for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): 108 | if isinstance(layer, FlowStep): 109 | z, _ = layer(z, u, reverse=True) 110 | elif isinstance(layer, Basic.SqueezeLayer): 111 | z, _ = layer(z, reverse=True) 112 | elif isinstance(layer, Basic.Split): 113 | if layer.level == 1: 114 | a2, _, conditional_feature2 = self.level1_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) 115 | z = layer(z, a2, reverse=True) 116 | elif layer.level == 0: 117 | conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) 118 | a1, _, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) 119 | z = layer(z, a1, reverse=True) 120 | 121 | 122 | 123 | return z 124 | 125 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | import natsort 7 | import glob 8 | 9 | 10 | class BaseModel(): 11 | def __init__(self, opt): 12 | self.opt = opt 13 | self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu') 14 | self.is_train = opt['is_train'] 15 | self.schedulers = [] 16 | self.optimizers = [] 17 | 18 | def feed_data(self, data): 19 | pass 20 | 21 | def optimize_parameters(self): 22 | pass 23 | 24 | def get_current_visuals(self): 25 | pass 26 | 27 | def get_current_losses(self): 28 | pass 29 | 30 | def print_network(self): 31 | pass 32 | 33 | def save(self, label): 34 | pass 35 | 36 | def load(self): 37 | pass 38 | 39 | def _set_lr(self, lr_groups_l): 40 | ''' set learning rate for warmup, 41 | lr_groups_l: list for lr_groups. each for a optimizer''' 42 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 43 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 44 | param_group['lr'] = lr 45 | 46 | def _get_init_lr(self): 47 | # get the initial lr, which is set by the scheduler 48 | init_lr_groups_l = [] 49 | for optimizer in self.optimizers: 50 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 51 | return init_lr_groups_l 52 | 53 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 54 | for scheduler in self.schedulers: 55 | scheduler.step() 56 | #### set up warm up learning rate 57 | if cur_iter < warmup_iter: 58 | # get initial lr for each group 59 | init_lr_g_l = self._get_init_lr() 60 | # modify warming-up learning rates 61 | warm_up_lr_l = [] 62 | for init_lr_g in init_lr_g_l: 63 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 64 | # set learning rate 65 | self._set_lr(warm_up_lr_l) 66 | 67 | def get_current_learning_rate(self): 68 | # return self.schedulers[0].get_lr()[0] 69 | return self.optimizers[0].param_groups[0]['lr'] 70 | 71 | def get_network_description(self, network): 72 | '''Get the string and total parameters of the network''' 73 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 74 | network = network.module 75 | s = str(network) 76 | n = sum(map(lambda x: x.numel(), network.parameters())) 77 | return s, n 78 | 79 | def save_network(self, network, network_label, iter_label): 80 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))), 81 | reverse=True) 82 | paths = [p for p in paths if 83 | "latest_" not in p and not any([str(i * 5000) in p.split("/")[-1].split("_") for i in range(101)])] 84 | if len(paths) > 2: 85 | for path in paths[2:]: 86 | os.remove(path) 87 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 88 | save_path = os.path.join(self.opt['path']['models'], save_filename) 89 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 90 | network = network.module 91 | state_dict = network.state_dict() 92 | for key, param in state_dict.items(): 93 | state_dict[key] = param.cpu() 94 | torch.save(state_dict, save_path) 95 | 96 | def load_network(self, load_path, network, strict=True, submodule=None): 97 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 98 | network = network.module 99 | if not (submodule is None or submodule.lower() == 'none'.lower()): 100 | network = network.__getattr__(submodule) 101 | load_net = torch.load(load_path) 102 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 103 | for k, v in load_net.items(): 104 | 105 | # if 'flowLRNet' in k or 'quantization' in k: 106 | # continue 107 | # k = k.replace('flowDownpsamplerNet', 'flow') 108 | # 109 | # # step2 110 | # k = k.replace('Split2d712', 'level0_condFlow') 111 | # k = k.replace('layers.29', 'level1_condFlow') 112 | # # k = k.replace('Split2d932', 'level2_condFlow') 113 | 114 | 115 | if k.startswith('module.'): 116 | load_net_clean[k[7:]] = v 117 | else: 118 | load_net_clean[k] = v 119 | 120 | network.load_state_dict(load_net_clean, strict=strict) 121 | 122 | # if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 123 | # network = network.module 124 | # state_dict = network.state_dict() 125 | # for key, param in state_dict.items(): 126 | # state_dict[key] = param.cpu() 127 | # torch.save(state_dict, '../experiments/SR_DF2K_X4_HCFlow32.pth') 128 | 129 | def save_training_state(self, epoch, iter_step): 130 | '''Saves training state during training, which will be used for resuming''' 131 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 132 | for s in self.schedulers: 133 | state['schedulers'].append(s.state_dict()) 134 | for o in self.optimizers: 135 | state['optimizers'].append(o.state_dict()) 136 | save_filename = '{}.state'.format(iter_step) 137 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 138 | 139 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")), 140 | reverse=True) 141 | paths = [p for p in paths if "latest_" not in p] 142 | if len(paths) > 2: 143 | for path in paths[2:]: 144 | os.remove(path) 145 | 146 | torch.save(state, save_path) 147 | 148 | def resume_training(self, resume_state): 149 | '''Resume the optimizers and schedulers for training''' 150 | resume_optimizers = resume_state['optimizers'] 151 | resume_schedulers = resume_state['schedulers'] 152 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 153 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 154 | for i, o in enumerate(resume_optimizers): 155 | self.optimizers[i].load_state_dict(o) 156 | for i, s in enumerate(resume_schedulers): 157 | # manually change lr milestones 158 | # from collections import Counter 159 | # s['milestones'] = Counter([100000, 150000, 200000, 250000, 300000, 350000, 400000]) # for multistage lr 160 | # s['restarts'] = [120001, 240001, 360001] # for cosine_restart_lr 161 | 162 | self.schedulers[i].load_state_dict(s) 163 | 164 | 165 | 166 | 167 | -------------------------------------------------------------------------------- /codes/utils/imresize.py: -------------------------------------------------------------------------------- 1 | # https://github.com/fatheral/matlab_imresize 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2020 Alex 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | 26 | from __future__ import print_function 27 | import numpy as np 28 | from math import ceil, floor 29 | 30 | 31 | def deriveSizeFromScale(img_shape, scale): 32 | output_shape = [] 33 | for k in range(2): 34 | output_shape.append(int(ceil(scale[k] * img_shape[k]))) 35 | return output_shape 36 | 37 | 38 | def deriveScaleFromSize(img_shape_in, img_shape_out): 39 | scale = [] 40 | for k in range(2): 41 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) 42 | return scale 43 | 44 | 45 | def triangle(x): 46 | x = np.array(x).astype(np.float64) 47 | lessthanzero = np.logical_and((x >= -1), x < 0) 48 | greaterthanzero = np.logical_and((x <= 1), x >= 0) 49 | f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero) 50 | return f 51 | 52 | 53 | def cubic(x): 54 | x = np.array(x).astype(np.float64) 55 | absx = np.absolute(x) 56 | absx2 = np.multiply(absx, absx) 57 | absx3 = np.multiply(absx2, absx) 58 | f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, 59 | (1 < absx) & (absx <= 2)) 60 | return f 61 | 62 | 63 | def contributions(in_length, out_length, scale, kernel, k_width): 64 | if scale < 1: 65 | h = lambda x: scale * kernel(scale * x) 66 | kernel_width = 1.0 * k_width / scale 67 | else: 68 | h = kernel 69 | kernel_width = k_width 70 | x = np.arange(1, out_length + 1).astype(np.float64) 71 | u = x / scale + 0.5 * (1 - 1 / scale) 72 | left = np.floor(u - kernel_width / 2) 73 | P = int(ceil(kernel_width)) + 2 74 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 75 | indices = ind.astype(np.int32) 76 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 77 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) 78 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) 79 | indices = aux[np.mod(indices, aux.size)] 80 | ind2store = np.nonzero(np.any(weights, axis=0)) 81 | weights = weights[:, ind2store] 82 | indices = indices[:, ind2store] 83 | return weights, indices 84 | 85 | 86 | def imresizemex(inimg, weights, indices, dim): 87 | in_shape = inimg.shape 88 | w_shape = weights.shape 89 | out_shape = list(in_shape) 90 | out_shape[dim] = w_shape[0] 91 | outimg = np.zeros(out_shape) 92 | if dim == 0: 93 | for i_img in range(in_shape[1]): 94 | for i_w in range(w_shape[0]): 95 | w = weights[i_w, :] 96 | ind = indices[i_w, :] 97 | im_slice = inimg[ind, i_img].astype(np.float64) 98 | outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 99 | elif dim == 1: 100 | for i_img in range(in_shape[0]): 101 | for i_w in range(w_shape[0]): 102 | w = weights[i_w, :] 103 | ind = indices[i_w, :] 104 | im_slice = inimg[i_img, ind].astype(np.float64) 105 | outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 106 | if inimg.dtype == np.uint8: 107 | outimg = np.clip(outimg, 0, 255) 108 | return np.around(outimg).astype(np.uint8) 109 | else: 110 | return outimg 111 | 112 | 113 | def imresizevec(inimg, weights, indices, dim): 114 | wshape = weights.shape 115 | if dim == 0: 116 | weights = weights.reshape((wshape[0], wshape[2], 1, 1)) 117 | outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) 118 | elif dim == 1: 119 | weights = weights.reshape((1, wshape[0], wshape[2], 1)) 120 | outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) 121 | if inimg.dtype == np.uint8: 122 | outimg = np.clip(outimg, 0, 255) 123 | return np.around(outimg).astype(np.uint8) 124 | else: 125 | return outimg 126 | 127 | 128 | def resizeAlongDim(A, dim, weights, indices, mode="vec"): 129 | if mode == "org": 130 | out = imresizemex(A, weights, indices, dim) 131 | else: 132 | out = imresizevec(A, weights, indices, dim) 133 | return out 134 | 135 | 136 | def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"): 137 | if method is 'bicubic': 138 | kernel = cubic 139 | elif method is 'bilinear': 140 | kernel = triangle 141 | else: 142 | print('Error: Unidentified method supplied') 143 | 144 | kernel_width = 4.0 145 | # Fill scale and output_size 146 | if scalar_scale is not None: 147 | scalar_scale = float(scalar_scale) 148 | scale = [scalar_scale, scalar_scale] 149 | output_size = deriveSizeFromScale(I.shape, scale) 150 | elif output_shape is not None: 151 | scale = deriveScaleFromSize(I.shape, output_shape) 152 | output_size = list(output_shape) 153 | else: 154 | print('Error: scalar_scale OR output_shape should be defined!') 155 | return 156 | scale_np = np.array(scale) 157 | order = np.argsort(scale_np) 158 | weights = [] 159 | indices = [] 160 | for k in range(2): 161 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) 162 | weights.append(w) 163 | indices.append(ind) 164 | B = np.copy(I) 165 | flag2D = False 166 | if B.ndim == 2: 167 | B = np.expand_dims(B, axis=2) 168 | flag2D = True 169 | for k in range(2): 170 | dim = order[k] 171 | B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode) 172 | if flag2D: 173 | B = np.squeeze(B, axis=2) 174 | return B 175 | 176 | 177 | def convertDouble2Byte(I): 178 | B = np.clip(I, 0.0, 1.0) 179 | B = 255 * B 180 | return np.around(B).astype(np.uint8) -------------------------------------------------------------------------------- /codes/models/modules/FlowNet_Rescaling_x4.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | import torch.nn.functional as F 5 | 6 | from utils.util import opt_get 7 | from models.modules import Basic 8 | from models.modules.FlowStep import FlowStep 9 | from models.modules.ConditionalFlow import ConditionalFlow 10 | 11 | class FlowNet(nn.Module): 12 | def __init__(self, image_shape, opt=None): 13 | assert image_shape[2] == 1 or image_shape[2] == 3 14 | super().__init__() 15 | H, W, self.C = image_shape 16 | self.opt = opt 17 | self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) 18 | self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) 19 | if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) 20 | 21 | squeeze = opt_get(self.opt, ['network_G', 'flowDownsampler', 'squeeze'], 'checkerboard') 22 | n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) 23 | flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') 24 | flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') 25 | cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) 26 | enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) 27 | after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) 28 | if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) 29 | 30 | # construct flow 31 | self.layers = nn.ModuleList() 32 | self.output_shapes = [] 33 | 34 | for level in range(self.L): 35 | # 1. Squeeze 36 | if squeeze == 'checkerboard': 37 | self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing 38 | elif squeeze == 'haar': 39 | self.layers.append(Basic.HaarDownsampling(channel_in=self.C)) 40 | 41 | self.C, H, W = self.C * 4, H // 2, W // 2 42 | self.output_shapes.append([-1, self.C, H, W]) 43 | 44 | # 2. main FlowSteps (uncodnitional flow) 45 | for k in range(self.K[level]-after_splitOff_flowStep[level]): 46 | self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, 47 | flow_permutation=flow_permutation, 48 | flow_coupling=flow_coupling, 49 | LRvsothers=True if k%2==0 else False, 50 | opt=opt['network_G']['flowDownsampler'])) 51 | self.output_shapes.append([-1, self.C, H, W]) 52 | 53 | # 3. additional FlowSteps (split + conditional flow) 54 | if enable_splitOff: 55 | if level == 0: 56 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 57 | self.level0_condFlow = ConditionalFlow(num_channels=self.C, 58 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 59 | n_flow_step=after_splitOff_flowStep[level], 60 | opt=opt['network_G']['flowDownsampler']['splitOff'], 61 | num_levels_condition=1, SR=False) 62 | elif level == 1: 63 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 64 | self.level1_condFlow = (ConditionalFlow(num_channels=self.C, 65 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 66 | n_flow_step=after_splitOff_flowStep[level], 67 | opt=opt['network_G']['flowDownsampler']['splitOff'], 68 | num_levels_condition=0, SR=False)) 69 | 70 | self.C = self.C // 2 if level < self.L-1 else 3 71 | self.output_shapes.append([-1, self.C, H, W]) 72 | 73 | 74 | self.H = H 75 | self.W = W 76 | self.scaleH = image_shape[0] / H 77 | self.scaleW = image_shape[1] / W 78 | print('shapes:', self.output_shapes) 79 | 80 | def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): 81 | if not reverse: 82 | return self.normal_flow(hr, u=u, logdet=logdet, training=training) 83 | else: 84 | return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) 85 | 86 | 87 | ''' 88 | hr->y1+z1->y2+z2 89 | ''' 90 | def normal_flow(self, z, u=None, logdet=None, training=True): 91 | for layer, shape in zip(self.layers, self.output_shapes): 92 | if isinstance(layer, FlowStep): 93 | z, _ = layer(z, u, logdet=logdet, reverse=False) 94 | elif isinstance(layer, Basic.SqueezeLayer) or isinstance(layer, Basic.HaarDownsampling): 95 | z, _ = layer(z, logdet=logdet, reverse=False) 96 | elif isinstance(layer, Basic.Split): 97 | if layer.level == 0: 98 | z, a1 = layer(z, reverse=False) 99 | y1 = z.clone() 100 | elif layer.level == 1: 101 | z, a2 = layer(z, reverse=False) 102 | fake_z2, conditional_feature2 = self.level1_condFlow(a2, z, logdet=logdet, reverse=False, training=training) 103 | 104 | conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) 105 | fake_z1, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) 106 | 107 | return z, fake_z1, fake_z2 108 | 109 | ''' 110 | y2+z2->y1+z1->hr 111 | ''' 112 | def reverse_flow(self, z, u=None, eps_std=None, training=True): 113 | for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): 114 | if isinstance(layer, FlowStep): 115 | z, _ = layer(z, u, reverse=True) 116 | elif isinstance(layer, Basic.SqueezeLayer) or isinstance(layer, Basic.HaarDownsampling): 117 | z, _ = layer(z, reverse=True) 118 | elif isinstance(layer, Basic.Split): 119 | if layer.level == 1: 120 | a2, conditional_feature2 = self.level1_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) 121 | z = layer(z, a2, reverse=True) 122 | elif layer.level == 0: 123 | conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest')],1) 124 | a1, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) 125 | z = layer(z, a1, reverse=True) 126 | 127 | 128 | return z 129 | 130 | -------------------------------------------------------------------------------- /codes/data/LRHR_PKL_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import os 18 | # import subprocess 19 | import torch.utils.data as data 20 | import numpy as np 21 | import time 22 | import torch 23 | 24 | import pickle 25 | 26 | 27 | class LRHR_PKLDataset(data.Dataset): 28 | def __init__(self, opt): 29 | super(LRHR_PKLDataset, self).__init__() 30 | self.opt = opt 31 | self.crop_size = opt.get("GT_size", None) 32 | self.scale = None 33 | self.random_scale_list = [1] 34 | 35 | hr_file_path = opt["dataroot_GT"] 36 | lr_file_path = opt["dataroot_LQ"] 37 | y_labels_file_path = opt['dataroot_y_labels'] 38 | 39 | gpu = True 40 | augment = True 41 | 42 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False 43 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False 44 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False 45 | self.center_crop_hr_size = opt.get("center_crop_hr_size", None) 46 | 47 | n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8) 48 | 49 | t = time.time() 50 | self.lr_images = self.load_pkls(lr_file_path, n_max) 51 | self.hr_images = self.load_pkls(hr_file_path, n_max) 52 | 53 | min_val_hr = np.min([i.min() for i in self.hr_images[:20]]) 54 | max_val_hr = np.max([i.max() for i in self.hr_images[:20]]) 55 | 56 | min_val_lr = np.min([i.min() for i in self.lr_images[:20]]) 57 | max_val_lr = np.max([i.max() for i in self.lr_images[:20]]) 58 | 59 | t = time.time() - t 60 | print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". 61 | format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path)) 62 | print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". 63 | format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path)) 64 | 65 | self.gpu = gpu 66 | self.augment = augment 67 | 68 | self.measures = None 69 | 70 | # todo: this is very slow (~15min) when using nn.DistributedDataParallel(), and we have to set n_worker=0 71 | 72 | # # save as png 73 | # import cv2 74 | # for i in range(400): 75 | # img = self.hr_images[i] 76 | # img = np.transpose(img, [1,2,0]) 77 | # cv2.imwrite("/cluster/work/cvl/jinliang/log/srflow_experiments/ICCV21/baseline_results/CelebA_HR_8X/{}.png".format(i), img[:,:,[2, 1, 0]]) 78 | # img = self.lr_images[i] 79 | # img = np.transpose(img, [1,2,0]) 80 | # cv2.imwrite("/cluster/work/cvl/jinliang/log/srflow_experiments/ICCV21/baseline_results/CelebA_LR_8X/{}.png".format(i), img[:,:,[2, 1, 0]]) 81 | # raise NotImplementedError 82 | 83 | def load_pkls(self, path, n_max): 84 | assert os.path.isfile(path), path 85 | images = [] 86 | with open(path, "rb") as f: 87 | images += pickle.load(f) 88 | assert len(images) > 0, path 89 | images = images[:n_max] 90 | images = [np.transpose(image, [2, 0, 1]) for image in images] 91 | return images 92 | 93 | def __len__(self): 94 | return len(self.hr_images) 95 | 96 | def __getitem__(self, item): 97 | 98 | hr = self.hr_images[item] 99 | lr = self.lr_images[item] 100 | 101 | if self.scale == None: 102 | self.scale = hr.shape[1] // lr.shape[1] 103 | assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape) 104 | 105 | if self.use_crop: 106 | hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop) 107 | 108 | if self.center_crop_hr_size: 109 | hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale) 110 | 111 | if self.use_flip: 112 | hr, lr = random_flip(hr, lr) 113 | 114 | if self.use_rot: 115 | hr, lr = random_rotation(hr, lr) 116 | 117 | hr = hr / 255.0 118 | lr = lr / 255.0 119 | 120 | if self.measures is None or np.random.random() < 0.05: 121 | if self.measures is None: 122 | self.measures = {} 123 | self.measures['hr_means'] = np.mean(hr) 124 | self.measures['hr_stds'] = np.std(hr) 125 | self.measures['lr_means'] = np.mean(lr) 126 | self.measures['lr_stds'] = np.std(lr) 127 | 128 | hr = torch.Tensor(hr) 129 | lr = torch.Tensor(lr) 130 | 131 | # if self.gpu: 132 | # hr = hr.cuda() 133 | # lr = lr.cuda() 134 | 135 | return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)} 136 | 137 | def print_and_reset(self, tag): 138 | m = self.measures 139 | kvs = [] 140 | for k in sorted(m.keys()): 141 | kvs.append("{}={:.2f}".format(k, m[k])) 142 | print("[KPI] " + tag + ": " + ", ".join(kvs)) 143 | self.measures = None 144 | 145 | 146 | def random_flip(img, seg): 147 | random_choice = np.random.choice([True, False]) 148 | img = img if random_choice else np.flip(img, 2).copy() 149 | seg = seg if random_choice else np.flip(seg, 2).copy() 150 | return img, seg 151 | 152 | 153 | def random_rotation(img, seg): 154 | random_choice = np.random.choice([0, 1, 3]) 155 | img = np.rot90(img, random_choice, axes=(1, 2)).copy() 156 | seg = np.rot90(seg, random_choice, axes=(1, 2)).copy() 157 | return img, seg 158 | 159 | 160 | def random_crop(hr, lr, size_hr, scale, random): 161 | size_lr = size_hr // scale 162 | 163 | size_lr_x = lr.shape[1] 164 | size_lr_y = lr.shape[2] 165 | 166 | start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 167 | start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 168 | 169 | # LR Patch 170 | lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr] 171 | 172 | # HR Patch 173 | start_x_hr = start_x_lr * scale 174 | start_y_hr = start_y_lr * scale 175 | hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr] 176 | 177 | return hr_patch, lr_patch 178 | 179 | 180 | def center_crop(img, size): 181 | assert img.shape[1] == img.shape[2], img.shape 182 | border_double = img.shape[1] - size 183 | assert border_double % 2 == 0, (img.shape, size) 184 | border = border_double // 2 185 | return img[:, border:-border, border:-border] 186 | 187 | 188 | def center_crop_tensor(img, size): 189 | assert img.shape[2] == img.shape[3], img.shape 190 | border_double = img.shape[2] - size 191 | assert border_double % 2 == 0, (img.shape, size) 192 | border = border_double // 2 193 | return img[:, :, border:-border, border:-border] 194 | -------------------------------------------------------------------------------- /codes/models/modules/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | class Discriminator_VGG_128(nn.Module): 7 | def __init__(self, in_nc, nf): 8 | super(Discriminator_VGG_128, self).__init__() 9 | # [64, 128, 128] 10 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 11 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 12 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 13 | # [64, 64, 64] 14 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 15 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 16 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 17 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 18 | # [128, 32, 32] 19 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 20 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 21 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 22 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 23 | # [256, 16, 16] 24 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 25 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 26 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 27 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 28 | # [512, 8, 8] 29 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 30 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 31 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 32 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 33 | 34 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 35 | self.linear2 = nn.Linear(100, 1) 36 | 37 | # activation function 38 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 39 | 40 | def forward(self, x): 41 | fea = self.lrelu(self.conv0_0(x)) 42 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 43 | 44 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 45 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 46 | 47 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 48 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 49 | 50 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 51 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 52 | 53 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 54 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 55 | 56 | fea = fea.view(fea.size(0), -1) 57 | fea = self.lrelu(self.linear1(fea)) 58 | out = self.linear2(fea) 59 | return out 60 | 61 | def reset_parameters(self): 62 | for layer in self.children(): 63 | if hasattr(layer, 'reset_parameters'): 64 | layer.reset_parameters() 65 | 66 | 67 | 68 | class Discriminator_VGG_160(nn.Module): 69 | def __init__(self, in_nc, nf): 70 | super(Discriminator_VGG_160, self).__init__() 71 | # [64, 160, 160] 72 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 73 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 74 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 75 | # [64, 80, 80] 76 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 77 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 78 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 79 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 80 | # [128, 40, 40] 81 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 82 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 83 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 84 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 85 | # [256, 20, 20] 86 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 87 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 88 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 89 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 90 | # [512, 10, 10] 91 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 92 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 93 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 94 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 95 | # [512, 5, 5] 96 | 97 | self.linear1 = nn.Linear(512 * 5 * 5, 100) 98 | self.linear2 = nn.Linear(100, 1) 99 | 100 | # activation function 101 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 102 | 103 | def forward(self, x): 104 | fea = self.lrelu(self.conv0_0(x)) 105 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 106 | 107 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 108 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 109 | 110 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 111 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 112 | 113 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 114 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 115 | 116 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 117 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 118 | 119 | fea = fea.view(fea.size(0), -1) 120 | fea = self.lrelu(self.linear1(fea)) 121 | out = self.linear2(fea) 122 | return out 123 | 124 | def reset_parameters(self): 125 | for layer in self.children(): 126 | if hasattr(layer, 'reset_parameters'): 127 | layer.reset_parameters() 128 | 129 | 130 | class VGGFeatureExtractor(nn.Module): 131 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 132 | device=torch.device('cpu')): 133 | super(VGGFeatureExtractor, self).__init__() 134 | self.use_input_norm = use_input_norm 135 | if use_bn: 136 | model = torchvision.models.vgg19_bn(pretrained=True) 137 | else: 138 | model = torchvision.models.vgg19(pretrained=True) 139 | if self.use_input_norm: 140 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 141 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 142 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 143 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 144 | self.register_buffer('mean', mean) 145 | self.register_buffer('std', std) 146 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 147 | # No need to BP to variable 148 | for k, v in self.features.named_parameters(): 149 | v.requires_grad = False 150 | 151 | def forward(self, x): 152 | # Assume input range is [0, 1] 153 | if self.use_input_norm: 154 | x = (x - self.mean) / self.std 155 | output = self.features(x) 156 | return output 157 | 158 | 159 | class PatchGANDiscriminator(nn.Module): 160 | """Defines a PatchGAN discriminator""" 161 | 162 | def __init__(self, in_nc=3, ndf=64, n_layers=35, norm_layer=nn.BatchNorm2d): 163 | """Construct a PatchGAN discriminator 164 | Parameters: 165 | input_nc (int) -- the number of channels in input images 166 | ndf (int) -- the number of filters in the last conv layer 167 | n_layers (int) -- the number of conv layers in the discriminator 168 | norm_layer -- normalization layer 169 | """ 170 | super(PatchGANDiscriminator, self).__init__() 171 | use_bias = False 172 | kw = 3 173 | padw = 0 174 | sequence = [nn.Conv2d(in_nc, ndf, kernel_size=kw, stride=1, padding=padw), nn.LeakyReLU(0.2, True)] 175 | 176 | for i in range(0, n_layers): 177 | sequence += [ 178 | nn.Conv2d(ndf, ndf, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 179 | norm_layer(ndf), 180 | nn.LeakyReLU(0.2, True) 181 | ] 182 | 183 | sequence += [nn.Conv2d(ndf, 1, kernel_size=kw, stride=1, padding=padw, bias=use_bias)] # output 1 channel prediction map 184 | # TODO 185 | self.model = nn.Sequential(*sequence) 186 | 187 | def forward(self, x): 188 | """Standard forward.""" 189 | return self.model(x) 190 | -------------------------------------------------------------------------------- /codes/models/modules/FlowNet_SR_x8.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from torch import nn as nn 4 | import torch.nn.functional as F 5 | 6 | from utils.util import opt_get 7 | from models.modules import Basic 8 | from models.modules.FlowStep import FlowStep 9 | from models.modules.ConditionalFlow import ConditionalFlow 10 | 11 | class FlowNet(nn.Module): 12 | def __init__(self, image_shape, opt=None): 13 | assert image_shape[2] == 1 or image_shape[2] == 3 14 | super().__init__() 15 | H, W, self.C = image_shape 16 | self.opt = opt 17 | self.L = opt_get(opt, ['network_G', 'flowDownsampler', 'L']) 18 | self.K = opt_get(opt, ['network_G', 'flowDownsampler', 'K']) 19 | if isinstance(self.K, int): self.K = [self.K] * (self.L + 1) 20 | 21 | n_additionalFlowNoAffine = opt_get(self.opt, ['network_G', 'flowDownsampler', 'additionalFlowNoAffine'], 0) 22 | flow_permutation = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_permutation'], 'invconv') 23 | flow_coupling = opt_get(self.opt, ['network_G', 'flowDownsampler', 'flow_coupling'], 'Affine') 24 | cond_channels = opt_get(self.opt, ['network_G', 'flowDownsampler', 'cond_channels'], None) 25 | enable_splitOff = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'enable'], False) 26 | after_splitOff_flowStep = opt_get(opt, ['network_G', 'flowDownsampler', 'splitOff', 'after_flowstep'], 0) 27 | if isinstance(after_splitOff_flowStep, int): after_splitOff_flowStep = [after_splitOff_flowStep] * (self.L + 1) 28 | 29 | # construct flow 30 | self.layers = nn.ModuleList() 31 | self.output_shapes = [] 32 | 33 | for level in range(self.L): 34 | # 1. Squeeze 35 | self.layers.append(Basic.SqueezeLayer(factor=2)) # may need a better way for squeezing 36 | self.C, H, W = self.C * 4, H // 2, W // 2 37 | self.output_shapes.append([-1, self.C, H, W]) 38 | 39 | # 2. main FlowSteps (unconditional flow) 40 | for k in range(self.K[level]-after_splitOff_flowStep[level]): 41 | self.layers.append(FlowStep(in_channels=self.C, cond_channels=cond_channels, 42 | flow_permutation=flow_permutation, 43 | flow_coupling=flow_coupling, 44 | opt=opt['network_G']['flowDownsampler'])) 45 | self.output_shapes.append([-1, self.C, H, W]) 46 | 47 | # 3. additional FlowSteps (split + conditional flow) 48 | if enable_splitOff: 49 | if level == 0: 50 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 51 | self.level0_condFlow = ConditionalFlow(num_channels=self.C, 52 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 53 | n_flow_step=after_splitOff_flowStep[level], 54 | opt=opt['network_G']['flowDownsampler']['splitOff'], 55 | num_levels_condition=2, SR=True) 56 | elif level == 1: 57 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 58 | self.level1_condFlow = ConditionalFlow(num_channels=self.C, 59 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 60 | n_flow_step=after_splitOff_flowStep[level], 61 | opt=opt['network_G']['flowDownsampler']['splitOff'], 62 | num_levels_condition=1, SR=True) 63 | elif level == 2: 64 | self.layers.append(Basic.Split(num_channels_split=self.C // 2 if level < self.L-1 else 3, level=level)) 65 | self.level2_condFlow = ConditionalFlow(num_channels=self.C, 66 | num_channels_split=self.C // 2 if level < self.L-1 else 3, 67 | n_flow_step=after_splitOff_flowStep[level], 68 | opt=opt['network_G']['flowDownsampler']['splitOff'], 69 | num_levels_condition=0, SR=True) 70 | 71 | self.C = self.C // 2 if level < self.L-1 else 3 72 | self.output_shapes.append([-1, self.C, H, W]) 73 | 74 | self.H = H 75 | self.W = W 76 | self.scaleH = image_shape[0] / H 77 | self.scaleW = image_shape[1] / W 78 | print('shapes:', self.output_shapes) 79 | 80 | 81 | # nodetach version: # nodetach version; 0.05 better than detach version, 0.30 better when using only nll loss 82 | def forward(self, hr=None, z=None, u=None, eps_std=None, logdet=None, reverse=False, training=True): 83 | if not reverse: 84 | return self.normal_flow(hr, u=u, logdet=logdet, training=training) 85 | else: 86 | return self.reverse_flow(z, u=u, eps_std=eps_std, training=training) 87 | 88 | ''' 89 | hr->y1+a1(z1)->y2+a2(z2)->y3+z3 90 | ''' 91 | def normal_flow(self, z, u=None, logdet=None, training=True): 92 | for layer, shape in zip(self.layers, self.output_shapes): 93 | if isinstance(layer, FlowStep): 94 | z, logdet = layer(z, u, logdet=logdet, reverse=False) 95 | elif isinstance(layer, Basic.SqueezeLayer): 96 | z, logdet = layer(z, logdet=logdet, reverse=False) 97 | elif isinstance(layer, Basic.Split): 98 | if layer.level == 0: 99 | z, a1 = layer(z, reverse=False) 100 | y1 = z.clone() 101 | elif layer.level == 1: 102 | z, a2 = layer(z, reverse=False) 103 | y2 = z.clone() 104 | elif layer.level == 2: 105 | z, a3 = layer(z, reverse=False) 106 | 107 | logdet, conditional_feature3 = self.level2_condFlow(a3, z, logdet=logdet, reverse=False, training=training) 108 | 109 | conditional_feature2 = torch.cat([y2, F.interpolate(conditional_feature3, scale_factor=2, mode='nearest')],1) 110 | logdet, conditional_feature2 = self.level1_condFlow(a2, conditional_feature2, logdet=logdet, reverse=False, training=training) 111 | 112 | conditional_feature1 = torch.cat([y1, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest'), 113 | F.interpolate(conditional_feature3, scale_factor=4, mode='nearest')],1) 114 | logdet, _ = self.level0_condFlow(a1, conditional_feature1, logdet=logdet, reverse=False, training=training) 115 | 116 | return z, logdet 117 | 118 | ''' 119 | y3+z3(a3)->y2+z2(a2)->y1+z1(a1)->hr 120 | ''' 121 | def reverse_flow(self, z, u=None, eps_std=None, training=True): 122 | for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): 123 | if isinstance(layer, FlowStep): 124 | z, _ = layer(z, u, reverse=True) 125 | elif isinstance(layer, Basic.SqueezeLayer): 126 | z, _ = layer(z, reverse=True) 127 | elif isinstance(layer, Basic.Split): 128 | if layer.level == 2: 129 | a3, _, conditional_feature3 = self.level2_condFlow(None, z, eps_std=eps_std, reverse=True, training=training) 130 | z = layer(z, a3, reverse=True) 131 | elif layer.level == 1: 132 | conditional_feature2 = torch.cat([z, F.interpolate(conditional_feature3, scale_factor=2, mode='nearest')],1) 133 | a2, _, conditional_feature2 = self.level1_condFlow(None, conditional_feature2, eps_std=eps_std, reverse=True, training=training) 134 | z = layer(z, a2, reverse=True) 135 | elif layer.level == 0: 136 | conditional_feature1 = torch.cat([z, F.interpolate(conditional_feature2, scale_factor=2, mode='nearest'), 137 | F.interpolate(conditional_feature3, scale_factor=4, mode='nearest')],1) 138 | a1, _, _ = self.level0_condFlow(None, conditional_feature1, eps_std=eps_std, reverse=True, training=training) 139 | z = layer(z, a1, reverse=True) 140 | 141 | 142 | 143 | 144 | return z 145 | 146 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) 3 | 4 | 5 | [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2108.05301) 6 | [![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/HCFlow?style=social)](https://github.com/JingyunLiang/HCFlow) 7 | [![download](https://img.shields.io/github/downloads/JingyunLiang/HCFlow/total.svg)](https://github.com/JingyunLiang/HCFlow/releases) 8 | [ google colab logo](https://colab.research.google.com/gist/JingyunLiang/cdb3fef89ebd174eaa43794accb6f59d/hcflow-demo-on-x8-face-image-sr.ipynb) 9 | 10 | 11 | This repository is the official PyTorch implementation of Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling 12 | ([arxiv](https://arxiv.org/pdf/2108.05301.pdf), [supp](https://github.com/JingyunLiang/HCFlow/releases)). 13 | 14 | 15 | :rocket: :rocket: :rocket: **News**: 16 | - Sep. 07, 2021: We add an [online Colob demo google colab logo](https://colab.research.google.com/gist/JingyunLiang/cdb3fef89ebd174eaa43794accb6f59d/hcflow-demo-on-x8-face-image-sr.ipynb) for easy comparison of HCFlow and [SRFlow](https://github.com/andreas128/SRFlow). 17 | - Sep.06, 2021: See our recent work [SwinIR: Transformer-based image restoration](https://github.com/JingyunLiang/SwinIR) [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2108.10257)[![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/SwinIR?style=social)](https://github.com/JingyunLiang/SwinIR)[![download](https://img.shields.io/github/downloads/JingyunLiang/SwinIR/total.svg)](https://github.com/JingyunLiang/SwinIR/releases)[ google colab logo](https://colab.research.google.com/gist/JingyunLiang/a5e3e54bc9ef8d7bf594f6fee8208533/swinir-demo-on-real-world-image-sr.ipynb) 18 | - Aug. 17, 2021: See our recent work for [blind SR: Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet), ICCV2021](https://github.com/JingyunLiang/MANet) [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2108.05302)[![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/MANet?style=social)](https://github.com/JingyunLiang/MANet) 19 | [![download](https://img.shields.io/github/downloads/JingyunLiang/MANet/total.svg)](https://github.com/JingyunLiang/MANet/releases)[ google colab logo](https://colab.research.google.com/gist/JingyunLiang/4ed2524d6e08343710ee408a4d997e1c/manet-demo-on-spatially-variant-kernel-estimation.ipynb) 20 | - Aug. 17, 2021: See our recent work for [real-world image SR: Designing a Practical Degradation Model for Deep Blind Image Super-Resolution (BSRGAN), ICCV2021](https://github.com/cszn/BSRGAN) [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2103.14006) 21 | [![GitHub Stars](https://img.shields.io/github/stars/cszn/BSRGAN?style=social)](https://github.com/cszn/BSRGAN) 22 | - Aug. 17, 2021: See our previous [flow-based kernel estimation: Flow-based Kernel Prior with Application to Blind Super-Resolution (FKP), CVPR2021](https://github.com/JingyunLiang/FKP) [![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2103.15977) 23 | [![GitHub Stars](https://img.shields.io/github/stars/JingyunLiang/FKP?style=social)](https://github.com/JingyunLiang/FKP) 24 | --- 25 | 26 | > Normalizing flows have recently demonstrated promising results for low-level vision tasks. For image super-resolution (SR), it learns to predict diverse photo-realistic high-resolution (HR) images from the low-resolution (LR) image rather than learning a deterministic mapping. For image rescaling, it achieves high accuracy by jointly modelling the downscaling and upscaling processes. While existing approaches employ specialized techniques for these two tasks, we set out to unify them in a single formulation. In this paper, we propose the hierarchical conditional flow (HCFlow) as a unified framework for image SR and image rescaling. More specifically, HCFlow learns a bijective mapping between HR and LR image pairs by modelling the distribution of the LR image and the rest high-frequency component simultaneously. In particular, the high-frequency component is conditional on the LR image in a hierarchical manner. To further enhance the performance, other losses such as perceptual loss and GAN loss are combined with the commonly used negative log-likelihood loss in training. Extensive experiments on general image SR, face image SR and image rescaling have demonstrated that the proposed HCFlow achieves state-of-the-art performance in terms of both quantitative metrics and visual quality. 27 | >

28 | >           29 |

30 | 31 | ## Requirements 32 | - Python 3.7, PyTorch == 1.7.1 33 | - Requirements: opencv-python, lpips, natsort, etc. 34 | - Platforms: Ubuntu 16.04, cuda-11.0 35 | 36 | 37 | ```bash 38 | cd HCFlow-master 39 | pip install -r requirements.txt 40 | ``` 41 | 42 | ## Quick Run (takes 1 Minute) 43 | To run the code with one command (without preparing data), run following command. Or you can go to our [online Colob demo google colab logo](https://colab.research.google.com/gist/JingyunLiang/cdb3fef89ebd174eaa43794accb6f59d/hcflow-demo-on-x8-face-image-sr.ipynb) to have a try. 44 | ```bash 45 | cd codes 46 | # face image SR 47 | python test_HCFlow.py --opt options/test/test_SR_CelebA_8X_HCFlow.yml 48 | 49 | # general image SR 50 | python test_HCFlow.py --opt options/test/test_SR_DF2K_4X_HCFlow.yml 51 | 52 | # image rescaling 53 | python test_HCFlow.py --opt options/test/test_Rescaling_DF2K_4X_HCFlow.yml 54 | ``` 55 | --- 56 | 57 | ## Data Preparation 58 | The framework of this project is based on [MMSR](https://github.com/open-mmlab/mmediting) and [SRFlow](https://github.com/andreas128/SRFlow). To prepare data, put training and testing sets in `./datasets` as `./datasets/DIV2K/HR/0801.png`. Commonly used SR datasets can be downloaded [here](https://github.com/xinntao/BasicSR/blob/master/docs/DatasetPreparation.md#common-image-sr-datasets). 59 | There are two ways for accerleration in data loading: First, one can use `./scripts/png2npy.py` to generate `.npy` files and use `data/GTLQnpy_dataset.py`. Second, one can use `.pklv4` dataset (*recommended*) and use `data/LRHR_PKL_dataset.py`. Please refer to [SRFlow](https://github.com/andreas128/SRFlow#dataset-how-to-train-on-your-own-data) for more details. Prepared datasets can be downloaded [here](http://data.vision.ee.ethz.ch/alugmayr/SRFlow/datasets.zip). 60 | 61 | ## Training 62 | 63 | To train HCFlow for general image SR/ face image SR/ image rescaling, run this command: 64 | 65 | ```bash 66 | cd codes 67 | 68 | # face image SR 69 | python train_HCFlow.py --opt options/train/train_SR_CelebA_8X_HCFlow.yml 70 | 71 | # general image SR 72 | python train_HCFlow.py --opt options/train/train_SR_DF2K_4X_HCFlow.yml 73 | 74 | # image rescaling 75 | python train_HCFlow.py --opt options/train/train_Rescaling_DF2K_4X_HCFlow.yml 76 | ``` 77 | All trained models can be downloaded from [here](https://github.com/JingyunLiang/HCFlow/releases). 78 | 79 | 80 | ## Testing 81 | 82 | Please follow the **Quick Run** section. Just modify the dataset path in `test_HCFlow_*.yml`. 83 | 84 | ## Results 85 | We achieved state-of-the-art performance on general image SR, face image SR and image rescaling. 86 | > 87 | > 88 | For more results, please refer to the [paper](https://arxiv.org/abs/2108.05301) and [supp](https://github.com/JingyunLiang/HCFlow/releases) for details. 89 | 90 | ## Citation 91 | @inproceedings{liang21hierarchical, 92 | title={Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling}, 93 | author={Liang, Jingyun and Lugmayr, Andreas and Zhang, Kai and Danelljan, Martin and Van Gool, Luc and Timofte, Radu}, 94 | booktitle={IEEE International Conference on Computer Vision}, 95 | year={2021} 96 | } 97 | 98 | 99 | ## License & Acknowledgement 100 | 101 | This project is released under the Apache 2.0 license. The codes are based on [MMSR](https://github.com/open-mmlab/mmediting), [SRFlow](https://github.com/andreas128/SRFlow), [IRN](https://github.com/pkuxmq/Invertible-Image-Rescaling) and [Glow-pytorch](https://github.com/chaiyujin/glow-pytorch). Please also follow their licenses. Thanks for their great works. 102 | 103 | -------------------------------------------------------------------------------- /codes/models/modules/AffineCouplings.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.modules import thops 6 | from models.modules.Basic import Conv2d, Conv2dZeros, DenseBlock, FCN, RDN 7 | from utils.util import opt_get, register_hook, trunc_normal_ 8 | 9 | 10 | class AffineCoupling(nn.Module): 11 | def __init__(self, in_channels, cond_channels=None, opt=None): 12 | super().__init__() 13 | self.in_channels = in_channels 14 | self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) 15 | self.n_hidden_layers = 1 16 | self.kernel_hidden = 1 17 | self.cond_channels = cond_channels 18 | f_in_channels = self.in_channels//2 if cond_channels is None else self.in_channels//2 + cond_channels 19 | f_out_channels = (self.in_channels - self.in_channels//2) * 2 20 | nn_module = opt_get(opt, ['nn_module'], 'FCN') 21 | if nn_module == 'DenseBlock': 22 | self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) 23 | elif nn_module == 'FCN': 24 | self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, 25 | kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) 26 | 27 | 28 | def forward(self, z, u=None, y=None, logdet=None, reverse=False): 29 | if not reverse: 30 | return self.normal_flow(z, u, y, logdet) 31 | else: 32 | return self.reverse_flow(z, u, y, logdet) 33 | 34 | def normal_flow(self, z, u=None, y=None, logdet=None): 35 | z1, z2 = thops.split_feature(z, "split") 36 | 37 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 38 | shift, scale = thops.split_feature(h, "cross") 39 | # adding 1e-4 is crucial for torch.slogdet(), as used in Glow (leads to black rect in experiments). 40 | # see https://github.com/didriknielsen/survae_flows/issues/5 for discussion. 41 | # or use `torch.exp(2. * torch.tanh(s / 2.)) as in SurVAE (more unstable in practice). 42 | 43 | # version 1, srflow (use FCN) 44 | # scale = torch.sigmoid(scale + 2.) + 1e-4 45 | # z2 = (z2 + shift) * scale 46 | # logdet += thops.sum(torch.log(scale), dim=[1, 2, 3]) 47 | 48 | # version2, survae 49 | # logscale = 2. * torch.tanh(scale / 2.) 50 | # z2 = (z2+shift) * torch.exp(logscale) # as in glow, it's shift+scale! 51 | # logdet += thops.sum(logscale, dim=[1, 2, 3]) 52 | 53 | # version3, FrEIA, now have problem with FCN, but densenet is ok. (use FCN2/Denseblock) 54 | # logscale = 0.5 * 0.636 * torch.atan(scale / 0.5) # clamp it to be between [-0.5,0.5] 55 | logscale = 0.318 * torch.atan(2 * scale) 56 | # logscale = 1.0 * 0.636 * torch.atan(scale / 1.0) 57 | z2 = (z2 + shift) * torch.exp(logscale) 58 | if logdet is not None: 59 | logdet += thops.sum(logscale, dim=[1, 2, 3]) 60 | 61 | z = thops.cat_feature(z1, z2) 62 | 63 | return z, logdet 64 | 65 | def reverse_flow(self, z, u=None, y=None, logdet=None): 66 | z1, z2 = thops.split_feature(z, "split") 67 | 68 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 69 | shift, scale = thops.split_feature(h, "cross") 70 | 71 | # version1, srflow 72 | # scale = torch.sigmoid(scale + 2.) + 1e-4 73 | # z2 = (z2 / scale) -shift 74 | 75 | # version2, survae 76 | # logscale = 2. * torch.tanh(scale / 2.) 77 | # z2 = z2 * torch.exp(-logscale) - shift 78 | 79 | # version3, FrEIA 80 | # logscale = 0.5 * 0.636 * torch.atan(scale / 0.5) 81 | logscale = 0.318 * torch.atan(2 * scale) 82 | # logscale = 1 * 0.636 * torch.atan(scale / 1.0) 83 | z2 = z2 * torch.exp(-logscale) - shift 84 | 85 | z = thops.cat_feature(z1, z2) 86 | 87 | return z, logdet 88 | 89 | 90 | '''3 channel conditional on the rest channels, or vice versa. only shift LR. 91 | used in image rescaling to divide the low-frequencies and the high-frequencies apart from early flow layers.''' 92 | class AffineCoupling3shift(nn.Module): 93 | def __init__(self, in_channels, cond_channels=None, LRvsothers=True, opt=None): 94 | super().__init__() 95 | self.in_channels = in_channels 96 | self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) 97 | self.n_hidden_layers = 1 98 | self.kernel_hidden = 1 99 | self.cond_channels = cond_channels 100 | self.LRvsothers = LRvsothers 101 | if LRvsothers: 102 | f_in_channels = 3 if cond_channels is None else 3 + cond_channels 103 | f_out_channels = (self.in_channels - 3) * 2 104 | else: 105 | f_in_channels = self.in_channels - 3 if cond_channels is None else self.in_channels - 3 + cond_channels 106 | f_out_channels = 3 107 | nn_module = opt_get(opt, ['nn_module'], 'FCN') 108 | 109 | if nn_module == 'DenseBlock': 110 | self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) 111 | elif nn_module == 'FCN': 112 | self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, 113 | kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) 114 | 115 | 116 | def forward(self, z, u=None, y=None, logdet=None, reverse=False): 117 | if not reverse: 118 | return self.normal_flow(z, u, y, logdet) 119 | else: 120 | return self.reverse_flow(z, u, y, logdet) 121 | 122 | def normal_flow(self, z, u=None, y=None, logdet=None): 123 | if self.LRvsothers: 124 | z1, z2 = z[:, :3, ...], z[:, 3:, ...] 125 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 126 | shift, scale = thops.split_feature(h, "cross") 127 | logscale = 0.318 * torch.atan(2 * scale) 128 | z2 = (z2 + shift) * torch.exp(logscale) 129 | if logdet is not None: 130 | logdet += thops.sum(logscale, dim=[1, 2, 3]) 131 | else: 132 | z2, z1 = z[:, :3, ...], z[:, 3:, ...] 133 | shift = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 134 | z2 = z2 + shift 135 | 136 | if self.LRvsothers: 137 | z = thops.cat_feature(z1, z2) 138 | else: 139 | z = thops.cat_feature(z2, z1) 140 | 141 | return z, logdet 142 | 143 | def reverse_flow(self, z, u=None, y=None, logdet=None): 144 | if self.LRvsothers: 145 | z1, z2 = z[:, :3, ...], z[:, 3:, ...] 146 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 147 | shift, scale = thops.split_feature(h, "cross") 148 | logscale = 0.318 * torch.atan(2 * scale) 149 | z2 = z2 * torch.exp(-logscale) - shift 150 | else: 151 | z2, z1 = z[:, :3, ...], z[:, 3:, ...] 152 | shift = self.f(z1) 153 | z2 = z2 - shift 154 | 155 | if self.LRvsothers: 156 | z = thops.cat_feature(z1, z2) 157 | else: 158 | z = thops.cat_feature(z2, z1) 159 | 160 | return z, logdet 161 | 162 | 163 | ''' srflow's affine injector + original affine coupling, not used in this project''' 164 | class AffineCouplingInjector(nn.Module): 165 | def __init__(self, in_channels, cond_channels=None, opt=None): 166 | super().__init__() 167 | self.in_channels = in_channels 168 | self.hidden_channels = opt_get(opt, ['hidden_channels'], 64) 169 | self.n_hidden_layers = 1 170 | self.kernel_hidden = 1 171 | self.cond_channels = cond_channels 172 | f_in_channels = self.in_channels//2 if cond_channels is None else self.in_channels//2 + cond_channels 173 | f_out_channels = (self.in_channels - self.in_channels//2) * 2 174 | nn_module = opt_get(opt, ['nn_module'], 'FCN') 175 | if nn_module == 'DenseBlock': 176 | self.f = DenseBlock(in_channels=f_in_channels, out_channels=f_out_channels, gc=self.hidden_channels) 177 | self.f_injector = DenseBlock(in_channels=cond_channels, out_channels=self.in_channels*2, gc=self.hidden_channels) 178 | elif nn_module == 'FCN': 179 | self.f = FCN(in_channels=f_in_channels, out_channels=f_out_channels, hidden_channels=self.hidden_channels, 180 | kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) 181 | self.f_injector = FCN(in_channels=cond_channels, out_channels=self.in_channels*2, hidden_channels=self.hidden_channels, 182 | kernel_hidden=self.kernel_hidden, n_hidden_layers=self.n_hidden_layers) 183 | 184 | def forward(self, z, u=None, y=None, logdet=None, reverse=False): 185 | if not reverse: 186 | return self.normal_flow(z, u, y, logdet) 187 | else: 188 | return self.reverse_flow(z, u, y, logdet) 189 | 190 | def normal_flow(self, z, u=None, y=None, logdet=None): 191 | # overall-conditional 192 | h = self.f_injector(u) 193 | shift, scale = thops.split_feature(h, "cross") 194 | logscale = 0.318 * torch.atan(2 * scale) # clamp it to be between [-5,5] 195 | z = (z + shift) * torch.exp(logscale) 196 | logdet += thops.sum(logscale, dim=[1, 2, 3]) 197 | 198 | # self-conditional 199 | z1, z2 = thops.split_feature(z, "split") 200 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 201 | shift, scale = thops.split_feature(h, "cross") 202 | logscale = 0.318 * torch.atan(2 * scale) # clamp it to be between [-5,5] 203 | z2 = (z2 + shift) * torch.exp(logscale) 204 | logdet += thops.sum(logscale, dim=[1, 2, 3]) 205 | z = thops.cat_feature(z1, z2) 206 | 207 | return z, logdet 208 | 209 | def reverse_flow(self, z, u=None, y=None, logdet=None): 210 | # self-conditional 211 | z1, z2 = thops.split_feature(z, "split") 212 | h = self.f(z1) if self.cond_channels is None else self.f(thops.cat_feature(z1, u)) 213 | shift, scale = thops.split_feature(h, "cross") 214 | logscale = 0.318 * torch.atan(2 * scale) 215 | z2 = z2 * torch.exp(-logscale) - shift 216 | z = thops.cat_feature(z1, z2) 217 | 218 | # overall-conditional 219 | h = self.f_injector(u) 220 | shift, scale = thops.split_feature(h, "cross") 221 | logscale = 0.318 * torch.atan(2 * scale) 222 | z = z * torch.exp(-logscale) - shift 223 | 224 | return z, logdet 225 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [2021][HCFlow Authors] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | -------------------------------------------------------------------------------- /codes/test_HCFlow.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | import numpy as np 7 | import torch 8 | import options.options as option 9 | import utils.util as util 10 | from utils.imresize import imresize 11 | from data.util import bgr2ycbcr 12 | from data import create_dataset, create_dataloader 13 | from models import create_model 14 | import lpips 15 | 16 | 17 | 18 | #### options 19 | parser = argparse.ArgumentParser() # test_SR_CelebA_8X_HCFlow test_SR_DF2K_4X_HCFlow test_Rescaling_DF2K_4X_HCFlow 20 | parser.add_argument('--opt', type=str, default='options/test/test_SR_CelebA_8X_HCFlow.yml', help='Path to options YMAL file.') 21 | parser.add_argument('--save_kernel', action='store_true', default=False, help='Save Kernel Esimtation.') 22 | args = parser.parse_args() 23 | opt = option.parse(args.opt, is_train=False) 24 | opt = option.dict_to_nonedict(opt) 25 | device_id = torch.cuda.current_device() 26 | 27 | #### mkdir and logger 28 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 29 | and 'pretrain_model' not in key and 'resume' not in key and 'load_submodule' not in key)) 30 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 31 | screen=True, tofile=True) 32 | logger = logging.getLogger('base') 33 | logger.info(option.dict2str(opt)) 34 | 35 | # set random seed 36 | util.set_random_seed(0) 37 | 38 | #### Create test dataset and dataloader 39 | test_loaders = [] 40 | for phase, dataset_opt in sorted(opt['datasets'].items()): 41 | test_set = create_dataset(dataset_opt) 42 | test_loader = create_dataloader(test_set, dataset_opt) 43 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 44 | test_loaders.append(test_loader) 45 | 46 | # load pretrained model by default 47 | model = create_model(opt) 48 | loss_fn_alex = lpips.LPIPS(net='alex').to('cuda') 49 | crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] 50 | 51 | for test_loader in test_loaders: 52 | test_set_name = test_loader.dataset.opt['name'] 53 | logger.info('\n\nTesting [{:s}]...'.format(test_set_name)) 54 | test_start_time = time.time() 55 | dataset_dir = os.path.join(opt['path']['results_root'], test_set_name) 56 | util.mkdir(dataset_dir) 57 | 58 | idx = 0 59 | psnr_dict={} # for HR image 60 | ssim_dict={} 61 | psnr_y_dict = {} 62 | ssim_y_dict = {} 63 | bic_hr_psnr_dict={} # for bic(HR) 64 | bic_hr_ssim_dict={} 65 | bic_hr_psnr_y_dict = {} 66 | bic_hr_ssim_y_dict = {} 67 | lpips_dict = {} 68 | diversity_dict = {} # pixel-wise variance 69 | avg_lr_psnr = 0.0 # for generated LR image 70 | avg_lr_ssim = 0.0 71 | avg_lr_psnr_y = 0.0 72 | avg_lr_ssim_y = 0.0 73 | avg_nll = 0.0 74 | 75 | for test_data in test_loader: 76 | idx += 1 77 | 78 | real_image = True if test_loader.dataset.opt['mode'] == 'LQ' else False 79 | img_path = test_data['LQ_path'][0] if real_image else test_data['GT_path'][0] 80 | img_name = os.path.splitext(os.path.basename(img_path))[0] 81 | 82 | model.feed_data(test_data, need_GT=not real_image) 83 | nll = model.test() 84 | avg_nll += nll 85 | visuals = model.get_current_visuals(need_GT=not real_image) 86 | 87 | # deal with real-world data (just save) 88 | if real_image: 89 | for heat in opt['val']['heats']: 90 | for sample in range(opt['val']['n_sample']): 91 | sr_img = util.tensor2img(visuals['SR', heat, sample]) 92 | 93 | if opt['suffix']: 94 | save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}_{:s}.png'.format(img_name, heat, sample, opt['suffix'])) 95 | else: 96 | save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}.png'.format(img_name, heat, sample)) 97 | util.save_img(sr_img, save_img_path) 98 | 99 | # deal with synthetic data (calculate psnr and save) 100 | else: 101 | 102 | # calculate PSNR for LR 103 | gt_img_lr = util.tensor2img(visuals['LQ']) 104 | sr_img_lr = util.tensor2img(visuals['LQ_fromH']) 105 | # save_img_path = os.path.join(dataset_dir, 'LR_{:s}_{:.1f}_{:d}.png'.format(img_name, 1.0, 0)) 106 | # util.save_img(sr_img_lr, save_img_path) 107 | gt_img_lr = gt_img_lr / 255. 108 | sr_img_lr = sr_img_lr / 255. 109 | 110 | lr_psnr, lr_ssim, lr_psnr_y, lr_ssim_y = util.calculate_psnr_ssim(gt_img_lr, sr_img_lr, 0) 111 | avg_lr_psnr += lr_psnr 112 | avg_lr_ssim += lr_ssim 113 | avg_lr_psnr_y += lr_psnr_y 114 | avg_lr_ssim_y += lr_ssim_y 115 | 116 | for heat in opt['val']['heats']: 117 | psnr = 0.0 118 | ssim = 0.0 119 | psnr_y = 0.0 120 | ssim_y = 0.0 121 | lpips_value = 0.0 122 | bic_hr_psnr = 0.0 123 | bic_hr_ssim = 0.0 124 | bic_hr_psnr_y = 0.0 125 | bic_hr_ssim_y = 0.0 126 | 127 | sr_img_list =[] 128 | for sample in range(opt['val']['n_sample']): 129 | gt_img = visuals['GT'] 130 | sr_img = visuals['SR', heat, sample] 131 | sr_img_list.append(sr_img.unsqueeze(0)*255) 132 | lpips_dict[(idx, heat, sample)] = float(loss_fn_alex(2 * gt_img.to('cuda') - 1, 2 * sr_img.to('cuda') - 1).cpu()) 133 | lpips_value += lpips_dict[(idx, heat, sample)] 134 | 135 | gt_img = util.tensor2img(gt_img) # uint8 136 | sr_img = util.tensor2img(sr_img) # uint8 137 | if opt['suffix']: 138 | save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}_{:s}.png'.format(img_name, heat, sample, opt['suffix'])) 139 | else: 140 | save_img_path = os.path.join(dataset_dir, 'SR_{:s}_{:.1f}_{:d}.png'.format(img_name, heat, sample)) 141 | util.save_img(sr_img, save_img_path) 142 | 143 | gt_img = gt_img / 255. 144 | sr_img = sr_img / 255. 145 | bic_hr_gt_img = imresize(gt_img, 1 / opt['scale']) 146 | bic_hr_sr_img = imresize(sr_img, 1 / opt['scale']) 147 | 148 | psnr_dict[(idx, heat, sample)], ssim_dict[(idx, heat, sample)], \ 149 | psnr_y_dict[(idx, heat, sample)], ssim_y_dict[(idx, heat, sample)] = util.calculate_psnr_ssim(gt_img, sr_img, crop_border) 150 | psnr += psnr_dict[(idx, heat, sample)] 151 | ssim += ssim_dict[(idx, heat, sample)] 152 | psnr_y += psnr_y_dict[(idx, heat, sample)] 153 | ssim_y += ssim_y_dict[(idx, heat, sample)] 154 | bic_hr_psnr_dict[(idx, heat, sample)], bic_hr_ssim_dict[(idx, heat, sample)], \ 155 | bic_hr_psnr_y_dict[(idx, heat, sample)], bic_hr_ssim_y_dict[(idx, heat, sample)] = util.calculate_psnr_ssim(bic_hr_gt_img, bic_hr_sr_img, 0) 156 | bic_hr_psnr += bic_hr_psnr_dict[(idx, heat, sample)] 157 | bic_hr_ssim += bic_hr_ssim_dict[(idx, heat, sample)] 158 | bic_hr_psnr_y += bic_hr_psnr_y_dict[(idx, heat, sample)] 159 | bic_hr_ssim_y += bic_hr_ssim_y_dict[(idx, heat, sample)] 160 | 161 | 162 | # mean pixel-wise variance 163 | psnr /= opt['val']['n_sample'] 164 | ssim /= opt['val']['n_sample'] 165 | psnr_y /= opt['val']['n_sample'] 166 | ssim_y /= opt['val']['n_sample'] 167 | diversity_dict[(idx, heat)] = float(torch.cat(sr_img_list, 0).std([0]).mean().cpu()) 168 | lpips_value /= opt['val']['n_sample'] 169 | bic_hr_psnr /= opt['val']['n_sample'] 170 | bic_hr_ssim /= opt['val']['n_sample'] 171 | bic_hr_psnr_y /= opt['val']['n_sample'] 172 | bic_hr_ssim_y /= opt['val']['n_sample'] 173 | 174 | 175 | logger.info('{:20s} ({}samples),heat:{:.1f}) ' 176 | 'HR:PSNR/SSIM/PSNR_Y/SSIM_Y/LPIPS/Diversity: {:.2f}/{:.4f}/{:.2f}/{:.4f}/{:.4f}/{:.4f}, ' 177 | 'bicHR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, ' 178 | 'LR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, NLL: {:.4f}'.format( 179 | img_name, opt['val']['n_sample'], heat, 180 | psnr, ssim, psnr_y, ssim_y, lpips_value, diversity_dict[(idx, heat)], 181 | bic_hr_psnr, bic_hr_ssim, bic_hr_psnr_y, bic_hr_ssim_y, 182 | lr_psnr, lr_ssim, lr_psnr_y, lr_ssim_y, nll)) 183 | 184 | # Average PSNR/SSIM results 185 | avg_lr_psnr /= idx 186 | avg_lr_ssim /= idx 187 | avg_lr_psnr_y /= idx 188 | avg_lr_ssim_y /= idx 189 | avg_nll = avg_nll / idx 190 | 191 | if real_image: 192 | logger.info('----{} ({} images), avg LR PSNR/SSIM/PSNR_K/LR_SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}\n'.format(test_set_name, idx, avg_lr_psnr, avg_lr_ssim, avg_lr_psnr_y, avg_lr_ssim_y)) 193 | else: 194 | logger.info('-------------------------------------------------------------------------------------') 195 | for heat in opt['val']['heats']: 196 | avg_psnr = 0.0 197 | avg_ssim = 0.0 198 | avg_psnr_y = 0.0 199 | avg_ssim_y = 0.0 200 | avg_lpips = 0.0 201 | avg_diversity = 0.0 202 | avg_bic_hr_psnr = 0.0 203 | avg_bic_hr_ssim = 0.0 204 | avg_bic_hr_psnr_y = 0.0 205 | avg_bic_hr_ssim_y = 0.0 206 | 207 | for iidx in range(1, idx+1): 208 | for sample in range(opt['val']['n_sample']): 209 | avg_psnr += psnr_dict[(iidx, heat, sample)] 210 | avg_ssim += ssim_dict[(iidx, heat, sample)] 211 | avg_psnr_y += psnr_y_dict[(iidx, heat, sample)] 212 | avg_ssim_y += ssim_y_dict[(iidx, heat, sample)] 213 | avg_lpips += lpips_dict[(iidx, heat, sample)] 214 | avg_bic_hr_psnr += bic_hr_psnr_dict[(iidx, heat, sample)] 215 | avg_bic_hr_ssim += bic_hr_ssim_dict[(iidx, heat, sample)] 216 | avg_bic_hr_psnr_y += bic_hr_psnr_y_dict[(iidx, heat, sample)] 217 | avg_bic_hr_ssim_y += bic_hr_ssim_y_dict[(iidx, heat, sample)] 218 | avg_diversity += diversity_dict[(iidx, heat)] 219 | 220 | avg_psnr = avg_psnr / idx / opt['val']['n_sample'] 221 | avg_ssim = avg_ssim / idx / opt['val']['n_sample'] 222 | avg_psnr_y = avg_psnr_y / idx / opt['val']['n_sample'] 223 | avg_ssim_y = avg_ssim_y / idx / opt['val']['n_sample'] 224 | avg_lpips = avg_lpips / idx / opt['val']['n_sample'] 225 | avg_diversity = avg_diversity / idx 226 | avg_bic_hr_psnr = avg_bic_hr_psnr / idx / opt['val']['n_sample'] 227 | avg_bic_hr_ssim = avg_bic_hr_ssim / idx / opt['val']['n_sample'] 228 | avg_bic_hr_psnr_y = avg_bic_hr_psnr_y / idx / opt['val']['n_sample'] 229 | avg_bic_hr_ssim_y = avg_bic_hr_ssim_y / idx / opt['val']['n_sample'] 230 | 231 | # log 232 | logger.info(opt['path']['pretrain_model_G']) 233 | logger.info('----{} ({}images,{}samples,heat:{:.1f}) ' 234 | 'average HR:PSNR/SSIM/PSNR_Y/SSIM_Y/LPIPS/Diversity: {:.2f}/{:.4f}/{:.2f}/{:.4f}/{:.4f}/{:.4f}, ' 235 | 'bicHR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, ' 236 | 'LR:PSNR/SSIM/PSNR_Y/SSIM_Y: {:.2f}/{:.4f}/{:.2f}/{:.4f}, NLL: {:.4f}'.format( 237 | test_set_name, idx, opt['val']['n_sample'], heat, 238 | avg_psnr, avg_ssim, avg_psnr_y, avg_ssim_y, avg_lpips, avg_diversity, 239 | avg_bic_hr_psnr, avg_bic_hr_ssim, avg_bic_hr_psnr_y, avg_bic_hr_ssim_y, 240 | avg_lr_psnr, avg_lr_ssim, avg_lr_psnr_y, avg_lr_ssim_y, avg_nll)) 241 | --------------------------------------------------------------------------------