├── .github └── workflows │ └── pylint.yml ├── .gitignore ├── LICENSE ├── README.md ├── data ├── __pycache__ │ ├── dataset_plain.cpython-39.pyc │ └── select_dataset.cpython-39.pyc ├── dataset_multiin.py ├── dataset_multimodal.py ├── dataset_plain.py └── select_dataset.py ├── dataset └── gen_dataset_wz_fov.py ├── main_test_sample.py ├── main_train_sample.py ├── models ├── __pycache__ │ ├── basicblock.cpython-39.pyc │ ├── loss.cpython-39.pyc │ ├── model_base.cpython-39.pyc │ ├── model_plain.cpython-39.pyc │ ├── select_model.cpython-39.pyc │ ├── select_network.cpython-39.pyc │ └── unet.cpython-39.pyc ├── _loss.py ├── _lr_scheduler.py ├── archs │ ├── sr_network_RRDB.py │ └── unet_arch.py ├── basicblock.py ├── model_base.py ├── model_multiin.py ├── model_multiout.py ├── model_plain.py ├── model_progressive.py ├── network_discriminator.py ├── network_feature.py ├── network_fftformer.py ├── network_fsanet.py ├── network_mimounet.py ├── network_mprnet.py ├── network_nafnet.py ├── network_pqnet.py ├── network_restormer.py ├── network_rrdb.py ├── network_rrdbnet.py ├── network_stripformer.py ├── network_uformer.py ├── network_unet.py ├── network_vapsr.py ├── network_vit.py ├── select_model.py └── select_network.py ├── options ├── option.json ├── option_20230722.json ├── option_20230724.json ├── option_official_implementation_fftformer.json ├── option_official_implementation_fsanet.json ├── option_official_implementation_mimounet.json ├── option_official_implementation_nafnet.json ├── option_official_implementation_painter.json ├── option_official_implementation_restormer.json └── option_official_implementation_uformer.json ├── requirements.txt └── utils ├── __pycache__ ├── utils_bnorm.cpython-39.pyc ├── utils_dist.cpython-39.pyc ├── utils_image.cpython-39.pyc ├── utils_logger.cpython-39.pyc ├── utils_model.cpython-39.pyc ├── utils_option.cpython-39.pyc └── utils_regularizers.cpython-39.pyc ├── utils_bnorm.py ├── utils_dist.py ├── utils_filter.py ├── utils_image.py ├── utils_logger.py ├── utils_mask.py ├── utils_model.py ├── utils_option.py └── utils_regularizers.py /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | strategy: 9 | matrix: 10 | python-version: ["3.8", "3.9", "3.10"] 11 | steps: 12 | - uses: actions/checkout@v3 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v3 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install pylint 21 | - name: Analysing the code with pylint 22 | run: | 23 | pylint $(git ls-files '*.py') 24 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | results/ 3 | logs/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Shiqi Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Toolbox for ImagingLab@ZJU 2 | 3 | ## 主要内容 4 | 0. 更新内容(New Features/Updates) 5 | 1. Toolbox功能介绍 (Introduction) 6 | 2. 环境依赖(Package dependencies) 7 | 3. 模型种类(Model Categories) 8 | 4. 使用方法(How To Use) 9 | 10 | 🚩 **更新内容(New Features/Updates)** 11 | 12 | - ✅ Sep. 16, 2023. 修复model_multiout中loss叠加回传梯度错误的问题,改进了valid和输出visualization 13 | - ✅ Sep. 05, 2023. 修复单卡训练及多尺度输出model的bugs 14 | - ✅ Aug. 14, 2023. 加入多模态视觉模型[Painter](https://github.com/baaivision/Painter/tree/main/Painter)的官方部署配置 15 | - ✅ Aug. 02, 2023. 部分transformer-based model加入local推理方式;支持多输入模型,*e.g*,退化PSF,信噪比等等;支持光学像差矫正任务,开源像差矫正模型[FSANet](https://opg.optica.org/oe/abstract.cfm?URI=oe-30-13-23485)(Frequency Self-Adaptive Network)及其官方部署配置,hope you enjoy it 🍻 16 | - ✅ Jul. 28, 2023. 加入uformer,nafnet,fftformer等复原模型的官方部署配置 17 | - ✅ Jul. 26, 2023. 增加多卡训练,加入mimounet,restormer等复原模型的官方部署配置 18 | 19 | ## Toolbox功能介绍(Introduction) 20 | - 支持多种low-level任务和主流图像复原网络。如denoise/super resolution/deblur/derain等任务,mimo-unet/restormer等图像复原网络,可以根据自身需求添加任务或者网络结构 21 | - 主流复原模型均配备了官方部署配置(option_official_implementation_xxx.json),无需重构代码即可快速实验 22 | - 实验管理方便。每一个实验下,均保存本次实验的原始配置json文件、训练日志文件、tensorboard的event文件、以及验证阶段都有对应checkpoint、可视化图像存储 23 | 24 | ## 环境依赖(Package dependencies) 25 | Toolbox是在PyTorch 2.0.1+cu118, Python3.9.6, CUDA12.2的虚拟环境中测试的,(PyTorch 1.13.1+cu118, Python3.9.6, CUDA12.2 的环境也可以使用,不过分布式训练的命令会有些差异),下载需要的包可以通过以下命令: 26 | 27 | 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | ## 模型种类(Model Categories) 33 | 目前可供训练的模型如下: 34 | - UNet 35 | - RRDBNet 36 | - MIMO-UNet / MIMO-UNet+ /MIMO-Unet-MFF 37 | - MPRNet 38 | - NAFNet 39 | - Restormer 40 | - Stripformer 41 | - Uformer 42 | - VapSR 43 | 44 | ## 使用方法(How To Use) 45 | 我们提供简单的demo来训练/测试/推理模型,以便快速启动。 这些demo/command无法涵盖所有情况,更多详细信息将在后续更新中提供。 46 | 47 | *TODO* 48 | 49 | ### 数据准备 50 | 组织数据的方式可参照这篇论文[Painter](https://github.com/baaivision/Painter),或者按照你自己喜欢的方式~ 51 | 52 | ### 项目架构 53 | 以下是项目根目录下主要功能介绍,主要修改options内的配置文件即可。\ 54 | 若要修改数据预处理、网络结构、loss函数等,参照下述说明即可。 55 | 56 | toolbox 57 | |-- main_train_sample.py # 训练代码入口 58 | |-- data # 数据集定义及预处理逻辑 59 | |-- logs # tensorboard可视化文件存储 60 | |-- models # 网络结构定义及选择 61 | |-- options # 训练配置json文件 62 | |-- results # 存储各次实验,以实验task命名 63 | |-- dataset # 数据集,也可用软链接 64 | |-- utils # 一些功能的类 65 | 66 | ### JSON文件主要参数解读 67 | "task":实验名称,建议是网络结构名称+一些重要参数+日期/编号,如rrdb_batchsize64_20230507 68 | "models":模型的优化方式,和模型结构区分,loss不一样,如plain只支持pixel loss 69 | "gpu_ids":单卡/多卡训练中,所使用的gpu编号,如4卡服务器为0 1 2 3 70 | "n_channels":数据集读入时的通道数,一般为3 71 | "path/root":任务名称,如results/superresolution。例,用rrdb做超分辨,那么实验结果可以在results/superresolution/rrdb_batchsize64_20230507目录下找到 72 | "datasets" 73 | "dataset_type":数据集类型,可以自己定义paired数据或者not paired数据等,默认plain为成对数据集 74 | "dataroot_H":数据集路径 75 | "H_size":Ground Truth的patch size 76 | "dataloader_num_workers":每个GPU上的线程数,一般不要太大,2-8之间为宜 77 | "dataloader_batch_size":每个GPU上的batch_size 78 | "netG" 79 | "net_type": 网络种类,目前支持rrdb rrdbnet unet mimounet mimounetplus mprnet nafnet restormer stripformer uformer vapsr 80 | "in_nc":输入通道数 81 | 注:其余参数可根据具体的网络结构进行定义 82 | "train" 83 | "checkpoint_test": 每多少iteration验证一次 84 | "checkpoint_save": 每多少iteration存储一次checkpoint 85 | "checkpoint_print": 每多少iteration打印一次训练情况 86 | 注:训练总的iteration数目,需要去train_main_sample.py内line 160手动修改! 87 | ... 88 | 89 | ### 训练模型 90 | # 在项目根目录下直接运行训练脚本 91 | $ cd toolbox 92 | $ python main_train_sample.py --opt options/option_xxxxx.json (单卡训练,注意不要输入--dist及其后的信息) 93 | # pytorch2.0.0版本以后的多卡训练,注意此时option文件中的gpu_ids必须为list,例如:[0, 1, 2, 3] 94 | $ torchrun --nproc_per_node=${GPU_NUMs} main_train_sample.py --opt options/option_xxxxx.json --dist True 95 | # pytorch2.0.0版本以前的多卡训练,注意此时option文件中的gpu_ids必须为list,例如:[0, 1, 2, 3] 96 | $ python -m torch.distributed.launch --nproc_per_node=8 --master_port=4321 main_train_sample.py --opt options/option_xxxxx.json --dist True 97 | 98 | ### 测试模型 99 | # 在项目根目录下直接运行训练脚本 100 | $ cd toolbox 101 | $ python main_test_sample.py --opt options/option_xxxxx.json --dist False (单卡测试) 102 | 103 | #### 注意测试模型和训练模型使用的是同一个json文件哦~ 104 | 105 | ### 一些推荐使用习惯 106 | - 待补充 107 | 108 | ## 引用(Citations) 109 | 如果我们的工具帮助到了您,不妨给我们点个星并引用一下吧 110 | 下面是BibTex的形式,使用需要Latex的 `url` 包. 111 | 112 | ``` latex 113 | @misc{toolbox@zjuimaging, 114 | author = {Shiqi Chen and Zida Chen and Ziran Zhang and Wenguan Zhang and Peng Luo and Zhengyue Zhuge and Jinwen Zhou}, 115 | title = {toolbox@zjuimaging: Open Source Image Restoration Toolbox of ZJU Imaging Lab}, 116 | howpublished = {\url{https://github.com/TanGeeGo/toolbox}}, 117 | year = {2023} 118 | } 119 | ``` 120 | 121 | > Shiqi Chen, Zida Chen, Ziran Zhang, Wenguan Zhang, Peng Luo, Zhengyue Zhuge, Jinwen Zhou. toolbox@zjuimaging: Open Source Image Restoration Toolbox of ZJU Imaging Lab. , 2023. -------------------------------------------------------------------------------- /data/__pycache__/dataset_plain.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/data/__pycache__/dataset_plain.cpython-39.pyc -------------------------------------------------------------------------------- /data/__pycache__/select_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/data/__pycache__/select_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /data/dataset_multiin.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import random 3 | import torch 4 | import numpy as np 5 | import scipy.io as scio 6 | import torch.utils.data as data 7 | import utils.utils_image as util 8 | import utils.utils_filter as filter 9 | 10 | class DatasetMultiin(data.Dataset): 11 | ''' 12 | # ----------------------------------------- 13 | # Get L/H for image-to-image mapping. 14 | # Both "paths_L" and "paths_H" are needed. 15 | # ----------------------------------------- 16 | # e.g., train denoiser with H and L 17 | # ----------------------------------------- 18 | ''' 19 | 20 | def __init__(self, opt): 21 | super(DatasetMultiin, self).__init__() 22 | self.opt = opt 23 | self.n_channels = opt['n_channels'] if opt['n_channels'] else 3 24 | self.patch_size = self.opt['H_size'] if self.opt['H_size'] else 64 25 | 26 | # ------------------------------------ 27 | # get the path of L/H 28 | # ------------------------------------ 29 | self.paths_H = util.get_image_paths(opt['dataroot_H']) 30 | self.paths_L = util.get_image_paths(opt['dataroot_L']) 31 | 32 | assert self.paths_H, 'Error: High path is empty.' 33 | assert self.paths_L, 'Error: L path is empty. Plain dataset assumes both L and H are given!' 34 | if self.paths_L and self.paths_H: 35 | assert len(self.paths_L) == len(self.paths_H), 'L/H mismatch - {}, {}.'.format(\ 36 | len(self.paths_L), len(self.paths_H)) 37 | 38 | # ------------------------------------ 39 | # get the kernel and sigma (for fsanet) 40 | # ------------------------------------ 41 | self.kernel_set = self.load_kernel() 42 | self.eptional = self.load_eptional() 43 | 44 | def load_kernel(self): 45 | kernel_file = sorted(glob.glob(self.opt['dataroot_K'] + "/*.mat")) 46 | 47 | kernel_set = list() 48 | for i in range(self.opt['kr_num']): 49 | kernel = scio.loadmat(kernel_file[i]) 50 | kernel = kernel['PSF'] 51 | # turning into gray scale (rgb make few differences, not a lot) 52 | kernel = kernel[:,:,0] + kernel[:,:,1] + kernel[:,:,2] 53 | kernel = kernel / np.sum(kernel) 54 | kernel = torch.FloatTensor(kernel).unsqueeze(0) 55 | # form a batch 56 | kernel_set.append(kernel) 57 | 58 | kernel_set = torch.cat(kernel_set, dim=0) 59 | 60 | return kernel_set 61 | 62 | def load_eptional(self): 63 | # different eptional fitting type 64 | if self.opt["eptional_name"] == 'sinefit': 65 | return torch.FloatTensor(filter.sinefit(self.opt['en_size'], self.opt['en_size'], self.opt['omega'], 66 | self.opt['theta_num'], self.opt['eptional_sigma'])) 67 | elif self.opt["eptional_name"] == 'gauss_fit': 68 | return torch.FloatTensor(filter.gauss_fit(self.opt['en_size'], self.opt['en_size'], self.opt['sig_low'], 69 | self.opt['sig_high'], self.opt['sig_num'])) 70 | elif self.opt["eptional_name"] == 'dir_map': 71 | return torch.ones((1, self.opt['en_size'], self.opt['en_size'])) 72 | 73 | def __getitem__(self, index): 74 | 75 | # ------------------------------------ 76 | # get H image 77 | # ------------------------------------ 78 | H_path = self.paths_H[index] 79 | img_H = util.imread_mat(H_path, 'hr', self.n_channels) 80 | 81 | # ------------------------------------ 82 | # get L image 83 | # ------------------------------------ 84 | L_path = self.paths_L[index] 85 | img_L = util.imread_mat(L_path, 'lr', self.n_channels+2) # with fov info 86 | 87 | # ------------------------------------ 88 | # if train, get L/H patch pair 89 | # ------------------------------------ 90 | if self.opt['phase'] == 'train': 91 | 92 | H, W, _ = img_H.shape 93 | 94 | # -------------------------------- 95 | # randomly crop the patch 96 | # -------------------------------- 97 | rnd_h = random.randint(0, max(0, H - self.patch_size)) 98 | rnd_w = random.randint(0, max(0, W - self.patch_size)) 99 | patch_L = img_L[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 100 | patch_H = img_H[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 101 | 102 | # -------------------------------- 103 | # augmentation - flip and/or rotate 104 | # -------------------------------- 105 | mode = random.randint(0, 7) 106 | patch_L, patch_H = util.augment_img(patch_L, mode=mode), util.augment_img(patch_H, mode=mode) 107 | 108 | # -------------------------------- 109 | # HWC to CHW, numpy(uint) to tensor 110 | # -------------------------------- 111 | img_L, img_H = util.uint2tensor3(patch_L), util.uint2tensor3(patch_H) 112 | 113 | else: 114 | 115 | # -------------------------------- 116 | # HWC to CHW, numpy(uint) to tensor 117 | # -------------------------------- 118 | img_L, img_H = util.uint2tensor3(img_L), util.uint2tensor3(img_H) 119 | 120 | return {'L': img_L, 'H': img_H, 'L_path': L_path, 'H_path': H_path, 121 | 'Kernels': self.kernel_set, 'Eptional': self.eptional} 122 | 123 | def __len__(self): 124 | return len(self.paths_H) 125 | -------------------------------------------------------------------------------- /data/dataset_plain.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch.utils.data as data 4 | import utils.utils_image as util 5 | 6 | 7 | class DatasetPlain(data.Dataset): 8 | ''' 9 | # ----------------------------------------- 10 | # Get L/H for image-to-image mapping. 11 | # Both "paths_L" and "paths_H" are needed. 12 | # ----------------------------------------- 13 | # e.g., train denoiser with H and L 14 | # ----------------------------------------- 15 | ''' 16 | 17 | def __init__(self, opt): 18 | super(DatasetPlain, self).__init__() 19 | self.opt = opt 20 | self.n_channels = opt['n_channels'] if opt['n_channels'] else 3 21 | self.patch_size = self.opt['H_size'] if self.opt['H_size'] else 64 22 | 23 | # ------------------------------------ 24 | # get the path of L/H 25 | # ------------------------------------ 26 | self.paths_H = util.get_image_paths(opt['dataroot_H']) 27 | self.paths_L = util.get_image_paths(opt['dataroot_L']) 28 | 29 | assert self.paths_H, 'Error: High path is empty.' 30 | assert self.paths_L, 'Error: L path is empty. Plain dataset assumes both L and H are given!' 31 | if self.paths_L and self.paths_H: 32 | assert len(self.paths_L) == len(self.paths_H), 'L/H mismatch - {}, {}.'.format(\ 33 | len(self.paths_L), len(self.paths_H)) 34 | 35 | def __getitem__(self, index): 36 | 37 | # ------------------------------------ 38 | # get H image 39 | # ------------------------------------ 40 | H_path = self.paths_H[index] 41 | img_H = util.imread_uint(H_path, self.n_channels) 42 | 43 | # ------------------------------------ 44 | # get L image 45 | # ------------------------------------ 46 | L_path = self.paths_L[index] 47 | img_L = util.imread_uint(L_path, self.n_channels) 48 | 49 | # ------------------------------------ 50 | # if train, get L/H patch pair 51 | # ------------------------------------ 52 | if self.opt['phase'] == 'train': 53 | 54 | H, W, _ = img_H.shape 55 | 56 | # -------------------------------- 57 | # randomly crop the patch 58 | # -------------------------------- 59 | rnd_h = random.randint(0, max(0, H - self.patch_size)) 60 | rnd_w = random.randint(0, max(0, W - self.patch_size)) 61 | patch_L = img_L[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 62 | patch_H = img_H[rnd_h:rnd_h + self.patch_size, rnd_w:rnd_w + self.patch_size, :] 63 | 64 | # -------------------------------- 65 | # augmentation - flip and/or rotate 66 | # -------------------------------- 67 | mode = random.randint(0, 7) 68 | patch_L, patch_H = util.augment_img(patch_L, mode=mode), util.augment_img(patch_H, mode=mode) 69 | 70 | # -------------------------------- 71 | # HWC to CHW, numpy(uint) to tensor 72 | # -------------------------------- 73 | img_L, img_H = util.uint2tensor3(patch_L), util.uint2tensor3(patch_H) 74 | 75 | else: 76 | 77 | # -------------------------------- 78 | # HWC to CHW, numpy(uint) to tensor 79 | # -------------------------------- 80 | img_L, img_H = util.uint2tensor3(img_L), util.uint2tensor3(img_H) 81 | 82 | return {'L': img_L, 'H': img_H, 'L_path': L_path, 'H_path': H_path} 83 | 84 | def __len__(self): 85 | return len(self.paths_H) 86 | -------------------------------------------------------------------------------- /data/select_dataset.py: -------------------------------------------------------------------------------- 1 | ''' 2 | # -------------------------------------------- 3 | # select dataset 4 | # -------------------------------------------- 5 | ''' 6 | 7 | def define_Dataset(dataset_opt): 8 | dataset_type = dataset_opt['dataset_type'].lower() 9 | if dataset_type in ['plain']: 10 | from data.dataset_plain import DatasetPlain as D # Low-quality image and High-quality image 11 | elif dataset_type in ['multimodal']: 12 | from data.dataset_multimodal import DatasetMultiModal as D 13 | elif dataset_type in ['multiin']: 14 | from data.dataset_multiin import DatasetMultiin as D 15 | 16 | dataset = D(dataset_opt) 17 | 18 | return dataset 19 | -------------------------------------------------------------------------------- /dataset/gen_dataset_wz_fov.py: -------------------------------------------------------------------------------- 1 | """ 2 | The frequency self-adaptive network for optical degradation correction 3 | paper link: 4 | @article{Lin_2022_OE, 5 | author = {Ting Lin and ShiQi Chen and Huajun Feng and Zhihai Xu and Qi Li and Yueting Chen}, 6 | journal = {Opt. Express}, 7 | keywords = {All optical devices; Blind deconvolution; Image processing; Image quality; Optical design; Ray tracing}, 8 | number = {13}, 9 | pages = {23485--23498}, 10 | publisher = {Optica Publishing Group}, 11 | title = {Non-blind optical degradation correction via frequency self-adaptive and finetune tactics}, 12 | volume = {30}, 13 | month = {Jun}, 14 | year = {2022}, 15 | url = {https://opg.optica.org/oe/abstract.cfm?URI=oe-30-13-23485}, 16 | doi = {10.1364/OE.458530}, 17 | abstract = {In mobile photography applications, limited volume constraints the diversity of optical design. In addition to the narrow space, the deviations introduced in mass production cause random bias to the real camera. In consequence, these factors introduce spatially varying aberration and stochastic degradation into the physical formation of an image. Many existing methods obtain excellent performance on one specific device but are not able to quickly adapt to mass production. To address this issue, we propose a frequency self-adaptive model to restore realistic features of the latent image. The restoration is mainly performed in the Fourier domain and two attention mechanisms are introduced to match the feature between Fourier and spatial domain. Our method applies a lightweight network, without requiring modification when the fields of view (FoV) changes. Considering the manufacturing deviations of a specific camera, we first pre-train a simulation-based model, then finetune it with additional manufacturing error, which greatly decreases the time and computational overhead consumption in implementation. Extensive results verify the promising applications of our technique for being integrated with the existing post-processing systems.}, 18 | } 19 | """ 20 | ##### Data preparation file for training Model on the Dataset with field of view information ######## 21 | 22 | import tifffile 23 | import numpy as np 24 | from glob import glob 25 | from natsort import natsorted 26 | import os 27 | from tqdm import tqdm 28 | from pdb import set_trace as stx 29 | from joblib import Parallel, delayed 30 | import scipy.io as scio 31 | 32 | def crop_files(file_): 33 | lr_file, hr_file = file_ 34 | filename = os.path.splitext(os.path.split(lr_file)[-1])[0] 35 | lr_img = tifffile.imread(lr_file) 36 | hr_img = tifffile.imread(hr_file) 37 | # normalize to 0~1 float 38 | lr_img = lr_img.astype(np.float32) / 65535. 39 | hr_img = hr_img.astype(np.float32) / 65535. 40 | num_patch = 0 41 | 42 | # field of view calculation 43 | h, w = lr_img.shape[:2] 44 | h_range = np.arange(0, h, 1) 45 | w_range = np.arange(0, w, 1) 46 | fov_w, fov_h = np.meshgrid(w_range, h_range) 47 | fov_h = ((fov_h - (h-1)/2) / ((h-1)/2)).astype(np.float32) # normalization 48 | fov_w = ((fov_w - (w-1)/2) / ((w-1)/2)).astype(np.float32) # normalization 49 | fov_h = np.expand_dims(fov_h, -1) 50 | fov_w = np.expand_dims(fov_w, -1) 51 | lr_wz_fov = np.concatenate([lr_img, fov_h, fov_w], 2) 52 | 53 | if w > p_max and h > p_max: 54 | w1 = list(np.arange(0, w-patch_size, patch_size-overlap, dtype=np.int32)) 55 | h1 = list(np.arange(0, h-patch_size, patch_size-overlap, dtype=np.int32)) 56 | w1.append(w-patch_size) 57 | h1.append(h-patch_size) 58 | for i in h1: 59 | for j in w1: 60 | num_patch += 1 61 | 62 | lr_patch = lr_wz_fov[i:i+patch_size, j:j+patch_size,:] 63 | hr_patch = hr_img[i:i+patch_size, j:j+patch_size,:] 64 | 65 | lr_savename = os.path.join(lr_tar, filename + '-' + str(num_patch) + '.mat') 66 | hr_savename = os.path.join(hr_tar, filename + '-' + str(num_patch) + '.mat') 67 | 68 | scio.savemat(lr_savename, {'lr': lr_patch}) 69 | scio.savemat(hr_savename, {'hr': hr_patch}) 70 | 71 | else: 72 | lr_savename = os.path.join(lr_tar, filename + '.mat') 73 | hr_savename = os.path.join(hr_tar, filename + '.mat') 74 | 75 | scio.savemat(lr_savename, {'lr': lr_patch}) 76 | scio.savemat(hr_savename, {'hr': hr_patch}) 77 | 78 | ############ Prepare Training data #################### 79 | num_cores = 10 80 | patch_size = 512 81 | overlap = 256 82 | p_max = 0 83 | 84 | src = '/data1/Aberration_Correction/synthetic_datasets/train' 85 | tar = '/data1/Aberration_Correction/train' 86 | 87 | os.makedirs(tar, exist_ok=True) 88 | 89 | lr_tar = os.path.join(tar, 'input_crops') 90 | hr_tar = os.path.join(tar, 'target_crops') 91 | 92 | os.makedirs(lr_tar, exist_ok=True) 93 | os.makedirs(hr_tar, exist_ok=True) 94 | 95 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.tiff')) + glob(os.path.join(src, 'input', '*.png'))) 96 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.tiff')) + glob(os.path.join(src, 'target', '*.png'))) 97 | 98 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 99 | 100 | Parallel(n_jobs=num_cores)(delayed(crop_files)(file_) for file_ in tqdm(files)) 101 | 102 | ############ Prepare Validating data #################### 103 | num_cores = 10 104 | patch_size = 512 105 | overlap = 256 106 | p_max = 0 107 | src = '/data1/Aberration_Correction/synthetic_datasets/val' 108 | tar = '/data1/Aberration_Correction/val' 109 | 110 | os.makedirs(tar, exist_ok=True) 111 | 112 | lr_tar = os.path.join(tar, 'input_crops') 113 | hr_tar = os.path.join(tar, 'target_crops') 114 | 115 | os.makedirs(lr_tar, exist_ok=True) 116 | os.makedirs(hr_tar, exist_ok=True) 117 | 118 | lr_files = natsorted(glob(os.path.join(src, 'input', '*.tiff')) + glob(os.path.join(src, 'input', '*.png'))) 119 | hr_files = natsorted(glob(os.path.join(src, 'target', '*.tiff')) + glob(os.path.join(src, 'target', '*.png'))) 120 | 121 | files = [(i, j) for i, j in zip(lr_files, hr_files)] 122 | 123 | Parallel(n_jobs=num_cores)(delayed(crop_files)(file_) for file_ in tqdm(files)) 124 | -------------------------------------------------------------------------------- /main_test_sample.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import cv2 3 | import glob 4 | import numpy as np 5 | from collections import OrderedDict 6 | import os 7 | import logging 8 | import torch 9 | from torch.utils.data import DataLoader 10 | 11 | from utils import utils_image as util 12 | from utils import utils_option as option 13 | 14 | from utils import utils_logger 15 | from data.select_dataset import define_Dataset 16 | from models.select_model import define_Model 17 | 18 | def main(json_path='options/option.json'): 19 | 20 | ''' 21 | # ---------------------------------------- 22 | # Step--1 (prepare opt) 23 | # ---------------------------------------- 24 | ''' 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.') 27 | 28 | args = parser.parse_args() 29 | 30 | opt = option.parse(args.opt, is_train=False) 31 | 32 | iters, path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') 33 | opt['path']['pretrained_netG'] = path_G 34 | 35 | # ---------------------------------------- 36 | # configure logger 37 | # ---------------------------------------- 38 | logger_name = 'test' 39 | utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) 40 | logger = logging.getLogger(logger_name) 41 | logger.info(option.dict2str(opt)) 42 | 43 | ''' 44 | # ---------------------------------------- 45 | # Step--2 (creat dataloader) 46 | # ---------------------------------------- 47 | ''' 48 | 49 | # ---------------------------------------- 50 | # 1) create_dataset 51 | # 2) creat_dataloader for train and valid 52 | # ---------------------------------------- 53 | for phase, dataset_opt in opt['datasets'].items(): 54 | if phase == 'test': 55 | test_set = define_Dataset(dataset_opt) 56 | test_loader = DataLoader(test_set, 57 | batch_size=dataset_opt['dataloader_batch_size'], 58 | shuffle=False, 59 | num_workers=dataset_opt['dataloader_num_workers'], 60 | drop_last=False, 61 | pin_memory=True) 62 | else: 63 | # leave the phase of train and valid into the training 64 | pass 65 | 66 | ''' 67 | # ---------------------------------------- 68 | # Step--3 (initialize model) 69 | # ---------------------------------------- 70 | ''' 71 | 72 | model = define_Model(opt) 73 | model.init_test() 74 | 75 | ''' 76 | # ---------------------------------------- 77 | # Step--4 (main testing) 78 | # ---------------------------------------- 79 | ''' 80 | 81 | for i, test_data in enumerate(test_loader): 82 | 83 | # ------------------------------- 84 | # 1) feed patch pairs 85 | # ------------------------------- 86 | model.feed_data(test_data) 87 | 88 | # ------------------------------- 89 | # 2) evaluate data 90 | # ------------------------------- 91 | model.netG_forward() 92 | 93 | visuals = model.current_visuals() 94 | E_img = util.tensor2uint(visuals['E']) 95 | H_img = util.tensor2uint(visuals['H']) 96 | 97 | # ------------------------------- 98 | # 3) save tested image E 99 | # ------------------------------- 100 | image_name_ext = os.path.basename(test_data['L_path'][0]) 101 | img_name, ext = os.path.splitext(image_name_ext) 102 | 103 | save_E_img_path = os.path.join(opt['path']['images'], '{:s}_pred.png'.format(img_name)) 104 | save_H_img_path = os.path.join(opt['path']['images'], '{:s}_grdt.png'.format(img_name)) 105 | util.imsave(E_img, save_E_img_path) 106 | util.imsave(H_img, save_H_img_path) 107 | 108 | # ----------------------- 109 | # 4) calculate indicators 110 | # ----------------------- 111 | psnr = util.calculate_psnr(E_img, H_img) 112 | ssim = util.calculate_ssim(E_img, H_img) 113 | model.log_dict['psnr'].append(psnr) 114 | model.log_dict['ssim'].append(ssim) 115 | if H_img.ndim == 3: 116 | E_img_y = util.bgr2ycbcr(E_img.astype(np.float32) / 255.) * 255. 117 | H_img_y = util.bgr2ycbcr(H_img.astype(np.float32) / 255.) * 255. 118 | psnr_y = util.calculate_psnr(E_img_y, H_img_y) 119 | ssim_y = util.calculate_ssim(E_img_y, H_img_y) 120 | model.log_dict['psnr_y'].append(psnr_y) 121 | model.log_dict['ssim_y'].append(ssim_y) 122 | 123 | logger.info('Image:{}, PSNR: {:<.4f}dB, SSIM: {:<.4f}, PSNR_Y: {:<.4f}dB, SSIM_Y: {:<.4f}\n'.format(\ 124 | img_name, psnr, ssim, psnr_y, ssim_y)) 125 | 126 | logger.info('PSNR_Average: {:<.4f}dB, SSIM_Average: {:<.4f}, PSNR_Y_Average: {:<.4f}dB, SSIM_Y_Average: {:<.4f}\n'.format(\ 127 | sum(model.log_dict['psnr'])/(i+1), sum(model.log_dict['ssim'])/(i+1), sum(model.log_dict['psnr_y'])/(i+1), sum(model.log_dict['ssim_y'])/(i+1))) 128 | 129 | if __name__ == '__main__': 130 | main() -------------------------------------------------------------------------------- /main_train_sample.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | import math 4 | import argparse 5 | import time 6 | import random 7 | import numpy as np 8 | from collections import OrderedDict 9 | import logging 10 | from torch.utils.data import DataLoader 11 | from torch.utils.data.distributed import DistributedSampler 12 | import torch 13 | 14 | from utils import utils_logger 15 | from utils import utils_image as util 16 | from utils import utils_option as option 17 | from utils.utils_dist import init_dist, get_dist_info 18 | from torch.utils.tensorboard import SummaryWriter 19 | 20 | from data.select_dataset import define_Dataset 21 | from models.select_model import define_Model 22 | 23 | def main(json_path='options/option.json'): 24 | 25 | ''' 26 | # ---------------------------------------- 27 | # Step--1 (prepare opt) 28 | # ---------------------------------------- 29 | ''' 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('--opt', type=str, default=json_path, help='Path to option JSON file.') 32 | parser.add_argument('--dist', type=bool, default=False) 33 | 34 | args = parser.parse_args() 35 | 36 | opt = option.parse(args.opt, is_train=True) 37 | 38 | torch.backends.cudnn.benchmark = True 39 | # torch.backends.cudnn.deterministic = True 40 | 41 | log_dir = Path(opt['path']['log']) 42 | log_dir.mkdir(exist_ok=True, parents=True) 43 | writer = SummaryWriter(log_dir=str(log_dir)) 44 | # ---------------------------------------- 45 | # distributed settings of training 46 | # ---------------------------------------- 47 | opt['dist'] = args.dist 48 | if opt['dist']: 49 | print(opt['dist']) 50 | init_dist('pytorch') 51 | opt['rank'], opt['word_size'] = get_dist_info() 52 | 53 | if opt['rank'] == 0: 54 | print('export CUDA_VISIBLE_DEVICES=' + ','.join(str(x) for x in opt['gpu_ids'])) 55 | print('number of GPUs is: ' + str(opt['num_gpu'])) 56 | util.mkdirs((path for key, path in opt['path'].items() if 'pretrained' not in key)) 57 | 58 | # ---------------------------------------- 59 | # update opt 60 | # ---------------------------------------- 61 | # -->-->-->-->-->-->-->-->-->-->-->-->-->- 62 | init_iter, init_path_G = option.find_last_checkpoint(opt['path']['models'], net_type='G') 63 | opt['path']['pretrained_netG'] = init_path_G 64 | current_step = init_iter 65 | # --<--<--<--<--<--<--<--<--<--<--<--<--<- 66 | 67 | # ---------------------------------------- 68 | # save opt to a '../option.json' file 69 | # ---------------------------------------- 70 | if opt['rank'] == 0: 71 | option.save(opt) 72 | 73 | # ---------------------------------------- 74 | # return None for missing key 75 | # ---------------------------------------- 76 | opt = option.dict_to_nonedict(opt) 77 | 78 | # ---------------------------------------- 79 | # configure logger 80 | # ---------------------------------------- 81 | logger_name = 'train' 82 | utils_logger.logger_info(logger_name, os.path.join(opt['path']['log'], logger_name+'.log')) 83 | logger = logging.getLogger(logger_name) 84 | # logger.info(option.dict2str(opt)) 85 | 86 | # ---------------------------------------- 87 | # seed 88 | # ---------------------------------------- 89 | seed = opt['train']['manual_seed'] 90 | if seed is None: 91 | seed = random.randint(1, 10000) 92 | logger.info('Random seed: {}'.format(seed)) 93 | random.seed(seed) 94 | np.random.seed(seed) 95 | torch.manual_seed(seed) 96 | torch.cuda.manual_seed_all(seed) 97 | 98 | ''' 99 | # ---------------------------------------- 100 | # Step--2 (creat dataloader) 101 | # ---------------------------------------- 102 | ''' 103 | 104 | # ---------------------------------------- 105 | # 1) create_dataset 106 | # 2) creat_dataloader for train and valid 107 | # ---------------------------------------- 108 | for phase, dataset_opt in opt['datasets'].items(): 109 | if phase == 'train': 110 | train_set = define_Dataset(dataset_opt) 111 | if opt['rank'] == 0: 112 | print('Dataset [{:s} - {:s}] is created.'.\ 113 | format(train_set.__class__.__name__, dataset_opt['name'])) 114 | train_size = int(math.ceil(len(train_set) / dataset_opt['dataloader_batch_size'])) 115 | if opt['rank'] == 0: 116 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(len(train_set), train_size)) 117 | if opt['dist']: 118 | train_sampler = DistributedSampler(train_set, shuffle=dataset_opt['dataloader_shuffle'], 119 | drop_last=True, seed=seed) 120 | train_loader = DataLoader(train_set, 121 | batch_size=dataset_opt['dataloader_batch_size']//opt['num_gpu'], 122 | shuffle=False, 123 | num_workers=dataset_opt['dataloader_num_workers']//opt['num_gpu'], 124 | drop_last=True, 125 | pin_memory=True, 126 | sampler=train_sampler) 127 | else: 128 | train_loader = DataLoader(train_set, 129 | batch_size=dataset_opt['dataloader_batch_size'], 130 | shuffle=dataset_opt['dataloader_shuffle'], 131 | num_workers=dataset_opt['dataloader_num_workers'], 132 | drop_last=True, 133 | pin_memory=True) 134 | 135 | elif phase == 'valid': 136 | valid_set = define_Dataset(dataset_opt) 137 | if opt['rank'] == 0: 138 | print('Dataset [{:s} - {:s}] is created.'.\ 139 | format(valid_set.__class__.__name__, dataset_opt['name'])) 140 | valid_loader = DataLoader(valid_set, 141 | batch_size=dataset_opt['dataloader_batch_size'], 142 | shuffle=False, 143 | num_workers=dataset_opt['dataloader_num_workers'], 144 | drop_last=False, 145 | pin_memory=True) 146 | else: 147 | # leave the phase of test into the evaluation 148 | pass 149 | 150 | ''' 151 | # ---------------------------------------- 152 | # Step--3 (initialize model) 153 | # ---------------------------------------- 154 | ''' 155 | 156 | model = define_Model(opt) 157 | if opt['rank'] == 0: 158 | print('Training model [{:s}] is created.'.format(model.__class__.__name__)) 159 | if opt['netG']['init_type'] not in ['default', 'none']: 160 | print('Initialization method [{:s} + {:s}], gain is [{:.2f}]'.format(\ 161 | opt['netG']['init_type'], opt['netG']['init_bn_type'], opt['netG']['init_gain'])) 162 | else: 163 | print('Pass this initialization! Initialization was done during network definition!') 164 | 165 | model.init_train() 166 | # unnote it if you want to see the detail of the model 167 | # if opt['rank'] == 0: 168 | # logger.info(model.info_network()) 169 | # logger.info(model.info_params()) 170 | # pass 171 | 172 | ''' 173 | # ---------------------------------------- 174 | # Step--4 (main training) 175 | # ---------------------------------------- 176 | ''' 177 | for epoch in range(opt['train']['total_epoch']): # TODO: the terminate condition 178 | logger.info('EPOCH: {:3d}'.format(epoch)) 179 | if opt['dist']: 180 | train_sampler.set_epoch(epoch) # set the sampler in data distribution 181 | 182 | for i, train_data in enumerate(train_loader): 183 | 184 | current_step += 1 185 | 186 | # ------------------------------- 187 | # 1) feed patch pairs 188 | # ------------------------------- 189 | model.feed_data(train_data, epoch) if opt['model'] == 'progressive' else \ 190 | model.feed_data(train_data) 191 | 192 | # ------------------------------- 193 | # 2) optimize parameters 194 | # ------------------------------- 195 | model.optimize_parameters(current_step) 196 | 197 | # ------------------------------- 198 | # 3) training information 199 | # ------------------------------- 200 | logs = model.current_log() 201 | writer.add_scalar("lr", model.current_learning_rate(), epoch + 1) 202 | for k, v in logs.items(): # merge log information into message 203 | writer.add_scalar(k, v, epoch + 1) 204 | 205 | if current_step % opt['train']['checkpoint_print'] == 0 and opt['rank'] == 0: 206 | logs = model.current_log() # such as loss 207 | message = ' '.\ 208 | format(epoch, current_step, model.current_learning_rate()) 209 | for k, v in logs.items(): # merge log information into message 210 | message += '{:s}: {:.6f} '.format(k, v) 211 | logger.info(message) 212 | 213 | # ------------------------------- 214 | # 4) update learning rate 215 | # ------------------------------- 216 | model.update_learning_rate() 217 | 218 | # ------------------------------- 219 | # 5) saving model 220 | # ------------------------------- 221 | if epoch % opt['train']['checkpoint_save'] == 0 and opt['rank'] == 0: 222 | logger.info('Saving the model.') 223 | model.save(current_step) 224 | 225 | # ------------------------------- 226 | # 6) validating 227 | # ------------------------------- 228 | if epoch % opt['train']['checkpoint_valid'] == 0 and opt['rank'] == 0: 229 | 230 | avg_psnr = 0.0 231 | avg_ssim = 0.0 232 | idx = 0 233 | 234 | for valid_data in valid_loader: 235 | idx += 1 236 | image_name_ext = os.path.basename(valid_data['L_path'][0]) 237 | img_name, ext = os.path.splitext(image_name_ext) 238 | 239 | img_dir = os.path.join(opt['path']['images'], img_name) 240 | util.mkdir(img_dir) 241 | 242 | model.feed_data(valid_data) 243 | model.valid() 244 | 245 | visuals = model.current_visuals() 246 | E_img = util.tensor2uint(visuals['E']) 247 | H_img = util.tensor2uint(visuals['H']) 248 | 249 | # ----------------------- 250 | # save estimated image E 251 | # ----------------------- 252 | save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) 253 | util.imsave(E_img, save_img_path) 254 | 255 | # ----------------------- 256 | # calculate PSNR and SSIM 257 | # ----------------------- 258 | current_psnr = util.calculate_psnr(E_img, H_img) 259 | current_ssim = util.calculate_ssim(E_img, H_img) 260 | 261 | # logger.info('{:->4d}--> {:>10s} | {:<4.2f}dB'.format(idx, image_name_ext, current_psnr)) 262 | 263 | avg_psnr += current_psnr 264 | avg_ssim += current_ssim 265 | 266 | avg_psnr = avg_psnr / idx 267 | avg_ssim = avg_ssim / idx 268 | 269 | # validating log 270 | writer.add_scalar('Average PSNR', avg_psnr, epoch + 1) 271 | logger.info(' [BN] => ReLU) * 2""" 11 | 12 | def __init__(self, in_channels, out_channels, mid_channels=None): 13 | super().__init__() 14 | if not mid_channels: 15 | mid_channels = out_channels 16 | self.double_conv = nn.Sequential( 17 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 18 | nn.BatchNorm2d(mid_channels), 19 | nn.ReLU(inplace=True), 20 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 21 | nn.BatchNorm2d(out_channels), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.double_conv(x) 27 | 28 | 29 | class Down(nn.Module): 30 | """Downscaling with maxpool then double conv""" 31 | 32 | def __init__(self, in_channels, out_channels): 33 | super().__init__() 34 | self.maxpool_conv = nn.Sequential( 35 | nn.MaxPool2d(2), 36 | DoubleConv(in_channels, out_channels) 37 | ) 38 | 39 | def forward(self, x): 40 | return self.maxpool_conv(x) 41 | 42 | 43 | class Up(nn.Module): 44 | """Upscaling then double conv""" 45 | 46 | def __init__(self, in_channels, out_channels, bilinear=True): 47 | super().__init__() 48 | 49 | # if bilinear, use the normal convolutions to reduce the number of channels 50 | if bilinear: 51 | self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 52 | self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 53 | else: 54 | self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 55 | self.conv = DoubleConv(in_channels, out_channels) 56 | 57 | def forward(self, x1, x2): 58 | x1 = self.up(x1) 59 | # input is CHW 60 | diffY = x2.size()[2] - x1.size()[2] 61 | diffX = x2.size()[3] - x1.size()[3] 62 | 63 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, 64 | diffY // 2, diffY - diffY // 2]) 65 | # if you have padding issues, see 66 | # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a 67 | # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd 68 | x = torch.cat([x2, x1], dim=1) 69 | return self.conv(x) 70 | 71 | 72 | class OutConv(nn.Module): 73 | def __init__(self, in_channels, out_channels): 74 | super(OutConv, self).__init__() 75 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 76 | 77 | def forward(self, x): 78 | return self.conv(x) 79 | 80 | 81 | class UNet(nn.Module): 82 | def __init__(self, n_channels, n_classes, bilinear=False): 83 | super(UNet, self).__init__() 84 | self.n_channels = n_channels 85 | self.n_classes = n_classes 86 | self.bilinear = bilinear 87 | 88 | self.inc = DoubleConv(n_channels, 64) 89 | self.down1 = Down(64, 128) 90 | self.down2 = Down(128, 256) 91 | self.down3 = Down(256, 512) 92 | factor = 2 if bilinear else 1 93 | self.down4 = Down(512, 1024 // factor) 94 | self.up1 = Up(1024, 512 // factor, bilinear) 95 | self.up2 = Up(512, 256 // factor, bilinear) 96 | self.up3 = Up(256, 128 // factor, bilinear) 97 | self.up4 = Up(128, 64, bilinear) 98 | self.outc = OutConv(64, n_classes) 99 | 100 | def forward(self, x): 101 | x1 = self.inc(x) 102 | x2 = self.down1(x1) 103 | x3 = self.down2(x2) 104 | x4 = self.down3(x3) 105 | x5 = self.down4(x4) 106 | x = self.up1(x5, x4) 107 | x = self.up2(x, x3) 108 | x = self.up3(x, x2) 109 | x = self.up4(x, x1) 110 | logits = self.outc(x) 111 | return logits -------------------------------------------------------------------------------- /models/model_base.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from utils.utils_bnorm import merge_bn, tidy_sequential 5 | from torch.nn.parallel import DataParallel, DistributedDataParallel 6 | 7 | 8 | class ModelBase(): 9 | def __init__(self, opt): 10 | self.opt = opt # opt 11 | self.save_dir = opt['path']['models'] # save models 12 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 13 | self.is_train = opt['is_train'] # training or not 14 | self.schedulers = [] # schedulers 15 | 16 | """ 17 | # ---------------------------------------- 18 | # Preparation before training with data 19 | # Save model during training 20 | # ---------------------------------------- 21 | """ 22 | 23 | def init_train(self): 24 | pass 25 | 26 | def load(self): 27 | pass 28 | 29 | def save(self, label): 30 | pass 31 | 32 | def define_loss(self): 33 | pass 34 | 35 | def define_optimizer(self): 36 | pass 37 | 38 | def define_scheduler(self): 39 | pass 40 | 41 | """ 42 | # ---------------------------------------- 43 | # Optimization during training with data 44 | # Testing/evaluation 45 | # ---------------------------------------- 46 | """ 47 | 48 | def feed_data(self, data): 49 | pass 50 | 51 | def optimize_parameters(self): 52 | pass 53 | 54 | def current_visuals(self): 55 | pass 56 | 57 | def current_losses(self): 58 | pass 59 | 60 | def update_learning_rate(self): 61 | for scheduler in self.schedulers: 62 | scheduler.step() 63 | 64 | def current_learning_rate(self): 65 | return self.schedulers[0].get_last_lr()[0] 66 | 67 | def requires_grad(self, model, flag=True): 68 | for p in model.parameters(): 69 | p.requires_grad = flag 70 | 71 | """ 72 | # ---------------------------------------- 73 | # Information of net 74 | # ---------------------------------------- 75 | """ 76 | 77 | def print_network(self): 78 | pass 79 | 80 | def info_network(self): 81 | pass 82 | 83 | def print_params(self): 84 | pass 85 | 86 | def info_params(self): 87 | pass 88 | 89 | def get_bare_model(self, network): 90 | """Get bare model, especially under wrapping with 91 | DistributedDataParallel or DataParallel. 92 | """ 93 | if isinstance(network, (DataParallel, DistributedDataParallel)): 94 | network = network.module 95 | return network 96 | 97 | def model_to_device(self, network): 98 | """Model to device. It also warps models with DistributedDataParallel 99 | or DataParallel. 100 | Args: 101 | network (nn.Module) 102 | """ 103 | network = network.to(self.device) 104 | if self.opt['dist']: 105 | find_unused_parameters = self.opt.get('find_unused_parameters', True) 106 | use_static_graph = self.opt.get('use_static_graph', False) 107 | network = DistributedDataParallel(network, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) 108 | if use_static_graph: 109 | print('Using static graph. Make sure that "unused parameters" will not change during training loop.') 110 | network._set_static_graph() 111 | else: 112 | network = DataParallel(network) 113 | return network 114 | 115 | # ---------------------------------------- 116 | # network name and number of parameters 117 | # ---------------------------------------- 118 | def describe_network(self, network): 119 | network = self.get_bare_model(network) 120 | msg = '\n' 121 | msg += 'Networks name: {}'.format(network.__class__.__name__) + '\n' 122 | msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), network.parameters()))) + '\n' 123 | msg += 'Net structure:\n{}'.format(str(network)) + '\n' 124 | return msg 125 | 126 | # ---------------------------------------- 127 | # parameters description 128 | # ---------------------------------------- 129 | def describe_params(self, network): 130 | network = self.get_bare_model(network) 131 | msg = '\n' 132 | msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' 133 | for name, param in network.state_dict().items(): 134 | if not 'num_batches_tracked' in name: 135 | v = param.data.clone().float() 136 | msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' 137 | return msg 138 | 139 | """ 140 | # ---------------------------------------- 141 | # Save prameters 142 | # Load prameters 143 | # ---------------------------------------- 144 | """ 145 | 146 | # ---------------------------------------- 147 | # save the state_dict of the network 148 | # ---------------------------------------- 149 | def save_network(self, save_dir, network, network_label, iter_label): 150 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 151 | save_path = os.path.join(save_dir, save_filename) 152 | network = self.get_bare_model(network) 153 | state_dict = network.state_dict() 154 | for key, param in state_dict.items(): 155 | state_dict[key] = param.cpu() 156 | torch.save(state_dict, save_path) 157 | 158 | # ---------------------------------------- 159 | # load the state_dict of the network 160 | # ---------------------------------------- 161 | def load_network(self, load_path, network, strict=True, param_key='params'): 162 | network = self.get_bare_model(network) 163 | if strict: 164 | state_dict = torch.load(load_path) 165 | if param_key in state_dict.keys(): 166 | state_dict = state_dict[param_key] 167 | network.load_state_dict(state_dict, strict=strict) 168 | else: 169 | state_dict_old = torch.load(load_path) 170 | if param_key in state_dict_old.keys(): 171 | state_dict_old = state_dict_old[param_key] 172 | state_dict = network.state_dict() 173 | for ((key_old, param_old),(key, param)) in zip(state_dict_old.items(), state_dict.items()): 174 | state_dict[key] = param_old 175 | network.load_state_dict(state_dict, strict=True) 176 | del state_dict_old, state_dict 177 | 178 | # ---------------------------------------- 179 | # save the state_dict of the optimizer 180 | # ---------------------------------------- 181 | def save_optimizer(self, save_dir, optimizer, optimizer_label, iter_label): 182 | save_filename = '{}_{}.pth'.format(iter_label, optimizer_label) 183 | save_path = os.path.join(save_dir, save_filename) 184 | torch.save(optimizer.state_dict(), save_path) 185 | 186 | # ---------------------------------------- 187 | # load the state_dict of the optimizer 188 | # ---------------------------------------- 189 | def load_optimizer(self, load_path, optimizer): 190 | optimizer.load_state_dict(torch.load(load_path, map_location=lambda storage, loc: storage.cuda(torch.cuda.current_device()))) 191 | 192 | def update_E(self, decay=0.999): 193 | netG = self.get_bare_model(self.netG) 194 | netG_params = dict(netG.named_parameters()) 195 | netE_params = dict(self.netE.named_parameters()) 196 | for k in netG_params.keys(): 197 | netE_params[k].data.mul_(decay).add_(netG_params[k].data, alpha=1-decay) 198 | 199 | """ 200 | # ---------------------------------------- 201 | # Merge Batch Normalization for training 202 | # Merge Batch Normalization for testing 203 | # ---------------------------------------- 204 | """ 205 | 206 | # ---------------------------------------- 207 | # merge bn during training 208 | # ---------------------------------------- 209 | def merge_bnorm_train(self): 210 | merge_bn(self.netG) 211 | tidy_sequential(self.netG) 212 | self.define_optimizer() 213 | self.define_scheduler() 214 | 215 | # ---------------------------------------- 216 | # merge bn before testing 217 | # ---------------------------------------- 218 | def merge_bnorm_test(self): 219 | merge_bn(self.netG) 220 | tidy_sequential(self.netG) 221 | -------------------------------------------------------------------------------- /models/model_multiin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from collections import OrderedDict 4 | 5 | from models.model_plain import ModelPlain 6 | 7 | from utils.utils_regularizers import regularizer_orth, regularizer_clip 8 | 9 | class ModelMultiin(ModelPlain): 10 | """Train with pixel loss and has multiple outputs""" 11 | def __init__(self, opt): 12 | super(ModelMultiin, self).__init__(opt) 13 | 14 | """ 15 | # ---------------------------------------- 16 | # Optimization during training with data 17 | # Testing/evaluation 18 | # ---------------------------------------- 19 | """ 20 | 21 | # ---------------------------------------- 22 | # feed L/H data 23 | # ---------------------------------------- 24 | def feed_data(self, data, need_H=True): 25 | self.L = data['L'].to(self.device) 26 | if self.opt['netG']['net_type'] == "fsanet": 27 | self.Kernels = data['Kernels'].to(self.device) 28 | self.Eptional = data['Eptional'].to(self.device) 29 | elif self.opt['netG']['net_type'] == 'painter': 30 | self.Mask = data['Mask'].to(self.device) 31 | self.Valid = data['Valid'].to(self.device) 32 | 33 | if need_H: 34 | self.H = data['H'].to(self.device) 35 | 36 | # ---------------------------------------- 37 | # feed L to netG 38 | # ---------------------------------------- 39 | def netG_forward(self): 40 | # multiple input to the models 41 | if self.opt['netG']['net_type'] == "fsanet": 42 | self.E = self.netG(self.L, self.Kernels, self.Eptional) 43 | elif self.opt['netG']['net_type'] == 'painter': 44 | loss, self.E, self.Mask = self.netG(self.L, self.H, self.Mask, self.Valid) 45 | return loss 46 | 47 | # ---------------------------------------- 48 | # update parameters and get loss 49 | # ---------------------------------------- 50 | def optimize_parameters(self, current_step): 51 | self.G_optimizer.zero_grad() 52 | 53 | if self.opt['netG']['net_type'] == 'painter': 54 | G_loss = self.netG_forward() 55 | else: 56 | self.netG_forward() 57 | if len(self.G_lossfn_weight) == 1: 58 | G_loss = self.G_lossfn_weight[0] * self.G_lossfn(self.E, self.H) 59 | elif len(self.G_lossfn_weight) == 2: 60 | G_loss = self.G_lossfn_weight[0] * self.G_lossfn(self.E, self.H) + \ 61 | self.G_lossfn_weight[1] * self.G_lossfn_aux(self.E, self.H) 62 | 63 | G_loss.backward() 64 | 65 | # ------------------------------------ 66 | # clip_grad 67 | # ------------------------------------ 68 | # `clip_grad_norm` helps prevent the exploding gradient problem. 69 | G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0 70 | if G_optimizer_clipgrad > 0: 71 | torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) 72 | 73 | self.G_optimizer.step() 74 | 75 | # ------------------------------------ 76 | # regularizer 77 | # ------------------------------------ 78 | G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0 79 | if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: 80 | self.netG.apply(regularizer_orth) 81 | G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0 82 | if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: 83 | self.netG.apply(regularizer_clip) 84 | 85 | # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'` 86 | self.log_dict['G_loss'] = G_loss.item() 87 | 88 | if self.opt_train['E_decay'] > 0: 89 | self.update_E(self.opt_train['E_decay']) 90 | -------------------------------------------------------------------------------- /models/model_multiout.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from collections import OrderedDict 4 | 5 | from models.model_plain import ModelPlain 6 | 7 | from utils.utils_regularizers import regularizer_orth, regularizer_clip 8 | 9 | class ModelMultiout(ModelPlain): 10 | """Train with pixel loss and has multiple outputs""" 11 | def __init__(self, opt): 12 | super(ModelMultiout, self).__init__(opt) 13 | 14 | """ 15 | # ---------------------------------------- 16 | # Optimization during training with data 17 | # Testing/evaluation 18 | # ---------------------------------------- 19 | """ 20 | 21 | # ---------------------------------------- 22 | # feed L/H data 23 | # ---------------------------------------- 24 | def feed_data(self, data, need_H=True): 25 | self.L = data['L'].to(self.device) 26 | if need_H: 27 | # multiple supervision 28 | self.H = list() 29 | H_1 = data['H'].to(self.device) 30 | self.H.append(F.interpolate(H_1, scale_factor=0.25)) 31 | self.H.append(F.interpolate(H_1, scale_factor=0.5)) 32 | self.H.append(H_1) 33 | 34 | # ---------------------------------------- 35 | # update parameters and get loss 36 | # ---------------------------------------- 37 | def optimize_parameters(self, current_step): 38 | self.G_optimizer.zero_grad() 39 | self.netG_forward() 40 | # ensure the length of E equals to H 41 | assert len(self.E) == len(self.H), ValueError('Output amount is not right') 42 | # G_loss_list = [] 43 | if len(self.G_lossfn_weight) == 1: 44 | G_loss = self.G_lossfn(self.E[0], self.H[0]) + self.G_lossfn(self.E[1], self.H[1]) + \ 45 | self.G_lossfn(self.E[2], self.H[2]) 46 | elif len(self.G_lossfn_weight) == 2: 47 | G_loss_main = self.G_lossfn(self.E[0], self.H[0]) + self.G_lossfn(self.E[1], self.H[1]) + \ 48 | self.G_lossfn(self.E[2], self.H[2]) 49 | G_loss_aux = self.G_lossfn_aux(self.E[0], self.H[0]) + self.G_lossfn_aux(self.E[1], self.H[1]) + \ 50 | self.G_lossfn_aux(self.E[2], self.H[2]) 51 | G_loss = self.G_lossfn_weight[0] * G_loss_main + \ 52 | self.G_lossfn_weight[1] * G_loss_aux 53 | # if self.opt_train['G_lossfn_type'] == 'l1+fft': 54 | # loss_content = self.G_lossfn(self.E[0], self.H[0]) + self.G_lossfn(self.E[1], self.H[1]) + \ 55 | # self.G_lossfn(self.E[2], self.H[2]) 56 | # loss_fft = self.G_lossfn(torch.fft.rfft2(self.E[0]), torch.fft.rfft2(self.H[0])) + \ 57 | # self.G_lossfn(torch.fft.rfft2(self.E[1]), torch.fft.rfft2(self.H[1])) + \ 58 | # self.G_lossfn(torch.fft.rfft2(self.E[2]), torch.fft.rfft2(self.H[2])) 59 | # G_loss = self.G_lossfn_weight[0] * loss_content + self.G_lossfn_weight[1] * loss_fft 60 | # else: 61 | # raise NotImplementedError('Loss type [{:s}] is not found.'.format(self.opt_train['G_lossfn_type'])) 62 | G_loss.backward() 63 | 64 | # ------------------------------------ 65 | # clip_grad 66 | # ------------------------------------ 67 | # `clip_grad_norm` helps prevent the exploding gradient problem. 68 | G_optimizer_clipgrad = self.opt_train['G_optimizer_clipgrad'] if self.opt_train['G_optimizer_clipgrad'] else 0 69 | if G_optimizer_clipgrad > 0: 70 | torch.nn.utils.clip_grad_norm_(self.netG.parameters(), max_norm=self.opt_train['G_optimizer_clipgrad'], norm_type=2) 71 | 72 | self.G_optimizer.step() 73 | 74 | # ------------------------------------ 75 | # regularizer 76 | # ------------------------------------ 77 | G_regularizer_orthstep = self.opt_train['G_regularizer_orthstep'] if self.opt_train['G_regularizer_orthstep'] else 0 78 | if G_regularizer_orthstep > 0 and current_step % G_regularizer_orthstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: 79 | self.netG.apply(regularizer_orth) 80 | G_regularizer_clipstep = self.opt_train['G_regularizer_clipstep'] if self.opt_train['G_regularizer_clipstep'] else 0 81 | if G_regularizer_clipstep > 0 and current_step % G_regularizer_clipstep == 0 and current_step % self.opt['train']['checkpoint_save'] != 0: 82 | self.netG.apply(regularizer_clip) 83 | 84 | # self.log_dict['G_loss'] = G_loss.item()/self.E.size()[0] # if `reduction='sum'` 85 | if len(self.G_lossfn_weight) == 1: 86 | self.log_dict['G_loss'] = G_loss.item() 87 | elif len(self.G_lossfn_weight) == 2: 88 | self.log_dict[self.G_lossfn_type_[0]+'_loss'] = G_loss_main.item() 89 | self.log_dict[self.G_lossfn_type_[1]+'_loss'] = G_loss_aux.item() 90 | self.log_dict['G_loss'] = G_loss.item() 91 | 92 | if self.opt_train['E_decay'] > 0: 93 | self.update_E(self.opt_train['E_decay']) 94 | 95 | # ---------------------------------------- 96 | # get L, E, H image 97 | # ---------------------------------------- 98 | def current_visuals(self, need_H=True): 99 | out_dict = OrderedDict() 100 | out_dict['L'] = self.L.detach()[0].float().cpu() 101 | # multiscale outputs, the highest resolution is in the last 102 | out_dict['E'] = self.E[-1].detach()[0].float().cpu() 103 | if need_H: 104 | out_dict['H'] = self.H[-1].detach()[0].float().cpu() 105 | return out_dict 106 | 107 | # ---------------------------------------- 108 | # get L, E, H batch images 109 | # ---------------------------------------- 110 | def current_results(self, need_H=True): 111 | out_dict = OrderedDict() 112 | out_dict['L'] = self.L.detach().float().cpu() 113 | # multiscale outputs, the highest resolution is in the last 114 | out_dict['E'] = self.E[-1].detach()[0].float().cpu() 115 | if need_H: 116 | out_dict['H'] = self.H.detach().float().cpu() 117 | return out_dict 118 | -------------------------------------------------------------------------------- /models/model_progressive.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | from models.model_plain import ModelPlain 4 | 5 | class ModelProgressive(ModelPlain): 6 | """Train with pixel loss and has multiple outputs""" 7 | def __init__(self, opt): 8 | super(ModelProgressive, self).__init__(opt) 9 | # ------------------------------------ 10 | # check the configuration of progressive training 11 | # ------------------------------------ 12 | self.opt = opt 13 | self.dist = self.opt['dist'] 14 | self.num_gpu = self.opt['num_gpu'] 15 | self.scale = self.opt['scale'] 16 | self.batch_size = self.opt['datasets']['train'].get('dataloader_batch_size')//self.num_gpu if self.dist else \ 17 | self.opt['datasets']['train'].get('dataloader_batch_size') 18 | self.H_size = self.opt['datasets']['train'].get('H_size') 19 | mini_batch_sizes = self.opt['datasets']['train'].get('mini_batch_sizes') 20 | self.mini_batch_sizes = [bs // self.num_gpu for bs in mini_batch_sizes] if self.dist else mini_batch_sizes 21 | iters_milestones = self.opt['datasets']['train'].get('iters_milestones') 22 | self.mini_H_sizes = self.opt['datasets']['train'].get('mini_H_sizes') 23 | assert self.mini_batch_sizes and iters_milestones and self.mini_H_sizes, 'Error: Key progressive is empty.' 24 | assert len(self.mini_batch_sizes) == len(iters_milestones) and \ 25 | len(self.mini_H_sizes) == len(iters_milestones), 'Error: List mismatch - batch:{}, milestone:{}, H_size{}'.format(\ 26 | len(self.mini_batch_sizes), len(iters_milestones), len(self.mini_H_sizes)) 27 | 28 | self.iters_milestones = np.array([sum(iters_milestones[0:i + 1]) for i in range(0, len(iters_milestones))]) 29 | 30 | """ 31 | # ---------------------------------------- 32 | # Optimization during training with data 33 | # Testing/evaluation 34 | # ---------------------------------------- 35 | """ 36 | 37 | # ---------------------------------------- 38 | # feed L/H data 39 | # ---------------------------------------- 40 | def feed_data(self, data, iter, need_H=True): 41 | L = data['L'] 42 | if need_H: 43 | H = data['H'] 44 | 45 | # progressive changing the batch size and patch size during the training 46 | progressive_seq = ((iter > self.iters_milestones) != True).nonzero()[0] 47 | progressive_idx = len(self.iters_milestones) - 1 if len(progressive_seq) == 0 else progressive_seq[0] 48 | 49 | mini_batch_size = self.mini_batch_sizes[progressive_idx] 50 | mini_H_size = self.mini_H_sizes[progressive_idx] 51 | 52 | assert mini_batch_size < self.batch_size, ValueError('Mini Batch Size {} must smaller than Batch Size {}'.format(mini_batch_size, self.batch_size)) 53 | assert mini_H_size < self.H_size, ValueError('Mini H Size {} must smaller than H Size {}'.format(mini_H_size, self.H_size)) 54 | 55 | if mini_batch_size < self.batch_size: 56 | batch_pick = random.sample(range(0, self.batch_size), k=mini_batch_size) 57 | L = L[batch_pick] 58 | if need_H: 59 | H = H[batch_pick] 60 | 61 | if mini_H_size < self.H_size: 62 | x0 = int((self.H_size - mini_H_size) * random.random()) 63 | y0 = int((self.H_size - mini_H_size) * random.random()) 64 | L = L[..., x0 : x0 + mini_H_size, y0 : y0 + mini_H_size] 65 | if need_H: 66 | H = H[..., x0*self.scale : (x0+mini_H_size)*self.scale, y0*self.scale : (y0+mini_H_size)*self.scale] 67 | 68 | self.L = L.to(self.device) 69 | if need_H: 70 | self.H = H.to(self.device) 71 | -------------------------------------------------------------------------------- /models/network_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | 5 | 6 | """ 7 | # -------------------------------------------- 8 | # VGG Feature Extractor 9 | # -------------------------------------------- 10 | """ 11 | 12 | # -------------------------------------------- 13 | # VGG features 14 | # Assume input range is [0, 1] 15 | # -------------------------------------------- 16 | class VGGFeatureExtractor(nn.Module): 17 | def __init__(self, 18 | feature_layer=34, 19 | use_bn=False, 20 | use_input_norm=True, 21 | device=torch.device('cpu')): 22 | super(VGGFeatureExtractor, self).__init__() 23 | if use_bn: 24 | model = torchvision.models.vgg19_bn(pretrained=True) 25 | else: 26 | model = torchvision.models.vgg19(pretrained=True) 27 | self.use_input_norm = use_input_norm 28 | if self.use_input_norm: 29 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 30 | # [0.485-1, 0.456-1, 0.406-1] if input in range [-1,1] 31 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 32 | # [0.229*2, 0.224*2, 0.225*2] if input in range [-1,1] 33 | self.register_buffer('mean', mean) 34 | self.register_buffer('std', std) 35 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 36 | # No need to BP to variable 37 | for k, v in self.features.named_parameters(): 38 | v.requires_grad = False 39 | 40 | def forward(self, x): 41 | if self.use_input_norm: 42 | x = (x - self.mean) / self.std 43 | output = self.features(x) 44 | return output 45 | 46 | 47 | -------------------------------------------------------------------------------- /models/network_fftformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numbers 5 | from einops import rearrange 6 | from torchstat import stat # complexity evaluation 7 | 8 | def to_3d(x): 9 | return rearrange(x, 'b c h w -> b (h w) c') 10 | 11 | 12 | def to_4d(x, h, w): 13 | return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 14 | 15 | 16 | class BiasFree_LayerNorm(nn.Module): 17 | def __init__(self, normalized_shape): 18 | super(BiasFree_LayerNorm, self).__init__() 19 | if isinstance(normalized_shape, numbers.Integral): 20 | normalized_shape = (normalized_shape,) 21 | normalized_shape = torch.Size(normalized_shape) 22 | 23 | assert len(normalized_shape) == 1 24 | 25 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 26 | self.normalized_shape = normalized_shape 27 | 28 | def forward(self, x): 29 | sigma = x.var(-1, keepdim=True, unbiased=False) 30 | return x / torch.sqrt(sigma + 1e-5) * self.weight 31 | 32 | 33 | class WithBias_LayerNorm(nn.Module): 34 | def __init__(self, normalized_shape): 35 | super(WithBias_LayerNorm, self).__init__() 36 | if isinstance(normalized_shape, numbers.Integral): 37 | normalized_shape = (normalized_shape,) 38 | normalized_shape = torch.Size(normalized_shape) 39 | 40 | assert len(normalized_shape) == 1 41 | 42 | self.weight = nn.Parameter(torch.ones(normalized_shape)) 43 | self.bias = nn.Parameter(torch.zeros(normalized_shape)) 44 | self.normalized_shape = normalized_shape 45 | 46 | def forward(self, x): 47 | mu = x.mean(-1, keepdim=True) 48 | sigma = x.var(-1, keepdim=True, unbiased=False) 49 | return (x - mu) / torch.sqrt(sigma + 1e-5) * self.weight + self.bias 50 | 51 | 52 | class LayerNorm(nn.Module): 53 | def __init__(self, dim, LayerNorm_type): 54 | super(LayerNorm, self).__init__() 55 | if LayerNorm_type == 'BiasFree': 56 | self.body = BiasFree_LayerNorm(dim) 57 | else: 58 | self.body = WithBias_LayerNorm(dim) 59 | 60 | def forward(self, x): 61 | h, w = x.shape[-2:] 62 | return to_4d(self.body(to_3d(x)), h, w) 63 | 64 | 65 | class DFFN(nn.Module): 66 | def __init__(self, dim, ffn_expansion_factor, bias): 67 | 68 | super(DFFN, self).__init__() 69 | 70 | hidden_features = int(dim * ffn_expansion_factor) 71 | 72 | self.patch_size = 8 73 | 74 | self.dim = dim 75 | self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias) 76 | 77 | self.dwconv = nn.Conv2d(hidden_features * 2, hidden_features * 2, kernel_size=3, stride=1, padding=1, 78 | groups=hidden_features * 2, bias=bias) 79 | 80 | self.fft = nn.Parameter(torch.ones((hidden_features * 2, 1, 1, self.patch_size, self.patch_size // 2 + 1))) 81 | self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias) 82 | 83 | def forward(self, x): 84 | x = self.project_in(x) 85 | x_patch = rearrange(x, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, 86 | patch2=self.patch_size) 87 | x_patch_fft = torch.fft.rfft2(x_patch.float()) 88 | x_patch_fft = x_patch_fft * self.fft 89 | x_patch = torch.fft.irfft2(x_patch_fft, s=(self.patch_size, self.patch_size)) 90 | x = rearrange(x_patch, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, 91 | patch2=self.patch_size) 92 | x1, x2 = self.dwconv(x).chunk(2, dim=1) 93 | 94 | x = F.gelu(x1) * x2 95 | x = self.project_out(x) 96 | return x 97 | 98 | 99 | class FSAS(nn.Module): 100 | def __init__(self, dim, bias): 101 | super(FSAS, self).__init__() 102 | 103 | self.to_hidden = nn.Conv2d(dim, dim * 6, kernel_size=1, bias=bias) 104 | self.to_hidden_dw = nn.Conv2d(dim * 6, dim * 6, kernel_size=3, stride=1, padding=1, groups=dim * 6, bias=bias) 105 | 106 | self.project_out = nn.Conv2d(dim * 2, dim, kernel_size=1, bias=bias) 107 | 108 | self.norm = LayerNorm(dim * 2, LayerNorm_type='WithBias') 109 | 110 | self.patch_size = 8 111 | 112 | def forward(self, x): 113 | hidden = self.to_hidden(x) 114 | 115 | q, k, v = self.to_hidden_dw(hidden).chunk(3, dim=1) 116 | 117 | q_patch = rearrange(q, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, 118 | patch2=self.patch_size) 119 | k_patch = rearrange(k, 'b c (h patch1) (w patch2) -> b c h w patch1 patch2', patch1=self.patch_size, 120 | patch2=self.patch_size) 121 | q_fft = torch.fft.rfft2(q_patch.float()) 122 | k_fft = torch.fft.rfft2(k_patch.float()) 123 | 124 | out = q_fft * k_fft 125 | out = torch.fft.irfft2(out, s=(self.patch_size, self.patch_size)) 126 | out = rearrange(out, 'b c h w patch1 patch2 -> b c (h patch1) (w patch2)', patch1=self.patch_size, 127 | patch2=self.patch_size) 128 | 129 | out = self.norm(out) 130 | 131 | output = v * out 132 | output = self.project_out(output) 133 | 134 | return output 135 | 136 | 137 | ########################################################################## 138 | class TransformerBlock(nn.Module): 139 | def __init__(self, dim, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', att=False): 140 | super(TransformerBlock, self).__init__() 141 | 142 | self.att = att 143 | if self.att: 144 | self.norm1 = LayerNorm(dim, LayerNorm_type) 145 | self.attn = FSAS(dim, bias) 146 | 147 | self.norm2 = LayerNorm(dim, LayerNorm_type) 148 | self.ffn = DFFN(dim, ffn_expansion_factor, bias) 149 | 150 | def forward(self, x): 151 | if self.att: 152 | x = x + self.attn(self.norm1(x)) 153 | 154 | x = x + self.ffn(self.norm2(x)) 155 | 156 | return x 157 | 158 | 159 | class Fuse(nn.Module): 160 | def __init__(self, n_feat): 161 | super(Fuse, self).__init__() 162 | self.n_feat = n_feat 163 | self.att_channel = TransformerBlock(dim=n_feat * 2) 164 | 165 | self.conv = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0) 166 | self.conv2 = nn.Conv2d(n_feat * 2, n_feat * 2, 1, 1, 0) 167 | 168 | def forward(self, enc, dnc): 169 | x = self.conv(torch.cat((enc, dnc), dim=1)) 170 | x = self.att_channel(x) 171 | x = self.conv2(x) 172 | e, d = torch.split(x, [self.n_feat, self.n_feat], dim=1) 173 | output = e + d 174 | 175 | return output 176 | 177 | 178 | ########################################################################## 179 | ## Overlapped image patch embedding with 3x3 Conv 180 | class OverlapPatchEmbed(nn.Module): 181 | def __init__(self, in_c=3, embed_dim=48, bias=False): 182 | super(OverlapPatchEmbed, self).__init__() 183 | 184 | self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=3, stride=1, padding=1, bias=bias) 185 | 186 | def forward(self, x): 187 | x = self.proj(x) 188 | 189 | return x 190 | 191 | 192 | ########################################################################## 193 | ## Resizing modules 194 | class Downsample(nn.Module): 195 | def __init__(self, n_feat): 196 | super(Downsample, self).__init__() 197 | 198 | self.body = nn.Sequential(nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=False), 199 | nn.Conv2d(n_feat, n_feat * 2, 3, stride=1, padding=1, bias=False)) 200 | 201 | def forward(self, x): 202 | return self.body(x) 203 | 204 | 205 | class Upsample(nn.Module): 206 | def __init__(self, n_feat): 207 | super(Upsample, self).__init__() 208 | 209 | self.body = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False), 210 | nn.Conv2d(n_feat, n_feat // 2, 3, stride=1, padding=1, bias=False)) 211 | 212 | def forward(self, x): 213 | return self.body(x) 214 | 215 | 216 | ########################################################################## 217 | ##---------- FFTformer ----------------------- 218 | class fftformer(nn.Module): 219 | def __init__(self, 220 | inp_channels=3, 221 | out_channels=3, 222 | dim=48, 223 | num_blocks=[6, 6, 12, 8], 224 | num_refinement_blocks=4, 225 | ffn_expansion_factor=3, 226 | bias=False, 227 | ): 228 | super(fftformer, self).__init__() 229 | 230 | self.patch_embed = OverlapPatchEmbed(inp_channels, dim) 231 | 232 | self.encoder_level1 = nn.Sequential(*[ 233 | TransformerBlock(dim=dim, ffn_expansion_factor=ffn_expansion_factor, bias=bias) for i in 234 | range(num_blocks[0])]) 235 | 236 | self.down1_2 = Downsample(dim) 237 | self.encoder_level2 = nn.Sequential(*[ 238 | TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor, 239 | bias=bias) for i in range(num_blocks[1])]) 240 | 241 | self.down2_3 = Downsample(int(dim * 2 ** 1)) 242 | self.encoder_level3 = nn.Sequential(*[ 243 | TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor, 244 | bias=bias) for i in range(num_blocks[2])]) 245 | 246 | self.decoder_level3 = nn.Sequential(*[ 247 | TransformerBlock(dim=int(dim * 2 ** 2), ffn_expansion_factor=ffn_expansion_factor, 248 | bias=bias, att=True) for i in range(num_blocks[2])]) 249 | 250 | self.up3_2 = Upsample(int(dim * 2 ** 2)) 251 | self.reduce_chan_level2 = nn.Conv2d(int(dim * 2 ** 2), int(dim * 2 ** 1), kernel_size=1, bias=bias) 252 | self.decoder_level2 = nn.Sequential(*[ 253 | TransformerBlock(dim=int(dim * 2 ** 1), ffn_expansion_factor=ffn_expansion_factor, 254 | bias=bias, att=True) for i in range(num_blocks[1])]) 255 | 256 | self.up2_1 = Upsample(int(dim * 2 ** 1)) 257 | 258 | self.decoder_level1 = nn.Sequential(*[ 259 | TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor, 260 | bias=bias, att=True) for i in range(num_blocks[0])]) 261 | 262 | self.refinement = nn.Sequential(*[ 263 | TransformerBlock(dim=int(dim), ffn_expansion_factor=ffn_expansion_factor, 264 | bias=bias, att=True) for i in range(num_refinement_blocks)]) 265 | 266 | self.fuse2 = Fuse(dim * 2) 267 | self.fuse1 = Fuse(dim) 268 | self.output = nn.Conv2d(int(dim), out_channels, kernel_size=3, stride=1, padding=1, bias=bias) 269 | 270 | def forward(self, inp_img): 271 | inp_enc_level1 = self.patch_embed(inp_img) 272 | out_enc_level1 = self.encoder_level1(inp_enc_level1) 273 | 274 | inp_enc_level2 = self.down1_2(out_enc_level1) 275 | out_enc_level2 = self.encoder_level2(inp_enc_level2) 276 | 277 | inp_enc_level3 = self.down2_3(out_enc_level2) 278 | out_enc_level3 = self.encoder_level3(inp_enc_level3) 279 | 280 | out_dec_level3 = self.decoder_level3(out_enc_level3) 281 | 282 | inp_dec_level2 = self.up3_2(out_dec_level3) 283 | 284 | inp_dec_level2 = self.fuse2(inp_dec_level2, out_enc_level2) 285 | 286 | out_dec_level2 = self.decoder_level2(inp_dec_level2) 287 | 288 | inp_dec_level1 = self.up2_1(out_dec_level2) 289 | 290 | inp_dec_level1 = self.fuse1(inp_dec_level1, out_enc_level1) 291 | out_dec_level1 = self.decoder_level1(inp_dec_level1) 292 | 293 | out_dec_level1 = self.refinement(out_dec_level1) 294 | 295 | out_dec_level1 = self.output(out_dec_level1) + inp_img 296 | 297 | return out_dec_level1 298 | 299 | def get_parameter_number(model, input_size=(3, 256, 256)): 300 | stat(model, input_size) 301 | 302 | if __name__ == "__main__": 303 | net = fftformer() 304 | get_parameter_number(net) -------------------------------------------------------------------------------- /models/network_pqnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from torchstat import stat # complexity evaluation 6 | 7 | class BasicBlock(nn.Module): 8 | def __init__(self, in_channel, out_channel, kernel_size, stride, bias=True, norm=False, relu=True, transpose=False): 9 | super(BasicBlock, self).__init__() 10 | if bias and norm: 11 | bias = False 12 | 13 | padding = kernel_size // 2 14 | layers = list() 15 | if transpose: 16 | padding = kernel_size // 2 -1 17 | layers.append(nn.ConvTranspose2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 18 | else: 19 | layers.append( 20 | nn.Conv2d(in_channel, out_channel, kernel_size, padding=padding, stride=stride, bias=bias)) 21 | if norm: 22 | layers.append(nn.BatchNorm2d(out_channel)) 23 | if relu: 24 | layers.append(nn.ReLU(inplace=True)) 25 | self.main = nn.Sequential(*layers) 26 | 27 | def forward(self, x): 28 | return self.main(x) 29 | 30 | class ResBlock(nn.Module): 31 | def __init__(self, in_channel, out_channel): 32 | super(ResBlock, self).__init__() 33 | self.main = nn.Sequential( 34 | BasicBlock(in_channel, out_channel, kernel_size=3, stride=1, relu=True), 35 | BasicBlock(out_channel, out_channel, kernel_size=3, stride=1, relu=False) 36 | ) 37 | 38 | def forward(self, x): 39 | return self.main(x) + x 40 | 41 | class EBlock(nn.Module): 42 | def __init__(self, out_channel, num_res=8): 43 | super(EBlock, self).__init__() 44 | 45 | layers = [ResBlock(out_channel, out_channel) for _ in range(num_res)] 46 | 47 | self.layers = nn.Sequential(*layers) 48 | 49 | def forward(self, x): 50 | return self.layers(x) 51 | 52 | 53 | class DBlock(nn.Module): 54 | def __init__(self, channel, num_res=8): 55 | super(DBlock, self).__init__() 56 | 57 | layers = [ResBlock(channel, channel) for _ in range(num_res)] 58 | self.layers = nn.Sequential(*layers) 59 | 60 | def forward(self, x): 61 | return self.layers(x) 62 | 63 | class MFF(nn.Module): 64 | def __init__(self, scale, out_channel, base_channel=64): 65 | super(MFF, self).__init__() 66 | self.conv_shuffle = nn.ModuleList([ 67 | BasicBlock(base_channel, base_channel, kernel_size=7, stride=1, relu=True), 68 | BasicBlock(base_channel*2, base_channel*4, kernel_size=5, stride=1, relu=True), 69 | BasicBlock(base_channel*8, base_channel*16, kernel_size=3, stride=1, relu=True), 70 | ]) 71 | 72 | self.shuffle = nn.ModuleList([ 73 | nn.PixelShuffle(1) if scale==1 else nn.PixelUnshuffle(scale), 74 | nn.PixelShuffle(int(2/scale)), 75 | nn.PixelShuffle(int(4/scale)) 76 | ]) 77 | 78 | self.conv_out = nn.Sequential( 79 | BasicBlock(3*(scale**2)*base_channel, out_channel, kernel_size=3, stride=1, relu=False) 80 | ) 81 | 82 | def forward(self, x1, x2, x4): 83 | x1_ = self.shuffle[0](self.conv_shuffle[0](x1)) 84 | x2_ = self.shuffle[1](self.conv_shuffle[1](x2)) 85 | x4_ = self.shuffle[2](self.conv_shuffle[2](x4)) 86 | 87 | x = torch.cat([x1_, x2_, x4_], dim=1) 88 | return self.conv_out(x) 89 | 90 | class Prior_Quantization(nn.Module): 91 | def __init__(self, in_channel, base_channel=64): 92 | super(Prior_Quantization, self).__init__() 93 | self.prior_extract = nn.Sequential( 94 | BasicBlock(in_channel, base_channel, kernel_size=7, relu=True, stride=1), 95 | BasicBlock(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 96 | BasicBlock(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 97 | ) 98 | 99 | def forward(self, p): 100 | return self.prior_extract(p) 101 | 102 | class PQNet(nn.Module): 103 | def __init__(self, num_res=8, base_channel=32): 104 | super(PQNet, self).__init__() 105 | 106 | self.Encoder = nn.ModuleList([ 107 | EBlock(base_channel, num_res), 108 | EBlock(base_channel*2, num_res), 109 | EBlock(base_channel*8, num_res), 110 | ]) 111 | 112 | self.feat_extract = nn.ModuleList([ 113 | BasicBlock(3, base_channel, kernel_size=7, relu=True, stride=1), 114 | BasicBlock(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 115 | BasicBlock(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 116 | BasicBlock(base_channel*8, base_channel*4, kernel_size=4, relu=True, stride=2, transpose=True), 117 | BasicBlock(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 118 | BasicBlock(base_channel*2, 3, kernel_size=3, relu=False, stride=1) 119 | ]) 120 | 121 | self.Decoder = nn.ModuleList([ 122 | DBlock(base_channel * 8, num_res), 123 | DBlock(base_channel * 4, num_res), 124 | DBlock(base_channel * 2, num_res) 125 | ]) 126 | 127 | self.Convs = nn.ModuleList([ 128 | BasicBlock(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 129 | BasicBlock(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 130 | ]) 131 | 132 | self.ConvsOut = nn.ModuleList( 133 | [ 134 | BasicBlock(base_channel * 8, 3, kernel_size=3, relu=False, stride=1), 135 | BasicBlock(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 136 | ] 137 | ) 138 | 139 | self.MFFs = nn.ModuleList([ 140 | MFF(scale=1, out_channel=base_channel*2, base_channel=base_channel), 141 | MFF(scale=2, out_channel=base_channel*4, base_channel=base_channel), 142 | ]) 143 | 144 | self.PQ = Prior_Quantization(in_channel=20, base_channel=base_channel) 145 | 146 | def forward(self, x, p=torch.rand(1, 20, 224, 224)): 147 | 148 | x_2 = F.interpolate(x, scale_factor=0.5) 149 | x_4 = F.interpolate(x_2, scale_factor=0.5) 150 | outputs = list() 151 | 152 | x_ = self.feat_extract[0](x) 153 | res1 = self.Encoder[0](x_) 154 | 155 | z = self.feat_extract[1](res1) 156 | res2 = self.Encoder[1](z) 157 | 158 | z = self.feat_extract[2](res2) 159 | p_ = self.PQ(p) 160 | z = torch.cat([z, p_], dim=1) 161 | res3 = self.Encoder[2](z) 162 | 163 | res1_ = self.MFFs[0](res1, res2, res3) 164 | res2_ = self.MFFs[1](res1, res2, res3) 165 | 166 | z = self.Decoder[0](res3) 167 | z_ = self.ConvsOut[0](z) 168 | outputs.append(z_ + x_4) 169 | 170 | z = self.feat_extract[3](z) 171 | z = torch.cat([z, res2_], dim=1) 172 | z = self.Convs[0](z) 173 | z = self.Decoder[1](z) 174 | z_ = self.ConvsOut[1](z) 175 | outputs.append(z_+x_2) 176 | 177 | z = self.feat_extract[4](z) 178 | z = torch.cat([z, res1_], dim=1) 179 | z = self.Convs[1](z) 180 | z = self.Decoder[2](z) 181 | z = self.feat_extract[5](z) 182 | outputs.append(z+x) 183 | 184 | return outputs 185 | 186 | class PQNetPlus(nn.Module): 187 | def __init__(self, num_res=20, base_channel=32): 188 | super(PQNetPlus, self).__init__() 189 | 190 | self.Encoder = nn.ModuleList([ 191 | EBlock(base_channel, num_res), 192 | EBlock(base_channel*2, num_res), 193 | EBlock(base_channel*8, num_res), 194 | ]) 195 | 196 | self.feat_extract = nn.ModuleList([ 197 | BasicBlock(3, base_channel, kernel_size=7, relu=True, stride=1), 198 | BasicBlock(base_channel, base_channel*2, kernel_size=3, relu=True, stride=2), 199 | BasicBlock(base_channel*2, base_channel*4, kernel_size=3, relu=True, stride=2), 200 | BasicBlock(base_channel*8, base_channel*4, kernel_size=4, relu=True, stride=2, transpose=True), 201 | BasicBlock(base_channel*4, base_channel*2, kernel_size=4, relu=True, stride=2, transpose=True), 202 | BasicBlock(base_channel*2, 3, kernel_size=3, relu=False, stride=1) 203 | ]) 204 | 205 | self.Decoder = nn.ModuleList([ 206 | DBlock(base_channel * 8, num_res), 207 | DBlock(base_channel * 4, num_res), 208 | DBlock(base_channel * 2, num_res) 209 | ]) 210 | 211 | self.Convs = nn.ModuleList([ 212 | BasicBlock(base_channel * 8, base_channel * 4, kernel_size=1, relu=True, stride=1), 213 | BasicBlock(base_channel * 4, base_channel * 2, kernel_size=1, relu=True, stride=1), 214 | ]) 215 | 216 | self.ConvsOut = nn.ModuleList( 217 | [ 218 | BasicBlock(base_channel * 8, 3, kernel_size=3, relu=False, stride=1), 219 | BasicBlock(base_channel * 4, 3, kernel_size=3, relu=False, stride=1), 220 | ] 221 | ) 222 | 223 | self.MFFs = nn.ModuleList([ 224 | MFF(scale=1, out_channel=base_channel*2, base_channel=base_channel), 225 | MFF(scale=2, out_channel=base_channel*4, base_channel=base_channel), 226 | ]) 227 | 228 | self.PQ = Prior_Quantization(in_channel=20, base_channel=base_channel) 229 | 230 | def forward(self, x, p=torch.rand(1, 20, 224, 224)): 231 | 232 | x_2 = F.interpolate(x, scale_factor=0.5) 233 | x_4 = F.interpolate(x_2, scale_factor=0.5) 234 | outputs = list() 235 | 236 | x_ = self.feat_extract[0](x) 237 | res1 = self.Encoder[0](x_) 238 | 239 | z = self.feat_extract[1](res1) 240 | res2 = self.Encoder[1](z) 241 | 242 | z = self.feat_extract[2](res2) 243 | p_ = self.PQ(p) 244 | z = torch.cat([z, p_], dim=1) 245 | res3 = self.Encoder[2](z) 246 | 247 | res1_ = self.MFFs[0](res1, res2, res3) 248 | res2_ = self.MFFs[1](res1, res2, res3) 249 | 250 | z = self.Decoder[0](res3) 251 | z_ = self.ConvsOut[0](z) 252 | outputs.append(z_ + x_4) 253 | 254 | z = self.feat_extract[3](z) 255 | z = torch.cat([z, res2_], dim=1) 256 | z = self.Convs[0](z) 257 | z = self.Decoder[1](z) 258 | z_ = self.ConvsOut[1](z) 259 | outputs.append(z_+x_2) 260 | 261 | z = self.feat_extract[4](z) 262 | z = torch.cat([z, res1_], dim=1) 263 | z = self.Convs[1](z) 264 | z = self.Decoder[2](z) 265 | z = self.feat_extract[5](z) 266 | outputs.append(z+x) 267 | 268 | return outputs 269 | 270 | def get_parameter_number(model, input_size=(3, 224, 224)): 271 | stat(model, input_size) 272 | 273 | if __name__ == '__main__': 274 | model = PQNet() 275 | get_parameter_number(model) 276 | -------------------------------------------------------------------------------- /models/network_rrdb.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import models.basicblock as B 4 | 5 | 6 | """ 7 | # -------------------------------------------- 8 | # SR network with Residual in Residual Dense Block (RRDB) 9 | # "ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks" 10 | # -------------------------------------------- 11 | """ 12 | 13 | 14 | class RRDB(nn.Module): 15 | """ 16 | gc: number of growth channels 17 | nb: number of RRDB 18 | """ 19 | def __init__(self, in_nc=3, out_nc=3, nc=64, nb=23, gc=32, upscale=4, act_mode='L', upsample_mode='upconv'): 20 | super(RRDB, self).__init__() 21 | assert 'R' in act_mode or 'L' in act_mode, 'Examples of activation function: R, L, BR, BL, IR, IL' 22 | 23 | n_upscale = int(math.log(upscale, 2)) 24 | if upscale == 3: 25 | n_upscale = 1 26 | 27 | m_head = B.conv(in_nc, nc, mode='C') 28 | 29 | m_body = [B.RRDB(nc, gc=32, mode='C'+act_mode) for _ in range(nb)] 30 | m_body.append(B.conv(nc, nc, mode='C')) 31 | 32 | if upsample_mode == 'upconv': 33 | upsample_block = B.upsample_upconv 34 | elif upsample_mode == 'pixelshuffle': 35 | upsample_block = B.upsample_pixelshuffle 36 | elif upsample_mode == 'convtranspose': 37 | upsample_block = B.upsample_convtranspose 38 | else: 39 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 40 | 41 | if upscale == 3: 42 | m_uper = upsample_block(nc, nc, mode='3'+act_mode) 43 | else: 44 | m_uper = [upsample_block(nc, nc, mode='2'+act_mode) for _ in range(n_upscale)] 45 | 46 | H_conv0 = B.conv(nc, nc, mode='C'+act_mode) 47 | H_conv1 = B.conv(nc, out_nc, mode='C') 48 | m_tail = B.sequential(H_conv0, H_conv1) 49 | 50 | self.model = B.sequential(m_head, B.ShortcutBlock(B.sequential(*m_body)), *m_uper, m_tail) 51 | 52 | def forward(self, x): 53 | x = self.model(x) 54 | return x 55 | -------------------------------------------------------------------------------- /models/network_rrdbnet.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | """ 7 | # -------------------------------------------- 8 | # RRDB for sr 9 | # -------------------------------------------- 10 | """ 11 | 12 | def make_layer(block, n_layers): 13 | layers = [] 14 | for _ in range(n_layers): 15 | layers.append(block()) 16 | return nn.Sequential(*layers) 17 | 18 | 19 | class ResidualDenseBlock_5C(nn.Module): 20 | def __init__(self, nc=64, gc=32, bias=True): 21 | super(ResidualDenseBlock_5C, self).__init__() 22 | # gc: growth channel, i.e. intermediate channels 23 | self.conv1 = nn.Conv2d(nc, gc, 3, 1, 1, bias=bias) 24 | self.conv2 = nn.Conv2d(nc + gc, gc, 3, 1, 1, bias=bias) 25 | self.conv3 = nn.Conv2d(nc + 2 * gc, gc, 3, 1, 1, bias=bias) 26 | self.conv4 = nn.Conv2d(nc + 3 * gc, gc, 3, 1, 1, bias=bias) 27 | self.conv5 = nn.Conv2d(nc + 4 * gc, nc, 3, 1, 1, bias=bias) 28 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 29 | 30 | # initialization 31 | # mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 32 | 33 | def forward(self, x): 34 | x1 = self.lrelu(self.conv1(x)) 35 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 36 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 37 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 38 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 39 | return x5 * 0.2 + x 40 | 41 | 42 | class RRDB(nn.Module): 43 | '''Residual in Residual Dense Block''' 44 | 45 | def __init__(self, nc, gc=32): 46 | super(RRDB, self).__init__() 47 | self.RDB1 = ResidualDenseBlock_5C(nc, gc) 48 | self.RDB2 = ResidualDenseBlock_5C(nc, gc) 49 | self.RDB3 = ResidualDenseBlock_5C(nc, gc) 50 | 51 | def forward(self, x): 52 | out = self.RDB1(x) 53 | out = self.RDB2(out) 54 | out = self.RDB3(out) 55 | return out * 0.2 + x 56 | 57 | 58 | class RRDBNet(nn.Module): 59 | def __init__(self, in_nc=3, out_nc=3, nc=64, nb=23, gc=32, sf=4): 60 | super(RRDBNet, self).__init__() 61 | RRDB_block_f = functools.partial(RRDB, nc=nc, gc=gc) 62 | self.sf = sf 63 | 64 | self.conv_first = nn.Conv2d(in_nc, nc, 3, 1, 1, bias=True) 65 | self.RRDB_trunk = make_layer(RRDB_block_f, nb) 66 | self.trunk_conv = nn.Conv2d(nc, nc, 3, 1, 1, bias=True) 67 | #### upsampling 68 | self.upconv1 = nn.Conv2d(nc, nc, 3, 1, 1, bias=True) 69 | if self.sf==4: 70 | self.upconv2 = nn.Conv2d(nc, nc, 3, 1, 1, bias=True) 71 | self.HRconv = nn.Conv2d(nc, nc, 3, 1, 1, bias=True) 72 | self.conv_last = nn.Conv2d(nc, out_nc, 3, 1, 1, bias=True) 73 | 74 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 75 | 76 | def forward(self, x): 77 | fea = self.conv_first(x) 78 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 79 | fea = fea + trunk 80 | 81 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 82 | if self.sf == 4: 83 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 84 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 85 | 86 | return out 87 | -------------------------------------------------------------------------------- /models/network_unet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import models.basicblock as B 4 | import numpy as np 5 | 6 | ''' 7 | # ==================== 8 | # Residual U-Net 9 | # ==================== 10 | ''' 11 | 12 | 13 | class UNetRes(nn.Module): 14 | def __init__(self, in_nc=3, out_nc=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', bias=True): 15 | super(UNetRes, self).__init__() 16 | 17 | self.m_head = B.conv(in_nc, nc[0], bias=bias, mode='C') 18 | 19 | # downsample 20 | if downsample_mode == 'avgpool': 21 | downsample_block = B.downsample_avgpool 22 | elif downsample_mode == 'maxpool': 23 | downsample_block = B.downsample_maxpool 24 | elif downsample_mode == 'strideconv': 25 | downsample_block = B.downsample_strideconv 26 | else: 27 | raise NotImplementedError('downsample mode [{:s}] is not found'.format(downsample_mode)) 28 | 29 | self.m_down1 = B.sequential(*[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[0], nc[1], bias=bias, mode='2')) 30 | self.m_down2 = B.sequential(*[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[1], nc[2], bias=bias, mode='2')) 31 | self.m_down3 = B.sequential(*[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)], downsample_block(nc[2], nc[3], bias=bias, mode='2')) 32 | 33 | self.m_body = B.sequential(*[B.ResBlock(nc[3], nc[3], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) 34 | 35 | # upsample 36 | if upsample_mode == 'upconv': 37 | upsample_block = B.upsample_upconv 38 | elif upsample_mode == 'pixelshuffle': 39 | upsample_block = B.upsample_pixelshuffle 40 | elif upsample_mode == 'convtranspose': 41 | upsample_block = B.upsample_convtranspose 42 | else: 43 | raise NotImplementedError('upsample mode [{:s}] is not found'.format(upsample_mode)) 44 | 45 | self.m_up3 = B.sequential(upsample_block(nc[3], nc[2], bias=bias, mode='2'), *[B.ResBlock(nc[2], nc[2], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) 46 | self.m_up2 = B.sequential(upsample_block(nc[2], nc[1], bias=bias, mode='2'), *[B.ResBlock(nc[1], nc[1], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) 47 | self.m_up1 = B.sequential(upsample_block(nc[1], nc[0], bias=bias, mode='2'), *[B.ResBlock(nc[0], nc[0], bias=bias, mode='C'+act_mode+'C') for _ in range(nb)]) 48 | 49 | self.m_tail = B.conv(nc[0], out_nc, bias=bias, mode='C') 50 | 51 | def forward(self, x0): 52 | # h, w = x.size()[-2:] 53 | # paddingBottom = int(np.ceil(h/8)*8-h) 54 | # paddingRight = int(np.ceil(w/8)*8-w) 55 | # x = nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(x) 56 | 57 | x1 = self.m_head(x0) 58 | x2 = self.m_down1(x1) 59 | x3 = self.m_down2(x2) 60 | x4 = self.m_down3(x3) 61 | x = self.m_body(x4) 62 | x = self.m_up3(x+x4) 63 | x = self.m_up2(x+x3) 64 | x = self.m_up1(x+x2) 65 | x = self.m_tail(x+x1) 66 | # x = x[..., :h, :w] 67 | 68 | return x 69 | 70 | 71 | if __name__ == '__main__': 72 | x = torch.rand(1,3,256,256) 73 | net = UNetRes() 74 | net.eval() 75 | with torch.no_grad(): 76 | y = net(x) 77 | print(y.size()) 78 | 79 | # run models/network_unet.py 80 | -------------------------------------------------------------------------------- /models/network_vapsr.py: -------------------------------------------------------------------------------- 1 | # VAst-receptive-field Pixel attention network 2 | import torch.nn as nn 3 | # from basicsr.utils.registry import ARCH_REGISTRY 4 | # from basicsr.archs.arch_util import default_init_weights 5 | 6 | class Attention(nn.Module): 7 | def __init__(self, dim): 8 | super().__init__() 9 | self.pointwise = nn.Conv2d(dim, dim, 1) 10 | self.depthwise = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) 11 | self.depthwise_dilated = nn.Conv2d(dim, dim, 5, stride=1, padding=6, groups=dim, dilation=3) 12 | 13 | def forward(self, x): 14 | u = x.clone() 15 | attn = self.pointwise(x) 16 | attn = self.depthwise(attn) 17 | attn = self.depthwise_dilated(attn) 18 | return u * attn 19 | 20 | class VAB(nn.Module): 21 | def __init__(self, d_model, d_atten): 22 | super().__init__() 23 | self.proj_1 = nn.Conv2d(d_model, d_atten, 1) 24 | self.activation = nn.GELU() 25 | self.atten_branch = Attention(d_atten) 26 | self.proj_2 = nn.Conv2d(d_atten, d_model, 1) 27 | self.pixel_norm = nn.LayerNorm(d_model) 28 | # default_init_weights([self.pixel_norm], 0.1) 29 | 30 | def forward(self, x): 31 | shorcut = x.clone() 32 | x = self.proj_1(x) 33 | x = self.activation(x) 34 | x = self.atten_branch(x) 35 | x = self.proj_2(x) 36 | x = x + shorcut 37 | 38 | x = x.permute(0, 2, 3, 1) #(B, H, W, C) 39 | x = self.pixel_norm(x) 40 | x = x.permute(0, 3, 1, 2).contiguous() #(B, C, H, W) 41 | 42 | return x 43 | 44 | def pixelshuffle(in_channels, out_channels, upscale_factor=4): 45 | upconv1 = nn.Conv2d(in_channels, 64, 3, 1, 1) 46 | pixel_shuffle = nn.PixelShuffle(2) 47 | upconv2 = nn.Conv2d(16, out_channels * 4, 3, 1, 1) 48 | lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 49 | return nn.Sequential(*[upconv1, pixel_shuffle, lrelu, upconv2, pixel_shuffle]) 50 | 51 | #both scale X2 and X3 use this version 52 | def pixelshuffle_single(in_channels, out_channels, upscale_factor=2): 53 | upconv1 = nn.Conv2d(in_channels, 56, 3, 1, 1) 54 | pixel_shuffle = nn.PixelShuffle(upscale_factor) 55 | upconv2 = nn.Conv2d(56, out_channels * upscale_factor * upscale_factor, 3, 1, 1) 56 | lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 57 | return nn.Sequential(*[upconv1, lrelu, upconv2, pixel_shuffle]) 58 | 59 | 60 | def make_layer(block, n_layers, *kwargs): 61 | layers = [] 62 | for _ in range(n_layers): 63 | layers.append(block(*kwargs)) 64 | return nn.Sequential(*layers) 65 | 66 | # @ARCH_REGISTRY.register() 67 | class vapsr(nn.Module): 68 | def __init__(self, num_in_ch, num_out_ch, scale=4, num_feat=64, num_block=23, d_atten=64, conv_groups=1): 69 | super(vapsr, self).__init__() 70 | self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) 71 | self.body = make_layer(VAB, num_block, num_feat, d_atten) 72 | self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=conv_groups) #conv_groups=2 for VapSR-S 73 | 74 | # upsample 75 | if scale == 4: 76 | self.upsampler = pixelshuffle(num_feat, num_out_ch, upscale_factor=scale) 77 | else: 78 | self.upsampler = pixelshuffle_single(num_feat, num_out_ch, upscale_factor=scale) 79 | 80 | def forward(self, feat): 81 | feat = self.conv_first(feat) 82 | body_feat = self.body(feat) 83 | body_out = self.conv_body(body_feat) 84 | feat = feat + body_out 85 | out = self.upsampler(feat) 86 | return out -------------------------------------------------------------------------------- /models/select_model.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | # -------------------------------------------- 4 | # define training model 5 | # -------------------------------------------- 6 | """ 7 | 8 | 9 | def define_Model(opt): 10 | model = opt['model'] # one input: L 11 | 12 | if model == 'plain': 13 | from models.model_plain import ModelPlain as M 14 | elif model == 'multiout': 15 | from models.model_multiout import ModelMultiout as M 16 | elif model == 'multiin': 17 | from models.model_multiin import ModelMultiin as M 18 | elif model == 'progressive': 19 | from models.model_progressive import ModelProgressive as M 20 | else: 21 | raise NotImplementedError('Model [{:s}] is not defined.'.format(model)) 22 | 23 | m = M(opt) 24 | 25 | return m -------------------------------------------------------------------------------- /options/option.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "unet" // root/task/images-models-options 3 | , "model": "plain" // "plain" 4 | , "gpu_ids": [1] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/denoising" // "denoising" | "superresolution" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 18 | , "dataroot_H": "./trainsets/trainH" // path of High-quality training dataset 19 | , "dataroot_L": "./trainsets/trainL" // path of Low-quality training dataset 20 | , "H_size": 40 // patch size 40 | 64 | 96 | 128 | 192 21 | 22 | , "sigma": 25 // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN 23 | , "sigma_test": 25 // 15, 25, 50 for DnCNN and ffdnet 24 | 25 | , "dataloader_shuffle": true 26 | , "dataloader_num_workers": 8 27 | , "dataloader_batch_size": 1 // batch size 1 | 16 | 32 | 48 | 64 | 128 28 | } 29 | , "test": { 30 | "name": "test_dataset" // just name 31 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 32 | , "dataroot_H": "./testsets/testH" // path of High-quality testing dataset 33 | , "dataroot_L": "./testsets/testL" // path of Low-quality testing dataset 34 | , "H_size": 256 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 35 | 36 | , "sigma": 25 // 15, 25, 50 for DnCNN | [0, 75] for FFDNet and FDnCNN 37 | , "sigma_test": 25 // 15, 25, 50 for DnCNN and ffdnet 38 | 39 | } 40 | } 41 | 42 | , "netG": { 43 | "net_type": "mimounet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "msrresnet0" | "msrresnet1" | "rrdb" 44 | , "in_nc": 1 // input channel number 45 | , "out_nc": 1 // ouput channel number 46 | , "nc": 64 // 96 for "dpsr", 128 for "srmd", 64 for "dncnn" and "rrdb" 47 | , "nb": 23 // 23 for "rrdb", 12 for "srmd", 15 for "ffdnet", 20 for "dncnn", 16 for "srresnet" and "dpsr" 48 | , "gc": 32 // unused 49 | , "ng": 2 // unused 50 | , "reduction" : 16 // unused 51 | , "act_mode": "R" // "BR" for BN+ReLU | "R" for ReLU 52 | , "upsample_mode": "upconv" // "pixelshuffle" | "convtranspose" | "upconv" 53 | , "downsample_mode": "strideconv" // "strideconv" | "avgpool" | "maxpool" 54 | 55 | , "init_type": "orthogonal" // "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 56 | , "init_bn_type": "uniform" // "uniform" | "constant" 57 | , "init_gain": 0.2 58 | } 59 | 60 | , "train": { 61 | "G_lossfn_type": "l1" // "l1" preferred | "l2sum" | "l2" | "ssim" 62 | , "G_lossfn_weight": 1.0 // default 63 | 64 | , "G_optimizer_type": "adam" // fixed, adam is enough 65 | , "G_optimizer_lr": 1e-4 // learning rate 66 | , "G_optimizer_clipgrad": null // unused 67 | 68 | , "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough 69 | , "G_scheduler_milestones": [200000, 400000, 600000, 800000, 1000000, 2000000] 70 | , "G_scheduler_gamma": 0.5 71 | 72 | , "G_regularizer_orthstep": null // unused 73 | , "G_regularizer_clipstep": null // unused 74 | 75 | , "checkpoint_test": 5000 // for testing 76 | , "checkpoint_save": 5000 // for saving model 77 | , "checkpoint_print": 200 // for print 78 | } 79 | } 80 | -------------------------------------------------------------------------------- /options/option_20230722.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "test_mff_module_in_mimounet" // root/task/images-models-options 3 | , "model": "plain_multiout" // "plain" | plain_multiout 4 | , "gpu_ids": [2] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset 20 | , "H_size": 224 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 21 | 22 | , "sigma": 25 // unused 23 | , "sigma_test": 25 // unused 24 | 25 | , "dataloader_shuffle": true 26 | , "dataloader_num_workers": 8 27 | , "dataloader_batch_size": 32 // batch size 1 | 16 | 32 | 48 | 64 | 128 28 | } 29 | , "test": { 30 | "name": "test_dataset" // just name 31 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 32 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset 33 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset 34 | 35 | , "sigma": 25 // unused 36 | , "sigma_test": 25 // unused 37 | 38 | , "dataloader_num_workers": 8 39 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 40 | 41 | } 42 | } 43 | 44 | , "netG": { 45 | "net_type": "mimounet-mff" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "msrresnet0" | "msrresnet1" | "rrdb" 46 | , "in_nc": 3 // input channel number 47 | , "out_nc": 3 // ouput channel number 48 | , "nc": 32 // 96 for "dpsr", 128 for "srmd", 64 for "dncnn" and "rrdb" 49 | , "nb": 8 // 23 for "rrdb", 12 for "srmd", 15 for "ffdnet", 20 for "dncnn", 16 for "srresnet" and "dpsr" 50 | , "gc": 32 // unused 51 | , "ng": 2 // unused 52 | , "reduction" : 16 // unused 53 | , "act_mode": "R" // unused "BR" for BN+ReLU | "R" for ReLU 54 | , "upsample_mode": "upconv" // unused "pixelshuffle" | "convtranspose" | "upconv" 55 | , "downsample_mode": "strideconv" // unused "strideconv" | "avgpool" | "maxpool" 56 | 57 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 58 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 59 | , "init_gain": 0.2 60 | } 61 | 62 | , "train": { 63 | "total_epoch": 300000 64 | ,"G_lossfn_type": "l1+fft" // "l1" preferred | "l2sum" | "l2" | "ssim" 65 | , "G_lossfn_weight": [1.0, 0.1] // default 66 | 67 | , "G_optimizer_type": "adam" // fixed, adam is enough 68 | , "G_optimizer_lr": 1e-4 // learning rate 69 | , "G_optimizer_clipgrad": null // unused 70 | 71 | , "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough 72 | , "G_scheduler_milestones": [20000, 40000, 60000, 80000, 100000, 200000] 73 | , "G_scheduler_gamma": 0.5 74 | 75 | , "G_regularizer_orthstep": null // unused 76 | , "G_regularizer_clipstep": null // unused 77 | 78 | , "checkpoint_test": 5000 // for testing 79 | , "checkpoint_save": 5000 // for saving model 80 | , "checkpoint_print": 200 // for print 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /options/option_20230724.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "test_mimounet_20230724" // root/task/images-models-options 3 | , "model": "plain_multiout" // "plain" | plain_multiout 4 | , "gpu_ids": [3] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset 20 | , "H_size": 224 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 21 | 22 | , "sigma": 25 // unused 23 | , "sigma_test": 25 // unused 24 | 25 | , "dataloader_shuffle": true 26 | , "dataloader_num_workers": 8 27 | , "dataloader_batch_size": 32 // batch size 1 | 16 | 32 | 48 | 64 | 128 28 | } 29 | , "test": { 30 | "name": "test_dataset" // just name 31 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 32 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset 33 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset 34 | 35 | , "sigma": 25 // unused 36 | , "sigma_test": 25 // unused 37 | 38 | , "dataloader_num_workers": 8 39 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 40 | 41 | } 42 | } 43 | 44 | , "netG": { 45 | "net_type": "mimounet" // "dncnn" | "fdncnn" | "ffdnet" | "srmd" | "dpsr" | "msrresnet0" | "msrresnet1" | "rrdb" 46 | , "in_nc": 3 // input channel number 47 | , "out_nc": 3 // ouput channel number 48 | , "nc": 32 // 96 for "dpsr", 128 for "srmd", 64 for "dncnn" and "rrdb" 49 | , "nb": 8 // 23 for "rrdb", 12 for "srmd", 15 for "ffdnet", 20 for "dncnn", 16 for "srresnet" and "dpsr" 50 | , "gc": 32 // unused 51 | , "ng": 2 // unused 52 | , "reduction" : 16 // unused 53 | , "act_mode": "R" // unused "BR" for BN+ReLU | "R" for ReLU 54 | , "upsample_mode": "upconv" // unused "pixelshuffle" | "convtranspose" | "upconv" 55 | , "downsample_mode": "strideconv" // unused "strideconv" | "avgpool" | "maxpool" 56 | 57 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 58 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 59 | , "init_gain": 0.2 60 | } 61 | 62 | , "train": { 63 | "total_epoch": 300000 64 | ,"G_lossfn_type": "l1+fft" // "l1" preferred | "l2sum" | "l2" | "ssim" 65 | , "G_lossfn_weight": [1.0, 0.1] // default 66 | 67 | , "G_optimizer_type": "adam" // fixed, adam is enough 68 | , "G_optimizer_lr": 1e-4 // learning rate 69 | , "G_optimizer_clipgrad": null // unused 70 | 71 | , "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough 72 | , "G_scheduler_milestones": [20000, 40000, 60000, 80000, 100000, 200000] 73 | , "G_scheduler_gamma": 0.5 74 | 75 | , "G_regularizer_orthstep": null // unused 76 | , "G_regularizer_clipstep": null // unused 77 | 78 | , "checkpoint_test": 5000 // for testing 79 | , "checkpoint_save": 5000 // for saving model 80 | , "checkpoint_print": 200 // for print 81 | } 82 | } 83 | -------------------------------------------------------------------------------- /options/option_official_implementation_fftformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_fftformer" // root/task/images-models-options 3 | , "model": "plain" // "plain" | "multiout" | "progressive" 4 | , "gpu_ids": [4, 5] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "plain" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset (prefered full path) 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset (prefered full path) 20 | , "H_size": 128 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 21 | 22 | , "dataloader_shuffle": true 23 | , "dataloader_num_workers": 8 24 | , "dataloader_batch_size": 8 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 25 | 26 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 27 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 28 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 29 | } 30 | , "valid": { 31 | "name": "valid_dataset" // just name 32 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 33 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset (prefered full path) 34 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset (prefered full path) 35 | 36 | , "sigma": 25 // unused 37 | , "sigma_valid": 25 // unused 38 | 39 | , "dataloader_num_workers": 8 40 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 41 | } 42 | , "test": { 43 | "name": "test_dataset" // unused 44 | , "dataset_type": "plain" // dataset type 45 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 46 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 47 | 48 | , "sigma": 25 // unused 49 | , "sigma_valid": 25 // unused 50 | 51 | , "dataloader_num_workers": 0 52 | , "dataloader_batch_size": 1 // batch size 1 53 | } 54 | } 55 | 56 | , "netG": { 57 | "net_type": "fftformer" // "mimounet" | "mimounetplus" | "restormer" | "uformer" | "nafnet"("nafnet local" for test) | "fftformer" 58 | , "in_nc": 3 // input channel number 59 | , "out_nc": 3 // ouput channel number 60 | , "nc": 48 // basic hidden dim or base channel and 16 for uformer_tiny 61 | , "nb": [6, 6, 12] // number of blocks (list if different scales) 62 | , "n_refine_b": 4 // number of refinement blocks 63 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 64 | , "ffn_expansion_factor": 3 // hidden dim expanded in Gated-Dconv Network 65 | , "bias": false // bias in qkv generation 66 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 67 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 68 | , "token_mlp": "leff" // token of mlp in uformer 69 | , "enc_blk_nums": [1, 1, 1, 28]// encoder block number of nafnet 70 | , "dec_blk_nums": [1, 1, 1, 1] // decoder block number of nafnet 71 | 72 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 73 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 74 | , "init_gain": 0.2 75 | } 76 | 77 | , "train": { 78 | "total_epoch": 30000 79 | ,"G_lossfn_type": "psnr" // "l1" | "l2sum" | "l2" | "ssim" | "psnr" | "charbonnier" | 'l1+ssim' | 'l1+fft' 80 | , "G_lossfn_weight": [1.0] // default 81 | 82 | , "G_optimizer_type": "adamw" // fixed, adam is enough 83 | , "G_optimizer_lr": 1e-3 // learning rate 84 | , "G_optimizer_betas": [0.9, 0.9] // beta 85 | , "G_optimizer_wd": 1e-3 // weight decay 86 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 87 | 88 | , "G_scheduler_type": "CosineAnnealingLR" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" | "GradualWarmupScheduler" | "CosineAnnealingLR" 89 | , "G_scheduler_milestones": [5000, 10000, 20000, 25000, 30000] // for "MultiStepLR" 90 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 91 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 92 | , "G_scheduler_eta_min": 1e-7 // for "CosineAnnealingWarmRestarts" 93 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 94 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 95 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 96 | , "G_scheduler_multiplier": 1 // for "GradualWarmupScheduler" 97 | , "G_scheduler_warmup_epochs": 3 // for "GradualWarmupScheduler" 98 | 99 | , "G_regularizer_orthstep": null // unused 100 | , "G_regularizer_clipstep": null // unused 101 | 102 | , "checkpoint_valid": 1 // for validating per N epoch 103 | , "checkpoint_save": 1 // for saving model per N epoch 104 | , "checkpoint_print": 200 // for print every iteration 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /options/option_official_implementation_fsanet.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_fsanet" // root/task/images-models-options 3 | , "model": "multiin" // "plain" | "multiout" | "progressive" | "multiin" 4 | , "gpu_ids": [3] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/aberration_correction" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "multiin" // "plain" | "multimodal" | "multiin" 18 | , "dataroot_H": "/data1/Aberration_Correction/train/target_crops" // path of High-quality training dataset (prefered full path) 19 | , "dataroot_L": "/data1/Aberration_Correction/train/input_crops" // path of Low-quality training dataset (prefered full path) 20 | , "H_size": 200 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 21 | 22 | , "dataloader_shuffle": true 23 | , "dataloader_num_workers": 8 24 | , "dataloader_batch_size": 8 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 25 | 26 | // hyper-parameters for progressive training 27 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 28 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 29 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 30 | 31 | // hyper-parameters for fsanet 32 | , "dataroot_K": "/data1/Aberration_Correction/kernels" // path of the kernels 33 | , "kr_num": 30 // kernel number of fsanet 34 | , "eptional_name": "gauss_fit" // basis to fit the SNR in weiner filter 35 | , "en_size": 75 // encode size of the eptional, higher for higher frequency but larger computation 36 | , "sig_low": 0.2 // eptional fitted by gaussian,the lower boundary 37 | , "sig_high": 2.5 // eptional fitted by gaussian,the upper boundary 38 | , "sig_num": 29 // number of gaussian basis 39 | } 40 | , "valid": { 41 | "name": "valid_dataset" // just name 42 | , "dataset_type": "multiin" // "plain" | "multimodal" | "multiin" 43 | , "dataroot_H": "/data1/Aberration_Correction/val/target_crops" // path of High-quality testing dataset (prefered full path) 44 | , "dataroot_L": "/data1/Aberration_Correction/val/input_crops" // path of Low-quality testing dataset (prefered full path) 45 | 46 | , "sigma": 25 // unused 47 | , "sigma_valid": 25 // unused 48 | 49 | , "dataloader_num_workers": 8 50 | , "dataloader_batch_size": 8 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 51 | 52 | // hyper-parameters for fsanet 53 | , "dataroot_K": "/data1/Aberration_Correction/kernels" // path of the kernels 54 | , "kr_num": 30 // kernel number of fsanet 55 | , "eptional_name": "gauss_fit" // basis to fit the SNR in weiner filter 56 | , "en_size": 75 // encode size of the eptional, higher for higher frequency but larger computation 57 | , "sig_low": 0.2 // eptional fitted by gaussian,the lower boundary 58 | , "sig_high": 2.5 // eptional fitted by gaussian,the upper boundary 59 | , "sig_num": 29 // number of gaussian basis 60 | } 61 | , "test": { 62 | "name": "test_dataset" // unused 63 | , "dataset_type": "multiin" // dataset type 64 | , "dataroot_H": "/data1/Aberration_Correction/test/target" // path of High-quality testing dataset 65 | , "dataroot_L": "/data1/Aberration_Correction/test/input" // path of Low-quality testing dataset 66 | 67 | , "sigma": 25 // unused 68 | , "sigma_valid": 25 // unused 69 | 70 | , "dataloader_num_workers": 0 71 | , "dataloader_batch_size": 1 // batch size 1 72 | 73 | // hyper-parameters for fsanet 74 | , "dataroot_K": "/data1/Aberration_Correction/kernels" // path of the kernels 75 | , "kr_num": 30 // kernel number of fsanet 76 | , "eptional_name": "gauss_fit" // basis to fit the SNR in weiner filter 77 | , "en_size": 75 // encode size of the eptional, higher for higher frequency but larger computation 78 | , "sig_low": 0.2 // eptional fitted by gaussian,the lower boundary 79 | , "sig_high": 2.5 // eptional fitted by gaussian,the upper boundary 80 | , "sig_num": 29 // number of gaussian basis 81 | } 82 | } 83 | 84 | , "netG": { 85 | "net_type": "fsanet" // "mimounet" | "mimounetplus" | "restormer" | "uformer" | "nafnet"("nafnet local" for test) | "fftformer" | "fsanet" 86 | , "in_nc": 5 // input channel number 87 | , "out_nc": 3 // ouput channel number 88 | , "nc": 48 // basic hidden dim or base channel and 16 for uformer_tiny 89 | , "nb": [6, 6, 12] // number of blocks (list if different scales) 90 | , "n_refine_b": 4 // number of refinement blocks 91 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 92 | , "ffn_expansion_factor": 3 // hidden dim expanded in Gated-Dconv Network 93 | , "bias": false // bias in qkv generation 94 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 95 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 96 | , "token_mlp": "leff" // token of mlp in uformer 97 | , "enc_blk_nums": [1, 1, 1, 28]// encoder block number of nafnet 98 | , "dec_blk_nums": [1, 1, 1, 1] // decoder block number of nafnet 99 | 100 | , "kr_num": 30 // kernel number of fsanet 101 | , "sig_num": 30 // number of gaussian basis for fsanet 102 | 103 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 104 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 105 | , "init_gain": 0.2 106 | } 107 | 108 | , "train": { 109 | "total_epoch": 3000 110 | ,"G_lossfn_type": "l2+perc" // "l1" | "l2sum" | "l2" | "ssim" | "psnr" | "charbonnier" | 'l1+ssim' | 'l1+fft' | 'l2+perc' 111 | , "G_lossfn_weight": [1.0, 1e-4] // default 112 | 113 | , "G_optimizer_type": "adam" // fixed, adam is enough, adamw is for transformer 114 | , "G_optimizer_lr": 1e-4 // learning rate 115 | , "G_optimizer_betas": [0.9, 0.999] // beta 116 | , "G_optimizer_wd": 1e-3 // weight decay 117 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 118 | 119 | , "G_scheduler_type": "MultiStepLR" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" | "GradualWarmupScheduler" | "CosineAnnealingLR" 120 | , "G_scheduler_milestones": [500, 1000, 2000, 2500, 3000] // for "MultiStepLR" 121 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 122 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 123 | , "G_scheduler_eta_min": 1e-7 // for "CosineAnnealingWarmRestarts" 124 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 125 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 126 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 127 | , "G_scheduler_multiplier": 1 // for "GradualWarmupScheduler" 128 | , "G_scheduler_warmup_epochs": 3 // for "GradualWarmupScheduler" 129 | 130 | , "G_regularizer_orthstep": null // unused 131 | , "G_regularizer_clipstep": null // unused 132 | 133 | , "checkpoint_valid": 1 // for validating per N epoch 134 | , "checkpoint_save": 1 // for saving model per N epoch 135 | , "checkpoint_print": 200 // for print every iteration 136 | } 137 | } 138 | -------------------------------------------------------------------------------- /options/option_official_implementation_mimounet.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_mimounet" // root/task/images-models-options 3 | , "model": "multiout" // "plain" | "multiout" | "progressive" 4 | , "gpu_ids": [7] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset 20 | , "H_size": 224 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 21 | 22 | , "sigma": 25 // unused 23 | , "sigma_test": 25 // unused 24 | 25 | , "dataloader_shuffle": true 26 | , "dataloader_num_workers": 8 27 | , "dataloader_batch_size": 4 // batch size 1 | 16 | 32 | 48 | 64 | 128 28 | } 29 | , "valid": { 30 | "name": "valid_dataset" // just name 31 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 32 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset 33 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset 34 | 35 | , "sigma": 25 // unused 36 | , "sigma_valid": 25 // unused 37 | 38 | , "dataloader_num_workers": 8 39 | , "dataloader_batch_size": 4 // batch size 1 | 16 | 32 | 48 | 64 | 128 40 | } 41 | , "test": { 42 | "name": "test_dataset" // unused 43 | , "dataset_type": "plain" // dataset type 44 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 45 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 46 | 47 | , "sigma": 25 // unused 48 | , "sigma_valid": 25 // unused 49 | 50 | , "dataloader_num_workers": 0 51 | , "dataloader_batch_size": 1 // batch size 1 52 | } 53 | } 54 | 55 | , "netG": { 56 | "net_type": "mimounet" // "mimounet" | "mimounetplus" | "restormer" 57 | , "in_nc": 3 // input channel number 58 | , "out_nc": 3 // ouput channel number 59 | , "nc": 32 // basic hidden dim or base channel 60 | , "nb": 8 // number of blocks (list if different scales) 61 | , "gc": 32 62 | , "ng": 2 63 | , "act_mode": "R" // unused "BR" for BN+ReLU | "R" for ReLU 64 | , "upsample_mode": "upconv" // unused "pixelshuffle" | "convtranspose" | "upconv" 65 | , "downsample_mode": "strideconv" // unused "strideconv" | "avgpool" | "maxpool" 66 | 67 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 68 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 69 | , "init_gain": 0.2 70 | } 71 | 72 | , "train": { 73 | "total_epoch": 3000000 74 | ,"G_lossfn_type": "l1+fft" // "l1" | "l2sum" | "l2" | "ssim" 75 | , "G_lossfn_weight": [1.0, 0.1] // default 76 | 77 | , "G_optimizer_type": "adam" // fixed, adam is enough 78 | , "G_optimizer_lr": 1e-4 // learning rate 79 | , "G_optimizer_betas": [0.9, 0.999] // beta 80 | , "G_optimizer_clipgrad": null // the max norm of grad for clipping (negative for unclipping) 81 | 82 | , "G_scheduler_type": "MultiStepLR" // "MultiStepLR" is enough 83 | , "G_scheduler_milestones": [200000, 400000, 600000, 800000, 1000000, 2000000] // for "MultiStepLR" 84 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 85 | , "G_scheduler_period": 500 // for "CosineAnnealingWarmRestarts" 86 | , "G_scheduler_eta_min": 0.000001 // for "CosineAnnealingWarmRestarts" 87 | , "G_scheduler_periods": [900, 2000] // for "CosineAnnealingRestartCyclicLR" 88 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 89 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 90 | 91 | , "G_regularizer_orthstep": null // unused 92 | , "G_regularizer_clipstep": null // unused 93 | 94 | , "checkpoint_valid": 1 // for validating per N epoch 95 | , "checkpoint_save": 1 // for saving model per N epoch 96 | , "checkpoint_print": 200 // for print every iteration 97 | } 98 | } 99 | -------------------------------------------------------------------------------- /options/option_official_implementation_nafnet.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_nafnet" // root/task/images-models-options 3 | , "model": "plain" // "plain" | "multiout" | "progressive" 4 | , "gpu_ids": [4, 5] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "plain" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset (prefered full path) 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset (prefered full path) 20 | , "H_size": 256 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 21 | 22 | , "dataloader_shuffle": true 23 | , "dataloader_num_workers": 8 24 | , "dataloader_batch_size": 8 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 25 | 26 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 27 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 28 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 29 | } 30 | , "valid": { 31 | "name": "valid_dataset" // just name 32 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 33 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset (prefered full path) 34 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset (prefered full path) 35 | 36 | , "sigma": 25 // unused 37 | , "sigma_valid": 25 // unused 38 | 39 | , "dataloader_num_workers": 8 40 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 41 | } 42 | , "test": { 43 | "name": "test_dataset" // unused 44 | , "dataset_type": "plain" // dataset type 45 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 46 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 47 | 48 | , "sigma": 25 // unused 49 | , "sigma_valid": 25 // unused 50 | 51 | , "dataloader_num_workers": 0 52 | , "dataloader_batch_size": 1 // batch size 1 53 | } 54 | } 55 | 56 | , "netG": { 57 | "net_type": "nafnet" // "mimounet" | "mimounetplus" | "restormer" | "uformer" | "nafnet"("nafnet local" for test) 58 | , "in_nc": 3 // input channel number 59 | , "out_nc": 3 // ouput channel number 60 | , "nc": 64 // basic hidden dim or base channel and 16 for uformer_tiny 61 | , "nb": [2, 2, 2, 2] // number of blocks (list if different scales) 62 | , "n_refine_b": 4 // number of refinement blocks 63 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 64 | , "ffn_expansion_factor": 2.66 // hidden dim expanded in Gated-Dconv Network 65 | , "bias": false // bias in qkv generation 66 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 67 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 68 | , "token_mlp": "leff" // token of mlp in uformer 69 | , "enc_blk_nums": [1, 1, 1, 28]// encoder block number of nafnet 70 | , "dec_blk_nums": [1, 1, 1, 1] // decoder block number of nafnet 71 | 72 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 73 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 74 | , "init_gain": 0.2 75 | } 76 | 77 | , "train": { 78 | "total_epoch": 30000 79 | ,"G_lossfn_type": "psnr" // "l1" | "l2sum" | "l2" | "ssim" | "psnr" | "charbonnier" | 'l1+ssim' | 'l1+fft' 80 | , "G_lossfn_weight": [1.0] // default 81 | 82 | , "G_optimizer_type": "adamw" // fixed, adam is enough 83 | , "G_optimizer_lr": 1e-3 // learning rate 84 | , "G_optimizer_betas": [0.9, 0.9] // beta 85 | , "G_optimizer_wd": 1e-3 // weight decay 86 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 87 | 88 | , "G_scheduler_type": "CosineAnnealingLR" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" | "GradualWarmupScheduler" | "CosineAnnealingLR" 89 | , "G_scheduler_milestones": [5000, 10000, 20000, 25000, 30000] // for "MultiStepLR" 90 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 91 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 92 | , "G_scheduler_eta_min": 1e-7 // for "CosineAnnealingWarmRestarts" 93 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 94 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 95 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 96 | , "G_scheduler_multiplier": 1 // for "GradualWarmupScheduler" 97 | , "G_scheduler_warmup_epochs": 3 // for "GradualWarmupScheduler" 98 | 99 | , "G_regularizer_orthstep": null // unused 100 | , "G_regularizer_clipstep": null // unused 101 | 102 | , "checkpoint_valid": 1 // for validating per N epoch 103 | , "checkpoint_save": 1 // for saving model per N epoch 104 | , "checkpoint_print": 200 // for print every iteration 105 | } 106 | } 107 | -------------------------------------------------------------------------------- /options/option_official_implementation_painter.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_painter" // root/task/images-models-options 3 | , "model": "multiin" // "plain" | "multiout" | "progressive" | "multiin" 4 | , "gpu_ids": [6, 7] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/multimodal" // "denoising" | "superresolution" | "deblurring" | "multimodel" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "multimodal" // "plain" | "plainpatch" | "multimodal" 18 | , "dataroot_H": null // path of High-quality training dataset (prefered full path), for painter, use json_path_list 19 | , "dataroot_L": null // path of Low-quality training dataset (prefered full path) , for painter, use json_path_list 20 | , "json_path_list": ["/data1/Denoising/SIDD_Srgb/denoise_ssid_train.json", 21 | "/data1/Deraining/Rain100H/derain_rain100h_train.json", 22 | "/data1/Low_Light_Enhancement/LOL/enhance_lol_train.json", 23 | "/data1/Depth_Estimation/NYU_Depth_V2/nyuv2_sync_image_depth.json"] 24 | , "H_size": [896, 448] // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 25 | , "use_two_pairs": true // in multimodal datasets, if you need one target pairs to guide the mapping 26 | , "P_size": 16 // patch size in ViT, used to calculate the window 27 | , "half_mask_ratio": 0.1 // half mask ratio for MAE 28 | 29 | , "dataloader_shuffle": true 30 | , "dataloader_num_workers": 8 31 | , "dataloader_batch_size": 2 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 32 | 33 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 34 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 35 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 36 | } 37 | , "valid": { 38 | "name": "valid_dataset" // just name 39 | , "dataset_type": "multimodal" // "plain" | "plainpatch" | "multimodal" 40 | , "dataroot_H": null // path of High-quality training dataset (prefered full path), for painter, use json_path_list 41 | , "dataroot_L": null // path of Low-quality training dataset (prefered full path) , for painter, use json_path_list 42 | , "json_path_list": ["/data1/Denoising/SIDD_Srgb/denoise_ssid_val.json", 43 | "/data1/Deraining/Rain100H/derain_rain100h_val.json", 44 | "/data1/Low_Light_Enhancement/LOL/enhance_lol_val.json", 45 | "/data1/Depth_Estimation/NYU_Depth_V2/nyuv2_test_image_depth.json"] 46 | , "H_size": [896, 448] // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 47 | , "use_two_pairs": true // in multimodal datasets, if you need one target pairs to guide the mapping 48 | , "P_size": 16 // patch size in ViT, used to calculate the window 49 | , "half_mask_ratio": 0.1 // half mask ratio for MAE 50 | 51 | , "sigma": 25 // unused 52 | , "sigma_valid": 25 // unused 53 | 54 | , "dataloader_num_workers": 8 55 | , "dataloader_batch_size": 2 // batch size 1 | 16 | 32 | 48 | 64 | 128 56 | } 57 | , "test": { 58 | "name": "test_dataset" // unused 59 | , "dataset_type": "multimodal" // dataset type 60 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 61 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 62 | 63 | , "sigma": 25 // unused 64 | , "sigma_valid": 25 // unused 65 | 66 | , "dataloader_num_workers": 0 67 | , "dataloader_batch_size": 1 // batch size 1 68 | } 69 | } 70 | 71 | , "netG": { 72 | "net_type": "painter" // "mimounet" | "mimounetplus" | "restormer" | "uformer" | "nafnet"("nafnet local" for test) | "fftformer" | "painter" 73 | , "in_nc": 3 // input channel number 74 | , "out_nc": 3 // ouput channel number 75 | , "nc": 48 // basic hidden dim or base channel and 16 for uformer_tiny 76 | , "nb": [6, 6, 12] // number of blocks (list if different scales) 77 | , "n_refine_b": 4 // number of refinement blocks 78 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 79 | , "ffn_expansion_factor": 3 // hidden dim expanded in Gated-Dconv Network 80 | , "bias": false // bias in qkv generation 81 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 82 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 83 | , "token_mlp": "leff" // token of mlp in uformer 84 | , "enc_blk_nums": [1, 1, 1, 28]// encoder block number of nafnet 85 | , "dec_blk_nums": [1, 1, 1, 1] // decoder block number of nafnet 86 | 87 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 88 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 89 | , "init_gain": 0.2 90 | } 91 | 92 | , "train": { 93 | "total_epoch": 15 94 | ,"G_lossfn_type": "smoothl1" // "l1" | "l2sum" | "l2" | "ssim" | "psnr" | "charbonnier" | "smoothl1" | 'l1+ssim' | 'l1+fft' 95 | , "G_lossfn_weight": [1.0] // default 96 | 97 | , "G_optimizer_type": "adamw" // adam for CNN and adamw for transformer 98 | , "G_optimizer_lr": 1e-3 // learning rate 99 | , "G_optimizer_betas": [0.9, 0.999] // beta 100 | , "G_optimizer_wd": 1e-3 // weight decay 101 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 102 | 103 | , "G_scheduler_type": "GradualWarmupScheduler" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" | "GradualWarmupScheduler" | "CosineAnnealingLR" 104 | , "G_scheduler_milestones": [5000, 10000, 20000, 25000, 30000] // for "MultiStepLR" 105 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 106 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 107 | , "G_scheduler_eta_min": 1e-7 // for "CosineAnnealingWarmRestarts" 108 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 109 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 110 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 111 | , "G_scheduler_multiplier": 1 // for "GradualWarmupScheduler" 112 | , "G_scheduler_warmup_epochs": 1 // for "GradualWarmupScheduler" 113 | 114 | , "G_regularizer_orthstep": null // unused 115 | , "G_regularizer_clipstep": null // unused 116 | 117 | , "checkpoint_valid": 1 // for validating per N epoch 118 | , "checkpoint_save": 1 // for saving model per N epoch 119 | , "checkpoint_print": 200 // for print every iteration 120 | } 121 | } 122 | -------------------------------------------------------------------------------- /options/option_official_implementation_restormer.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_restormer" // root/task/images-models-options 3 | , "model": "progressive" // "plain" | "multiout" | "progressive" 4 | , "gpu_ids": [2, 5] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "plain" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset (prefered full path) 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset (prefered full path) 20 | , "H_size": 384 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 21 | 22 | , "dataloader_shuffle": true 23 | , "dataloader_num_workers": 8 24 | , "dataloader_batch_size": 64 // batch size 1 | 16 | 32 | 48 | 64 | 128 25 | 26 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 27 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 28 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 29 | } 30 | , "valid": { 31 | "name": "valid_dataset" // just name 32 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 33 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset (prefered full path) 34 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset (prefered full path) 35 | 36 | , "sigma": 25 // unused 37 | , "sigma_valid": 25 // unused 38 | 39 | , "dataloader_num_workers": 8 40 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 41 | } 42 | , "test": { 43 | "name": "test_dataset" // unused 44 | , "dataset_type": "plain" // dataset type 45 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 46 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 47 | 48 | , "sigma": 25 // unused 49 | , "sigma_valid": 25 // unused 50 | 51 | , "dataloader_num_workers": 0 52 | , "dataloader_batch_size": 1 // batch size 1 53 | } 54 | } 55 | 56 | , "netG": { 57 | "net_type": "restormer" // "mimounet" | "mimounetplus" | "restormer" | "uformer" | "nafnet"("nafnet local") 58 | , "in_nc": 3 // input channel number 59 | , "out_nc": 3 // ouput channel number 60 | , "nc": 48 // basic hidden dim or base channel 61 | , "nb": [4, 6, 6, 8] // number of blocks (list if different scales) 62 | , "n_refine_b": 4 // number of refinement blocks 63 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 64 | , "ffn_expansion_factor": 2.66 // hidden dim expanded in Gated-Dconv Network 65 | , "bias": false // bias in qkv generation 66 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 67 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 68 | 69 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 70 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 71 | , "init_gain": 0.2 72 | } 73 | 74 | , "train": { 75 | "total_epoch": 30000 76 | ,"G_lossfn_type": "l1" // "l1" | "l2sum" | "l2" | "ssim" | "charbonnier" | 'l1+ssim' | 'l1+fft' 77 | , "G_lossfn_weight": [1.0] // default 78 | 79 | , "G_optimizer_type": "adamw" // fixed, adam is enough 80 | , "G_optimizer_lr": 3e-4 // learning rate 81 | , "G_optimizer_betas": [0.9, 0.999] // beta 82 | , "G_optimizer_wd": 1e-4 // weight decay 83 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 84 | 85 | , "G_scheduler_type": "CosineAnnealingRestartCyclicLR" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" 86 | , "G_scheduler_milestones": [5000, 10000, 20000, 25000, 30000] // for "MultiStepLR" 87 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 88 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 89 | , "G_scheduler_eta_min": 0.000001 // for "CosineAnnealingWarmRestarts" 90 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 91 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 92 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 93 | 94 | , "G_regularizer_orthstep": null // unused 95 | , "G_regularizer_clipstep": null // unused 96 | 97 | , "checkpoint_valid": 1 // for validating per N epoch 98 | , "checkpoint_save": 1 // for saving model per N epoch 99 | , "checkpoint_print": 200 // for print every iteration 100 | } 101 | } 102 | -------------------------------------------------------------------------------- /options/option_official_implementation_uformer.json: -------------------------------------------------------------------------------- 1 | { 2 | "task": "official_implementation_of_uformer" // root/task/images-models-options 3 | , "model": "plain" // "plain" | "multiout" | "progressive" 4 | , "gpu_ids": [0, 1] 5 | 6 | , "scale": 1 // broadcast to "netG" if SISR 7 | , "n_channels": 3 // broadcast to "datasets", 1 for grayscale, 3 for color 8 | 9 | , "path": { 10 | "root": "results/deblurring" // "denoising" | "superresolution" | "deblurring" 11 | , "pretrained_netG": null // path of pretrained model 12 | } 13 | 14 | , "datasets": { 15 | "train": { 16 | "name": "train_dataset" // just name 17 | , "dataset_type": "plain" // "plain" 18 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/train/target_crops" // path of High-quality training dataset (prefered full path) 19 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/train/input_crops" // path of Low-quality training dataset (prefered full path) 20 | , "H_size": 256 // patch size 40 | 64 | 96 | 128 | 192 | 224 | 256 | 384 21 | 22 | , "dataloader_shuffle": true 23 | , "dataloader_num_workers": 8 24 | , "dataloader_batch_size": 8 // batch size 1 | 8 | 16 | 32 | 48 | 64 | 128 25 | 26 | , "mini_batch_sizes": [8, 5, 4, 2, 1, 1] // mini batch size for progressive training 27 | , "iters_milestones": [92000, 64000, 48000, 36000, 36000, 24000] // milestone of iteration for progressive training 28 | , "mini_H_sizes" : [128, 160, 192, 256, 320, 384] // varing H_size for progressive training 29 | } 30 | , "valid": { 31 | "name": "valid_dataset" // just name 32 | , "dataset_type": "plain" // "dncnn" | "dnpatch" for dncnn, | "fdncnn" | "ffdnet" | "sr" | "srmd" | "dpsr" | "plain" | "plainpatch" 33 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/val/target_crops" // path of High-quality testing dataset (prefered full path) 34 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/val/input_crops" // path of Low-quality testing dataset (prefered full path) 35 | 36 | , "sigma": 25 // unused 37 | , "sigma_valid": 25 // unused 38 | 39 | , "dataloader_num_workers": 8 40 | , "dataloader_batch_size": 16 // batch size 1 | 16 | 32 | 48 | 64 | 128 41 | } 42 | , "test": { 43 | "name": "test_dataset" // unused 44 | , "dataset_type": "plain" // dataset type 45 | , "dataroot_H": "/data1/Motion_Deblurring/GoPro/test/target" // path of High-quality testing dataset 46 | , "dataroot_L": "/data1/Motion_Deblurring/GoPro/test/input" // path of Low-quality testing dataset 47 | 48 | , "sigma": 25 // unused 49 | , "sigma_valid": 25 // unused 50 | 51 | , "dataloader_num_workers": 0 52 | , "dataloader_batch_size": 1 // batch size 1 53 | } 54 | } 55 | 56 | , "netG": { 57 | "net_type": "uformer" // "mimounet" | "mimounetplus" | "restormer" 58 | , "in_nc": 3 // input channel number 59 | , "out_nc": 3 // ouput channel number 60 | , "nc": 32 // basic hidden dim or base channel and 16 for uformer_tiny 61 | , "nb": [2, 2, 2, 2, 2, 2, 2, 2, 2] // number of blocks (list if different scales) and [1, 2, 8, 8, 2, 8, 8, 2, 1] for uformer_b 62 | , "n_refine_b": 4 // number of refinement blocks 63 | , "heads": [1, 2, 4, 8] // heads of multi-head attention 64 | , "ffn_expansion_factor": 2.66 // hidden dim expanded in Gated-Dconv Network 65 | , "bias": false // bias in qkv generation 66 | , "LayerNorm_type": "WithBias" // Other option 'BiasFree' 67 | , "dual_pixel_task": false // ## True for dual-pixel defocus deblurring only. Also set inp_channels=6 68 | , "token_mlp": "leff" // token of mlp in uformer 69 | 70 | , "init_type": "orthogonal" // unused "orthogonal" | "normal" | "uniform" | "xavier_normal" | "xavier_uniform" | "kaiming_normal" | "kaiming_uniform" 71 | , "init_bn_type": "uniform" // unused "uniform" | "constant" 72 | , "init_gain": 0.2 73 | } 74 | 75 | , "train": { 76 | "total_epoch": 30000 77 | ,"G_lossfn_type": "l1" // "l1" | "l2sum" | "l2" | "ssim" | "charbonnier" | 'l1+ssim' | 'l1+fft' 78 | , "G_lossfn_weight": [1.0] // default 79 | 80 | , "G_optimizer_type": "adamw" // fixed, adam is enough 81 | , "G_optimizer_lr": 3e-4 // learning rate 82 | , "G_optimizer_betas": [0.9, 0.999] // beta 83 | , "G_optimizer_wd": 1e-4 // weight decay 84 | , "G_optimizer_clipgrad": 0.01 // the max norm of grad for clipping (negative for unclipping) 85 | 86 | , "G_scheduler_type": "GradualWarmupScheduler" // "MultiStepLR" | "CosineAnnealingWarmRestarts" | "CosineAnnealingRestartCyclicLR" | "GradualWarmupScheduler" 87 | , "G_scheduler_milestones": [5000, 10000, 20000, 25000, 30000] // for "MultiStepLR" 88 | , "G_scheduler_gamma": 0.5 // for "MultiStepLR" 89 | , "G_scheduler_period": 5000 // for "CosineAnnealingWarmRestarts" 90 | , "G_scheduler_eta_min": 1e-6 // for "CosineAnnealingWarmRestarts" 91 | , "G_scheduler_periods": [9200, 20800] // for "CosineAnnealingRestartCyclicLR" 92 | , "G_scheduler_restart_weights": [1, 1] // for "CosineAnnealingRestartCyclicLR" 93 | , "G_scheduler_eta_mins": [0.0003,0.000001] // for "CosineAnnealingRestartCyclicLR" 94 | , "G_scheduler_multiplier": 1 // for "GradualWarmupScheduler" 95 | , "G_scheduler_warmup_epochs": 3 // for "GradualWarmupScheduler" 96 | 97 | , "G_regularizer_orthstep": null // unused 98 | , "G_regularizer_clipstep": null // unused 99 | 100 | , "checkpoint_valid": 1 // for validating per N epoch 101 | , "checkpoint_save": 1 // for saving model per N epoch 102 | , "checkpoint_print": 200 // for print every iteration 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | einops=0.6.1=pypi_0 5 | h5py=3.9.0=pypi_0 6 | huggingface-hub=0.15.1=pypi_0 7 | numpy=1.25.0=pypi_0 8 | opencv-python=4.7.0.72=pypi_0 9 | opendatalab=0.0.9=pypi_0 10 | pillow=9.5.0=pypi_0 11 | pip=23.1.2=py39h06a4308_0 12 | python=3.9.6=h12debd9_1 13 | pyyaml=6.0=pypi_0 14 | scikit-image=0.21.0=pypi_0 15 | scipy=1.11.1=pypi_0 16 | tensorboard=2.13.0=pypi_0 17 | tensorboard-data-server=0.7.1=pypi_0 18 | tifffile=2023.7.4=pypi_0 19 | timm=0.9.2=pypi_0 20 | torch=2.0.1+cu118=pypi_0 21 | torchstat=0.0.7=pypi_0 22 | torchvision=0.15.2+cu118=pypi_0 23 | tqdm=4.65.0=pypi_0 24 | -------------------------------------------------------------------------------- /utils/__pycache__/utils_bnorm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_bnorm.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_dist.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_dist.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_image.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_image.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_logger.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_logger.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_model.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_option.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_option.cpython-39.pyc -------------------------------------------------------------------------------- /utils/__pycache__/utils_regularizers.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TanGeeGo/toolbox/bee56747d426b33c57381426f3dcb083f568fde8/utils/__pycache__/utils_regularizers.cpython-39.pyc -------------------------------------------------------------------------------- /utils/utils_bnorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | """ 6 | # -------------------------------------------- 7 | # Merge Batch Normalization with convolution 8 | # to accelerate the inference of model and training 9 | # see https://zhuanlan.zhihu.com/p/49329030 for details 10 | # -------------------------------------------- 11 | """ 12 | 13 | 14 | # -------------------------------------------- 15 | # remove/delete specified layer 16 | # -------------------------------------------- 17 | def deleteLayer(model, layer_type=nn.BatchNorm2d): 18 | ''' Kai Zhang, 11/Jan/2019. 19 | ''' 20 | for k, m in list(model.named_children()): 21 | if isinstance(m, layer_type): 22 | del model._modules[k] 23 | deleteLayer(m, layer_type) 24 | 25 | 26 | # -------------------------------------------- 27 | # merge bn, "conv+bn" --> "conv" 28 | # -------------------------------------------- 29 | def merge_bn(model): 30 | ''' Kai Zhang, 11/Jan/2019. 31 | merge all 'Conv+BN' (or 'TConv+BN') into 'Conv' (or 'TConv') 32 | based on https://github.com/pytorch/pytorch/pull/901 33 | ''' 34 | prev_m = None 35 | for k, m in list(model.named_children()): 36 | if (isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d)) and (isinstance(prev_m, nn.Conv2d) or isinstance(prev_m, nn.Linear) or isinstance(prev_m, nn.ConvTranspose2d)): 37 | 38 | w = prev_m.weight.data 39 | 40 | if prev_m.bias is None: 41 | zeros = torch.Tensor(prev_m.out_channels).zero_().type(w.type()) 42 | prev_m.bias = nn.Parameter(zeros) 43 | b = prev_m.bias.data 44 | 45 | invstd = m.running_var.clone().add_(m.eps).pow_(-0.5) 46 | if isinstance(prev_m, nn.ConvTranspose2d): 47 | w.mul_(invstd.view(1, w.size(1), 1, 1).expand_as(w)) 48 | else: 49 | w.mul_(invstd.view(w.size(0), 1, 1, 1).expand_as(w)) 50 | b.add_(-m.running_mean).mul_(invstd) 51 | if m.affine: 52 | if isinstance(prev_m, nn.ConvTranspose2d): 53 | w.mul_(m.weight.data.view(1, w.size(1), 1, 1).expand_as(w)) 54 | else: 55 | w.mul_(m.weight.data.view(w.size(0), 1, 1, 1).expand_as(w)) 56 | b.mul_(m.weight.data).add_(m.bias.data) 57 | 58 | del model._modules[k] 59 | prev_m = m 60 | merge_bn(m) 61 | 62 | 63 | # -------------------------------------------- 64 | # add bn, "conv" --> "conv+bn" 65 | # -------------------------------------------- 66 | def add_bn(model): 67 | for k, m in list(model.named_children()): 68 | if (isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d)): 69 | b = nn.BatchNorm2d(m.out_channels, momentum=0.1, affine=True) 70 | b.weight.data.fill_(1) 71 | new_m = nn.Sequential(model._modules[k], b) 72 | model._modules[k] = new_m 73 | add_bn(m) 74 | 75 | 76 | # -------------------------------------------- 77 | # tidy model after removing bn 78 | # -------------------------------------------- 79 | def tidy_sequential(model): 80 | for k, m in list(model.named_children()): 81 | if isinstance(m, nn.Sequential): 82 | if m.__len__() == 1: 83 | model._modules[k] = m.__getitem__(0) 84 | tidy_sequential(m) 85 | -------------------------------------------------------------------------------- /utils/utils_dist.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import os 3 | import functools 4 | import pickle 5 | import subprocess 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | 11 | # ---------------------------------- 12 | # init 13 | # ---------------------------------- 14 | def init_dist(launcher, backend='nccl', **kwargs): 15 | if mp.get_start_method(allow_none=True) is None: 16 | mp.set_start_method('spawn') 17 | if launcher == 'pytorch': 18 | _init_dist_pytorch(backend, **kwargs) 19 | elif launcher == 'slurm': 20 | _init_dist_slurm(backend, **kwargs) 21 | else: 22 | raise ValueError(f'Invalid launcher type: {launcher}') 23 | 24 | 25 | def _init_dist_pytorch(backend, **kwargs): 26 | rank = int(os.environ['RANK']) 27 | num_gpus = torch.cuda.device_count() 28 | torch.cuda.set_device(rank % num_gpus) 29 | dist.init_process_group(backend=backend, **kwargs) 30 | 31 | 32 | def _init_dist_slurm(backend, port=None): 33 | """Initialize slurm distributed training environment. 34 | If argument ``port`` is not specified, then the master port will be system 35 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 36 | environment variable, then a default port ``29500`` will be used. 37 | Args: 38 | backend (str): Backend of torch.distributed. 39 | port (int, optional): Master port. Defaults to None. 40 | """ 41 | proc_id = int(os.environ['SLURM_PROCID']) 42 | ntasks = int(os.environ['SLURM_NTASKS']) 43 | node_list = os.environ['SLURM_NODELIST'] 44 | num_gpus = torch.cuda.device_count() 45 | torch.cuda.set_device(proc_id % num_gpus) 46 | addr = subprocess.getoutput( 47 | f'scontrol show hostname {node_list} | head -n1') 48 | # specify master port 49 | if port is not None: 50 | os.environ['MASTER_PORT'] = str(port) 51 | elif 'MASTER_PORT' in os.environ: 52 | pass # use MASTER_PORT in the environment variable 53 | else: 54 | # 29500 is torch.distributed default port 55 | os.environ['MASTER_PORT'] = '29500' 56 | os.environ['MASTER_ADDR'] = addr 57 | os.environ['WORLD_SIZE'] = str(ntasks) 58 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 59 | os.environ['RANK'] = str(proc_id) 60 | dist.init_process_group(backend=backend) 61 | 62 | 63 | 64 | # ---------------------------------- 65 | # get rank and world_size 66 | # ---------------------------------- 67 | def get_dist_info(): 68 | if dist.is_available(): 69 | initialized = dist.is_initialized() 70 | else: 71 | initialized = False 72 | if initialized: 73 | rank = dist.get_rank() 74 | world_size = dist.get_world_size() 75 | else: 76 | rank = 0 77 | world_size = 1 78 | return rank, world_size 79 | 80 | 81 | def get_rank(): 82 | if not dist.is_available(): 83 | return 0 84 | 85 | if not dist.is_initialized(): 86 | return 0 87 | 88 | return dist.get_rank() 89 | 90 | 91 | def get_world_size(): 92 | if not dist.is_available(): 93 | return 1 94 | 95 | if not dist.is_initialized(): 96 | return 1 97 | 98 | return dist.get_world_size() 99 | 100 | 101 | def master_only(func): 102 | 103 | @functools.wraps(func) 104 | def wrapper(*args, **kwargs): 105 | rank, _ = get_dist_info() 106 | if rank == 0: 107 | return func(*args, **kwargs) 108 | 109 | return wrapper 110 | 111 | 112 | 113 | 114 | 115 | 116 | # ---------------------------------- 117 | # operation across ranks 118 | # ---------------------------------- 119 | def reduce_sum(tensor): 120 | if not dist.is_available(): 121 | return tensor 122 | 123 | if not dist.is_initialized(): 124 | return tensor 125 | 126 | tensor = tensor.clone() 127 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 128 | 129 | return tensor 130 | 131 | 132 | def gather_grad(params): 133 | world_size = get_world_size() 134 | 135 | if world_size == 1: 136 | return 137 | 138 | for param in params: 139 | if param.grad is not None: 140 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 141 | param.grad.data.div_(world_size) 142 | 143 | 144 | def all_gather(data): 145 | world_size = get_world_size() 146 | 147 | if world_size == 1: 148 | return [data] 149 | 150 | buffer = pickle.dumps(data) 151 | storage = torch.ByteStorage.from_buffer(buffer) 152 | tensor = torch.ByteTensor(storage).to('cuda') 153 | 154 | local_size = torch.IntTensor([tensor.numel()]).to('cuda') 155 | size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)] 156 | dist.all_gather(size_list, local_size) 157 | size_list = [int(size.item()) for size in size_list] 158 | max_size = max(size_list) 159 | 160 | tensor_list = [] 161 | for _ in size_list: 162 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda')) 163 | 164 | if local_size != max_size: 165 | padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda') 166 | tensor = torch.cat((tensor, padding), 0) 167 | 168 | dist.all_gather(tensor_list, tensor) 169 | 170 | data_list = [] 171 | 172 | for size, tensor in zip(size_list, tensor_list): 173 | buffer = tensor.cpu().numpy().tobytes()[:size] 174 | data_list.append(pickle.loads(buffer)) 175 | 176 | return data_list 177 | 178 | 179 | def reduce_loss_dict(loss_dict): 180 | world_size = get_world_size() 181 | 182 | if world_size < 2: 183 | return loss_dict 184 | 185 | with torch.no_grad(): 186 | keys = [] 187 | losses = [] 188 | 189 | for k in sorted(loss_dict.keys()): 190 | keys.append(k) 191 | losses.append(loss_dict[k]) 192 | 193 | losses = torch.stack(losses, 0) 194 | dist.reduce(losses, dst=0) 195 | 196 | if dist.get_rank() == 0: 197 | losses /= world_size 198 | 199 | reduced_losses = {k: v for k, v in zip(keys, losses)} 200 | 201 | return reduced_losses 202 | 203 | -------------------------------------------------------------------------------- /utils/utils_filter.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | 5 | #------------------------------------------------------- 6 | # filter utils 7 | #------------------------------------------------------- 8 | 9 | #define 2d gaussian kernel 10 | def gaussian_kernel_2d(ksize, sigma): 11 | 12 | return cv2.getGaussianKernel(ksize,sigma) * np.transpose(cv2.getGaussianKernel(ksize, sigma)) 13 | 14 | # kernel wiener filter inverse 15 | def kernel_inv(kernel): 16 | 17 | fft = np.fft.fft2(kernel) 18 | k_inv = np.fft.ifft2(np.conj(fft) / (np.abs(fft)*np.abs(fft)+1e-2)) 19 | 20 | return np.abs(k_inv) / np.sum(np.abs(k_inv)) 21 | 22 | # generate inverse kernel of different sigma 23 | def gen_gausskernel_ivs(ksize, sigma_range): 24 | 25 | k_ivs=np.zeros((len(sigma_range), ksize, ksize)) 26 | for i in range(len(sigma_range)): 27 | 28 | temp=gaussian_kernel_2d(ksize, sigma_range[i]) 29 | k_ivs[i, :, :]=kernel_inv(temp) 30 | 31 | return k_ivs 32 | 33 | # kernel filter fft2 like wiener filter 34 | def kernel_fft(kernel, patch_size, eptional): 35 | 36 | # generate fft kernel size the same as img size 37 | fft = np.fft.fft2(kernel, (patch_size, patch_size)) 38 | k_size = kernel.shape[-1] 39 | k_fft = np.zeros((k_size, k_size), dtype=complex) 40 | k_fft = np.conj(fft) / (np.abs(fft) * np.abs(fft) + eptional) 41 | 42 | return k_fft 43 | 44 | def kernel_fft_t(kernel, x_shape, eptional): 45 | 46 | fft = torch.fft.fft2(kernel, (x_shape[-2],x_shape[-1])) # generate fft kernel size the same as img size 47 | 48 | k_fft = torch.conj(fft)/(torch.abs(fft)*torch.abs(fft)+eptional) 49 | 50 | return k_fft 51 | 52 | def eptional_fft(eptional_map, x_shape): 53 | 54 | # calculate eptional fft /size fit different input scale 55 | eptional_map_fft = list() 56 | for i in range(eptional_map.shape[0]): 57 | eptional_map_fft.append(torch.fft.fft2(eptional_map[i,:,:], (x_shape[-2], x_shape[-1])).unsqueeze(0)) 58 | 59 | eptional_map_fft.append(torch.ones_like(eptional_map_fft[0]).to(eptional_map.device)) 60 | 61 | return torch.cat(eptional_map_fft, dim=0) 62 | 63 | # leastSquare eptional 64 | def leastSquare(patch_size): 65 | #laplace operator 66 | la = [[0, -1, 0], [-1, 4, -1], [0, -1, 0]] 67 | la_fft = np.fft.fft2(la,(patch_size,patch_size)) 68 | 69 | return la_fft 70 | 71 | def sinefit(H, W, omega_num=10, theta_num=20, sigma=2): 72 | """ 73 | fit the epsilon with sine wave in case of the missing of epsilon 74 | """ 75 | print('sine fit,omega_num:{},theta_num:{},sigma={}\n'.format(omega_num,theta_num,sigma)) 76 | #low frequency,generate gaussian kernel 77 | low_mat = gaussian_kernel_2d(H,sigma) 78 | low_mat_f = np.fft.fft2(low_mat) 79 | 80 | # middle frequency, omega range[0,1),default omega number=10,default theta_num = 20 81 | 82 | omega = np.random.uniform(0.8,0.9,omega_num) 83 | theta = 360*np.random.uniform(size=theta_num) 84 | mid_mat = np.zeros((H,W)) 85 | mid_mat_f = np.zeros((omega_num*theta_num,H,W)) 86 | # fit with the sine wave of different omegas and thetas 87 | for i in range(omega_num): 88 | for j in range(theta_num): 89 | 90 | w1 = np.sin(theta[j]) 91 | w2 = np.cos(theta[j]) 92 | 93 | # meshgrid 94 | h = np.linspace(1,H,H) 95 | w = np.linspace(1,W,W) 96 | w_mat,h_mat = np.meshgrid(w,h) 97 | mid_mat= np.cos(omega[i]*(w1*h_mat+w2*w_mat)) 98 | mid_mat_f[i*theta_num+j,:,:]=np.fft.fft2(mid_mat) 99 | 100 | # high frequency,import gen_leastSquare_fft to generate laplace matrix 101 | la = [[0,-1,0],[-1,4,-1],[0,-1,0]] 102 | high_mat_f = np.fft.fft2(la,(H,W)) 103 | 104 | return low_mat_f, mid_mat_f, high_mat_f # return the frequency of low, middle, and high 105 | 106 | def gauss_fit(H, W, sig_low=0.2, sig_high=2.5, sig_num=25): 107 | """ 108 | fit the epsilon with gaussian kernel in case of the missing of epsilon 109 | """ 110 | sig_interval = (sig_high-sig_low-0.001)/(sig_num-1) 111 | eption_map = np.zeros((sig_num,H,W)) 112 | for i in range(sig_num-1): 113 | 114 | sigma = sig_low + sig_interval*i 115 | eption_map[i,:,:] = gaussian_kernel_2d(H,sigma) 116 | # eption_map[i,:,:] = np.fft.fft2(eption_map[i,:,:]) 117 | eption_map[sig_num-1,:,:] = np.ones((H,W)) 118 | 119 | return eption_map 120 | 121 | -------------------------------------------------------------------------------- /utils/utils_logger.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import datetime 3 | import logging 4 | 5 | def log(*args, **kwargs): 6 | print(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S:"), *args, **kwargs) 7 | 8 | ''' 9 | # -------------------------------------------- 10 | # logger 11 | # -------------------------------------------- 12 | ''' 13 | 14 | def logger_info(logger_name, log_path='default_logger.log'): 15 | ''' set up logger 16 | modified by Kai Zhang (github: https://github.com/cszn) 17 | ''' 18 | log = logging.getLogger(logger_name) 19 | if log.hasHandlers(): 20 | print('LogHandlers exist!') 21 | else: 22 | print('LogHandlers setup!') 23 | level = logging.INFO 24 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d : %(message)s', datefmt='%y-%m-%d %H:%M:%S') 25 | fh = logging.FileHandler(log_path, mode='a') 26 | fh.setFormatter(formatter) 27 | log.setLevel(level) 28 | log.addHandler(fh) 29 | # print(len(log.handlers)) 30 | 31 | sh = logging.StreamHandler() 32 | sh.setFormatter(formatter) 33 | log.addHandler(sh) 34 | 35 | ''' 36 | # -------------------------------------------- 37 | # print to file and std_out simultaneously 38 | # -------------------------------------------- 39 | ''' 40 | 41 | class logger_print(object): 42 | def __init__(self, log_path="default.log"): 43 | self.terminal = sys.stdout 44 | self.log = open(log_path, 'a') 45 | 46 | def write(self, message): 47 | self.terminal.write(message) 48 | self.log.write(message) # write the message 49 | 50 | def flush(self): 51 | pass 52 | -------------------------------------------------------------------------------- /utils/utils_mask.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Images Speak in Images: A Generalist Painter for In-Context Visual Learning (https://arxiv.org/abs/2212.02499) 3 | # Github source: https://github.com/baaivision/Painter 4 | # Copyright (c) 2022 Beijing Academy of Artificial Intelligence (BAAI) 5 | # Licensed under The MIT License [see LICENSE for details] 6 | # By Xinlong Wang, Wen Wang 7 | # Based on MAE, BEiT, detectron2, Mask2Former, bts, mmcv, mmdetetection, mmpose, MIRNet, MPRNet, and Uformer codebases 8 | # --------------------------------------------------------' 9 | 10 | import random 11 | import math 12 | import numpy as np 13 | 14 | 15 | class MaskingGenerator: 16 | def __init__( 17 | self, input_size, num_masking_patches, min_num_patches=4, max_num_patches=None, 18 | min_aspect=0.3, max_aspect=None): 19 | if not isinstance(input_size, tuple): 20 | input_size = (input_size,) * 2 21 | self.height, self.width = input_size 22 | 23 | self.num_patches = self.height * self.width 24 | self.num_masking_patches = num_masking_patches 25 | 26 | self.min_num_patches = min_num_patches 27 | self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches 28 | 29 | max_aspect = max_aspect or 1 / min_aspect 30 | self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) 31 | 32 | def __repr__(self): 33 | repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( 34 | self.height, self.width, self.min_num_patches, self.max_num_patches, 35 | self.num_masking_patches, self.log_aspect_ratio[0], self.log_aspect_ratio[1]) 36 | return repr_str 37 | 38 | def get_shape(self): 39 | return self.height, self.width 40 | 41 | def _mask(self, mask, max_mask_patches): 42 | delta = 0 43 | for attempt in range(10): 44 | target_area = random.uniform(self.min_num_patches, max_mask_patches) 45 | aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) 46 | h = int(round(math.sqrt(target_area * aspect_ratio))) 47 | w = int(round(math.sqrt(target_area / aspect_ratio))) 48 | if w < self.width and h < self.height: 49 | top = random.randint(0, self.height - h) 50 | left = random.randint(0, self.width - w) 51 | 52 | num_masked = mask[top: top + h, left: left + w].sum() 53 | # Overlap 54 | if 0 < h * w - num_masked <= max_mask_patches: 55 | for i in range(top, top + h): 56 | for j in range(left, left + w): 57 | if mask[i, j] == 0: 58 | mask[i, j] = 1 59 | delta += 1 60 | 61 | if delta > 0: 62 | break 63 | return delta 64 | 65 | def __call__(self): 66 | mask = np.zeros(shape=self.get_shape(), dtype=np.int32) 67 | mask_count = 0 68 | while mask_count < self.num_masking_patches: 69 | max_mask_patches = self.num_masking_patches - mask_count 70 | max_mask_patches = min(max_mask_patches, self.max_num_patches) 71 | 72 | delta = self._mask(mask, max_mask_patches) 73 | if delta == 0: 74 | break 75 | else: 76 | mask_count += delta 77 | 78 | # maintain a fix number {self.num_masking_patches} 79 | if mask_count > self.num_masking_patches: 80 | delta = mask_count - self.num_masking_patches 81 | mask_x, mask_y = mask.nonzero() 82 | to_vis = np.random.choice(mask_x.shape[0], delta, replace=False) 83 | mask[mask_x[to_vis], mask_y[to_vis]] = 0 84 | 85 | elif mask_count < self.num_masking_patches: 86 | delta = self.num_masking_patches - mask_count 87 | mask_x, mask_y = (mask == 0).nonzero() 88 | to_mask = np.random.choice(mask_x.shape[0], delta, replace=False) 89 | mask[mask_x[to_mask], mask_y[to_mask]] = 1 90 | 91 | assert mask.sum() == self.num_masking_patches, f"mask: {mask}, mask count {mask.sum()}" 92 | 93 | return mask 94 | 95 | 96 | if __name__ == '__main__': 97 | import pdb 98 | 99 | generator = MaskingGenerator(input_size=14, num_masking_patches=118, min_num_patches=16, ) 100 | for i in range(10000000): 101 | # for i in range(1): 102 | mask = generator() 103 | if mask.sum() != 118: 104 | pdb.set_trace() 105 | print(mask) 106 | print(mask.sum()) -------------------------------------------------------------------------------- /utils/utils_model.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | import torch 4 | from utils import utils_image as util 5 | import re 6 | import glob 7 | import os 8 | 9 | 10 | ''' 11 | # -------------------------------------------- 12 | # Model evaluation scripts 13 | # -------------------------------------------- 14 | ''' 15 | 16 | 17 | def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): 18 | """ 19 | # --------------------------------------- 20 | # Kai Zhang (github: https://github.com/cszn) 21 | # 03/Mar/2019 22 | # --------------------------------------- 23 | Args: 24 | save_dir: model folder 25 | net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' 26 | pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path 27 | 28 | Return: 29 | init_iter: iteration number 30 | init_path: model path 31 | # --------------------------------------- 32 | """ 33 | 34 | file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) 35 | if file_list: 36 | iter_exist = [] 37 | for file_ in file_list: 38 | iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) 39 | iter_exist.append(int(iter_current[0])) 40 | init_iter = max(iter_exist) 41 | init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) 42 | else: 43 | init_iter = 0 44 | init_path = pretrained_path 45 | return init_iter, init_path 46 | 47 | 48 | def test_mode(model, L, mode=0, refield=32, min_size=256, sf=1, modulo=1): 49 | ''' 50 | # --------------------------------------- 51 | # Kai Zhang (github: https://github.com/cszn) 52 | # 03/Mar/2019 53 | # --------------------------------------- 54 | Args: 55 | model: trained model 56 | L: input Low-quality image 57 | mode: 58 | (0) normal: test(model, L) 59 | (1) pad: test_pad(model, L, modulo=16) 60 | (2) split: test_split(model, L, refield=32, min_size=256, sf=1, modulo=1) 61 | (3) x8: test_x8(model, L, modulo=1) ^_^ 62 | (4) split and x8: test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1) 63 | refield: effective receptive filed of the network, 32 is enough 64 | useful when split, i.e., mode=2, 4 65 | min_size: min_sizeXmin_size image, e.g., 256X256 image 66 | useful when split, i.e., mode=2, 4 67 | sf: scale factor for super-resolution, otherwise 1 68 | modulo: 1 if split 69 | useful when pad, i.e., mode=1 70 | 71 | Returns: 72 | E: estimated image 73 | # --------------------------------------- 74 | ''' 75 | if mode == 0: 76 | E = test(model, L) 77 | elif mode == 1: 78 | E = test_pad(model, L, modulo, sf) 79 | elif mode == 2: 80 | E = test_split(model, L, refield, min_size, sf, modulo) 81 | elif mode == 3: 82 | E = test_x8(model, L, modulo, sf) 83 | elif mode == 4: 84 | E = test_split_x8(model, L, refield, min_size, sf, modulo) 85 | return E 86 | 87 | 88 | ''' 89 | # -------------------------------------------- 90 | # normal (0) 91 | # -------------------------------------------- 92 | ''' 93 | 94 | 95 | def test(model, L): 96 | E = model(L) 97 | return E 98 | 99 | 100 | ''' 101 | # -------------------------------------------- 102 | # pad (1) 103 | # -------------------------------------------- 104 | ''' 105 | 106 | 107 | def test_pad(model, L, modulo=16, sf=1): 108 | h, w = L.size()[-2:] 109 | paddingBottom = int(np.ceil(h/modulo)*modulo-h) 110 | paddingRight = int(np.ceil(w/modulo)*modulo-w) 111 | L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) 112 | E = model(L) 113 | E = E[..., :h*sf, :w*sf] 114 | return E 115 | 116 | 117 | ''' 118 | # -------------------------------------------- 119 | # split (function) 120 | # -------------------------------------------- 121 | ''' 122 | 123 | 124 | def test_split_fn(model, L, refield=32, min_size=256, sf=1, modulo=1): 125 | """ 126 | Args: 127 | model: trained model 128 | L: input Low-quality image 129 | refield: effective receptive filed of the network, 32 is enough 130 | min_size: min_sizeXmin_size image, e.g., 256X256 image 131 | sf: scale factor for super-resolution, otherwise 1 132 | modulo: 1 if split 133 | 134 | Returns: 135 | E: estimated result 136 | """ 137 | h, w = L.size()[-2:] 138 | if h*w <= min_size**2: 139 | L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) 140 | E = model(L) 141 | E = E[..., :h*sf, :w*sf] 142 | else: 143 | top = slice(0, (h//2//refield+1)*refield) 144 | bottom = slice(h - (h//2//refield+1)*refield, h) 145 | left = slice(0, (w//2//refield+1)*refield) 146 | right = slice(w - (w//2//refield+1)*refield, w) 147 | Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] 148 | 149 | if h * w <= 4*(min_size**2): 150 | Es = [model(Ls[i]) for i in range(4)] 151 | else: 152 | Es = [test_split_fn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] 153 | 154 | b, c = Es[0].size()[:2] 155 | E = torch.zeros(b, c, sf * h, sf * w).type_as(L) 156 | 157 | E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] 158 | E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] 159 | E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] 160 | E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] 161 | return E 162 | 163 | 164 | ''' 165 | # -------------------------------------------- 166 | # split (2) 167 | # -------------------------------------------- 168 | ''' 169 | 170 | 171 | def test_split(model, L, refield=32, min_size=256, sf=1, modulo=1): 172 | E = test_split_fn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) 173 | return E 174 | 175 | 176 | ''' 177 | # -------------------------------------------- 178 | # x8 (3) 179 | # -------------------------------------------- 180 | ''' 181 | 182 | 183 | def test_x8(model, L, modulo=1, sf=1): 184 | E_list = [test_pad(model, util.augment_img_tensor4(L, mode=i), modulo=modulo, sf=sf) for i in range(8)] 185 | for i in range(len(E_list)): 186 | if i == 3 or i == 5: 187 | E_list[i] = util.augment_img_tensor4(E_list[i], mode=8 - i) 188 | else: 189 | E_list[i] = util.augment_img_tensor4(E_list[i], mode=i) 190 | output_cat = torch.stack(E_list, dim=0) 191 | E = output_cat.mean(dim=0, keepdim=False) 192 | return E 193 | 194 | 195 | ''' 196 | # -------------------------------------------- 197 | # split and x8 (4) 198 | # -------------------------------------------- 199 | ''' 200 | 201 | 202 | def test_split_x8(model, L, refield=32, min_size=256, sf=1, modulo=1): 203 | E_list = [test_split_fn(model, util.augment_img_tensor4(L, mode=i), refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(8)] 204 | for k, i in enumerate(range(len(E_list))): 205 | if i==3 or i==5: 206 | E_list[k] = util.augment_img_tensor4(E_list[k], mode=8-i) 207 | else: 208 | E_list[k] = util.augment_img_tensor4(E_list[k], mode=i) 209 | output_cat = torch.stack(E_list, dim=0) 210 | E = output_cat.mean(dim=0, keepdim=False) 211 | return E 212 | 213 | 214 | ''' 215 | # -------------------------------------------- 216 | # print 217 | # -------------------------------------------- 218 | ''' 219 | 220 | 221 | # -------------------------------------------- 222 | # print model 223 | # -------------------------------------------- 224 | def print_model(model): 225 | msg = describe_model(model) 226 | print(msg) 227 | 228 | 229 | # -------------------------------------------- 230 | # print params 231 | # -------------------------------------------- 232 | def print_params(model): 233 | msg = describe_params(model) 234 | print(msg) 235 | 236 | 237 | ''' 238 | # -------------------------------------------- 239 | # information 240 | # -------------------------------------------- 241 | ''' 242 | 243 | 244 | # -------------------------------------------- 245 | # model inforation 246 | # -------------------------------------------- 247 | def info_model(model): 248 | msg = describe_model(model) 249 | return msg 250 | 251 | 252 | # -------------------------------------------- 253 | # params inforation 254 | # -------------------------------------------- 255 | def info_params(model): 256 | msg = describe_params(model) 257 | return msg 258 | 259 | 260 | ''' 261 | # -------------------------------------------- 262 | # description 263 | # -------------------------------------------- 264 | ''' 265 | 266 | 267 | # -------------------------------------------- 268 | # model name and total number of parameters 269 | # -------------------------------------------- 270 | def describe_model(model): 271 | if isinstance(model, torch.nn.DataParallel): 272 | model = model.module 273 | msg = '\n' 274 | msg += 'models name: {}'.format(model.__class__.__name__) + '\n' 275 | msg += 'Params number: {}'.format(sum(map(lambda x: x.numel(), model.parameters()))) + '\n' 276 | msg += 'Net structure:\n{}'.format(str(model)) + '\n' 277 | return msg 278 | 279 | 280 | # -------------------------------------------- 281 | # parameters description 282 | # -------------------------------------------- 283 | def describe_params(model): 284 | if isinstance(model, torch.nn.DataParallel): 285 | model = model.module 286 | msg = '\n' 287 | msg += ' | {:^6s} | {:^6s} | {:^6s} | {:^6s} || {:<20s}'.format('mean', 'min', 'max', 'std', 'shape', 'param_name') + '\n' 288 | for name, param in model.state_dict().items(): 289 | if not 'num_batches_tracked' in name: 290 | v = param.data.clone().float() 291 | msg += ' | {:>6.3f} | {:>6.3f} | {:>6.3f} | {:>6.3f} | {} || {:s}'.format(v.mean(), v.min(), v.max(), v.std(), v.shape, name) + '\n' 292 | return msg 293 | 294 | 295 | if __name__ == '__main__': 296 | 297 | class Net(torch.nn.Module): 298 | def __init__(self, in_channels=3, out_channels=3): 299 | super(Net, self).__init__() 300 | self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) 301 | 302 | def forward(self, x): 303 | x = self.conv(x) 304 | return x 305 | 306 | start = torch.cuda.Event(enable_timing=True) 307 | end = torch.cuda.Event(enable_timing=True) 308 | 309 | model = Net() 310 | model = model.eval() 311 | print_model(model) 312 | print_params(model) 313 | x = torch.randn((2,3,401,401)) 314 | torch.cuda.empty_cache() 315 | with torch.no_grad(): 316 | for mode in range(5): 317 | y = test_mode(model, x, mode, refield=32, min_size=256, sf=1, modulo=1) 318 | print(y.shape) 319 | 320 | # run utils/utils_model.py 321 | -------------------------------------------------------------------------------- /utils/utils_option.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | from datetime import datetime 4 | import json 5 | import re 6 | import glob 7 | 8 | """ 9 | # ---------------------------------------------------- 10 | # functions for parse the option 11 | # ---------------------------------------------------- 12 | """ 13 | 14 | def parse(opt_path, is_train=True): 15 | 16 | # ---------------------------------------- 17 | # remove comments starting with '//' 18 | # ---------------------------------------- 19 | json_str = '' 20 | with open(opt_path, 'r') as f: 21 | for line in f: 22 | line = line.split('//')[0] + '\n' 23 | json_str += line 24 | 25 | # ---------------------------------------- 26 | # initialize opt 27 | # ---------------------------------------- 28 | 29 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 30 | 31 | opt['opt_path'] = opt_path 32 | opt['is_train'] = is_train 33 | 34 | # ---------------------------------------- 35 | # set default 36 | # ---------------------------------------- 37 | 38 | if 'scale' not in opt: 39 | opt['scale'] = 1 40 | 41 | # ---------------------------------------- 42 | # datasets 43 | # ---------------------------------------- 44 | for phase, dataset in opt['datasets'].items(): 45 | phase = phase.split('_')[0] 46 | dataset['phase'] = phase 47 | dataset['scale'] = opt['scale'] # broadcast 48 | dataset['n_channels'] = opt['n_channels'] # broadcast 49 | if 'dataroot_GT' in dataset and dataset['dataroot_GT'] is not None: 50 | dataset['dataroot_GT'] = os.path.expanduser(dataset['dataroot_GT']) 51 | if 'dataroot_IP' in dataset and dataset['dataroot_IP'] is not None: 52 | dataset['dataroot_IP'] = os.path.expanduser(dataset['dataroot_IP']) 53 | 54 | # ---------------------------------------- 55 | # path 56 | # ---------------------------------------- 57 | for key, path in opt['path'].items(): 58 | if path and key in opt['path']: 59 | opt['path'][key] = os.path.expanduser(path) 60 | 61 | path_task = os.path.join(opt['path']['root'], opt['task']) 62 | opt['path']['task'] = path_task 63 | opt['path']['log'] = os.path.join(path_task, 'log') 64 | opt['path']['options'] = os.path.join(path_task, 'options') 65 | opt['path']['models'] = os.path.join(path_task, 'models') 66 | 67 | if is_train: 68 | opt['path']['images'] = os.path.join(path_task, 'images') 69 | else: # test 70 | opt['path']['images'] = os.path.join(path_task, 'test_images') 71 | 72 | # ---------------------------------------- 73 | # network 74 | # ---------------------------------------- 75 | opt['netG']['scale'] = opt['scale'] if 'scale' in opt else 1 76 | 77 | # ---------------------------------------- 78 | # GPU devices 79 | # ---------------------------------------- 80 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 81 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 82 | 83 | # ---------------------------------------- 84 | # default setting for distributeddataparallel 85 | # ---------------------------------------- 86 | if 'find_unused_parameters' not in opt: 87 | opt['find_unused_parameters'] = True 88 | if 'use_static_graph' not in opt: 89 | opt['use_static_graph'] = False 90 | if 'dist' not in opt: 91 | opt['dist'] = False 92 | opt['num_gpu'] = len(opt['gpu_ids']) 93 | 94 | # ---------------------------------------- 95 | # default setting for perceptual loss 96 | # ---------------------------------------- 97 | if 'F_feature_layer' not in opt['train']: 98 | opt['train']['F_feature_layer'] = 34 # 25; [2,7,16,25,34] 99 | if 'F_weights' not in opt['train']: 100 | opt['train']['F_weights'] = 1.0 # 1.0; [0.1,0.1,1.0,1.0,1.0] 101 | if 'F_lossfn_type' not in opt['train']: 102 | opt['train']['F_lossfn_type'] = 'l1' 103 | if 'F_use_input_norm' not in opt['train']: 104 | opt['train']['F_use_input_norm'] = True 105 | if 'F_use_range_norm' not in opt['train']: 106 | opt['train']['F_use_range_norm'] = False 107 | 108 | # ---------------------------------------- 109 | # default setting for optimizer 110 | # ---------------------------------------- 111 | if 'G_optimizer_type' not in opt['train']: 112 | opt['train']['G_optimizer_type'] = "adam" 113 | if 'G_optimizer_betas' not in opt['train']: 114 | opt['train']['G_optimizer_betas'] = [0.9,0.999] 115 | if 'G_scheduler_restart_weights' not in opt['train']: 116 | opt['train']['G_scheduler_restart_weights'] = 1 117 | if 'G_optimizer_wd' not in opt['train']: 118 | opt['train']['G_optimizer_wd'] = 1e-4 119 | if 'G_optimizer_reuse' not in opt['train']: 120 | opt['train']['G_optimizer_reuse'] = False 121 | if 'netD' in opt and 'D_optimizer_reuse' not in opt['train']: 122 | opt['train']['D_optimizer_reuse'] = False 123 | 124 | # ---------------------------------------- 125 | # default setting of strict for model loading 126 | # ---------------------------------------- 127 | if 'G_param_strict' not in opt['train']: 128 | opt['train']['G_param_strict'] = True 129 | if 'netD' in opt and 'D_param_strict' not in opt['path']: 130 | opt['train']['D_param_strict'] = True 131 | if 'E_param_strict' not in opt['path']: 132 | opt['train']['E_param_strict'] = True 133 | 134 | # ---------------------------------------- 135 | # Exponential Moving Average 136 | # ---------------------------------------- 137 | if 'E_decay' not in opt['train']: 138 | opt['train']['E_decay'] = 0 139 | 140 | # ---------------------------------------- 141 | # default setting for discriminator 142 | # ---------------------------------------- 143 | if 'netD' in opt: 144 | if 'net_type' not in opt['netD']: 145 | opt['netD']['net_type'] = 'discriminator_patchgan' # discriminator_unet 146 | if 'in_nc' not in opt['netD']: 147 | opt['netD']['in_nc'] = 3 148 | if 'base_nc' not in opt['netD']: 149 | opt['netD']['base_nc'] = 64 150 | if 'n_layers' not in opt['netD']: 151 | opt['netD']['n_layers'] = 3 152 | if 'norm_type' not in opt['netD']: 153 | opt['netD']['norm_type'] = 'spectral' 154 | 155 | 156 | return opt 157 | 158 | def find_last_checkpoint(save_dir, net_type='G', pretrained_path=None): 159 | """ 160 | Args: 161 | save_dir: model folder 162 | net_type: 'G' or 'D' or 'optimizerG' or 'optimizerD' 163 | pretrained_path: pretrained model path. If save_dir does not have any model, load from pretrained_path 164 | 165 | Return: 166 | init_iter: iteration number 167 | init_path: model path 168 | """ 169 | file_list = glob.glob(os.path.join(save_dir, '*_{}.pth'.format(net_type))) 170 | if file_list: 171 | iter_exist = [] 172 | for file_ in file_list: 173 | iter_current = re.findall(r"(\d+)_{}.pth".format(net_type), file_) 174 | iter_exist.append(int(iter_current[0])) 175 | init_iter = max(iter_exist) 176 | init_path = os.path.join(save_dir, '{}_{}.pth'.format(init_iter, net_type)) 177 | else: 178 | init_iter = 0 179 | init_path = pretrained_path 180 | return init_iter, init_path 181 | 182 | """ 183 | # ---------------------------------------------------- 184 | # convert the configuration into files 185 | # ---------------------------------------------------- 186 | """ 187 | 188 | def get_timestamp(): 189 | return datetime.now().strftime('_%y%m%d_%H%M%S') 190 | 191 | def save(opt): 192 | opt_path = opt['opt_path'] 193 | opt_path_copy = opt['path']['options'] 194 | dirname, filename_ext = os.path.split(opt_path) 195 | filename, ext = os.path.splitext(filename_ext) 196 | dump_path = os.path.join(opt_path_copy, filename+get_timestamp()+ext) 197 | with open(dump_path, 'w') as dump_file: 198 | json.dump(opt, dump_file, indent=2) 199 | 200 | ''' 201 | # -------------------------------------------- 202 | # dict to string for logger 203 | # -------------------------------------------- 204 | ''' 205 | 206 | 207 | def dict2str(opt, indent_l=1): 208 | msg = '' 209 | for k, v in opt.items(): 210 | if isinstance(v, dict): 211 | msg += ' ' * (indent_l * 2) + k + ':[\n' 212 | msg += dict2str(v, indent_l + 1) 213 | msg += ' ' * (indent_l * 2) + ']\n' 214 | else: 215 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 216 | return msg 217 | 218 | 219 | ''' 220 | # -------------------------------------------- 221 | # convert OrderedDict to NoneDict, 222 | # return None for missing key 223 | # -------------------------------------------- 224 | ''' 225 | 226 | 227 | def dict_to_nonedict(opt): 228 | if isinstance(opt, dict): 229 | new_opt = dict() 230 | for key, sub_opt in opt.items(): 231 | new_opt[key] = dict_to_nonedict(sub_opt) 232 | return NoneDict(**new_opt) 233 | elif isinstance(opt, list): 234 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 235 | else: 236 | return opt 237 | 238 | 239 | class NoneDict(dict): 240 | def __missing__(self, key): 241 | return None 242 | -------------------------------------------------------------------------------- /utils/utils_regularizers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | # -------------------------------------------- 5 | # SVD Orthogonal Regularization 6 | # -------------------------------------------- 7 | def regularizer_orth(m): 8 | """ 9 | # ---------------------------------------- 10 | # SVD Orthogonal Regularization 11 | # ---------------------------------------- 12 | # Applies regularization to the training by performing the 13 | # orthogonalization technique described in the paper 14 | # This function is to be called by the torch.nn.Module.apply() method, 15 | # which applies svd_orthogonalization() to every layer of the model. 16 | # usage: net.apply(regularizer_orth) 17 | # ---------------------------------------- 18 | """ 19 | classname = m.__class__.__name__ 20 | if classname.find('Conv') != -1: 21 | w = m.weight.data.clone() 22 | c_out, c_in, f1, f2 = w.size() 23 | # dtype = m.weight.data.type() 24 | w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) 25 | # self.netG.apply(svd_orthogonalization) 26 | u, s, v = torch.svd(w) 27 | s[s > 1.5] = s[s > 1.5] - 1e-4 28 | s[s < 0.5] = s[s < 0.5] + 1e-4 29 | w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) 30 | m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) 31 | else: 32 | pass 33 | 34 | 35 | # -------------------------------------------- 36 | # SVD Orthogonal Regularization 37 | # -------------------------------------------- 38 | def regularizer_orth2(m): 39 | """ 40 | # ---------------------------------------- 41 | # Applies regularization to the training by performing the 42 | # orthogonalization technique described in the paper 43 | # This function is to be called by the torch.nn.Module.apply() method, 44 | # which applies svd_orthogonalization() to every layer of the model. 45 | # usage: net.apply(regularizer_orth2) 46 | # ---------------------------------------- 47 | """ 48 | classname = m.__class__.__name__ 49 | if classname.find('Conv') != -1: 50 | w = m.weight.data.clone() 51 | c_out, c_in, f1, f2 = w.size() 52 | # dtype = m.weight.data.type() 53 | w = w.permute(2, 3, 1, 0).contiguous().view(f1*f2*c_in, c_out) 54 | u, s, v = torch.svd(w) 55 | s_mean = s.mean() 56 | s[s > 1.5*s_mean] = s[s > 1.5*s_mean] - 1e-4 57 | s[s < 0.5*s_mean] = s[s < 0.5*s_mean] + 1e-4 58 | w = torch.mm(torch.mm(u, torch.diag(s)), v.t()) 59 | m.weight.data = w.view(f1, f2, c_in, c_out).permute(3, 2, 0, 1) # .type(dtype) 60 | else: 61 | pass 62 | 63 | 64 | 65 | def regularizer_clip(m): 66 | """ 67 | # ---------------------------------------- 68 | # usage: net.apply(regularizer_clip) 69 | # ---------------------------------------- 70 | """ 71 | eps = 1e-4 72 | c_min = -1.5 73 | c_max = 1.5 74 | 75 | classname = m.__class__.__name__ 76 | if classname.find('Conv') != -1 or classname.find('Linear') != -1: 77 | w = m.weight.data.clone() 78 | w[w > c_max] -= eps 79 | w[w < c_min] += eps 80 | m.weight.data = w 81 | 82 | if m.bias is not None: 83 | b = m.bias.data.clone() 84 | b[b > c_max] -= eps 85 | b[b < c_min] += eps 86 | m.bias.data = b 87 | 88 | # elif classname.find('BatchNorm2d') != -1: 89 | # 90 | # rv = m.running_var.data.clone() 91 | # rm = m.running_mean.data.clone() 92 | # 93 | # if m.affine: 94 | # m.weight.data 95 | # m.bias.data 96 | --------------------------------------------------------------------------------