├── .gitignore ├── .idea └── .gitignore ├── LICENSE ├── README.md ├── README └── image-20230818115344581.png ├── __init__.py ├── data ├── JSH_dataset_train.py ├── JSH_dataset_val.py ├── __init__.py ├── data_sampler.py └── util.py ├── models ├── CSNorm_model.py ├── __init__.py ├── base_model.py ├── ckpts │ └── NAF_LOL.pth ├── lr_scheduler.py ├── modules │ ├── NAFNet │ │ ├── Baseline_arch.py │ │ ├── NAFNet.py │ │ ├── arch_util.py │ │ └── local_arch.py │ ├── __init__.py │ ├── common.py │ ├── loss.py │ ├── loss_new.py │ └── module_util.py └── networks.py ├── options ├── __init__.py ├── options.py ├── test │ └── test.yml └── train │ └── train_InvDN.yml ├── test.py ├── train.py └── utils ├── __init__.py ├── pytorch_ssim └── __init__.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | *.iml 162 | *.xml 163 | -------------------------------------------------------------------------------- /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # Editor-based HTTP Client requests 5 | /httpRequests/ 6 | # Datasource local storage ignored files 7 | /dataSources/ 8 | /dataSources.local.xml 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 mingde-yao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 | # [ICCV 2023 :fire:] Generalized Lightness Adaptation with Channel Selective Normalization. 5 | 6 | [Mingde Yao](https://scholar.google.com/citations?user=fsE3MzwAAAAJ&hl=en)\*, [Jie Huang](https://huangkevinj.github.io/)\*, [Xin Jin](http://home.ustc.edu.cn/~jinxustc/), [Ruikang Xu](https://scholar.google.com/citations?user=PulrrscAAAAJ&hl=en), Shenglong Zhou, [Man Zhou](https://manman1995.github.io/), [Zhiwei Xiong](http://staff.ustc.edu.cn/~zwxiong/) 7 | 8 | University of Science and Technology of China 9 | 10 | Eastern Institute of Technology 11 | 12 | Nanyang Technological University 13 | 14 | 15 | [[`Paper`](https://arxiv.org/pdf/2308.13783.pdf)] [[`BibTeX`](#heart-citing-us)] :zap: :rocket: :fire: 16 | 17 | [![python](https://img.shields.io/badge/-Python_3.8_%7C_3.9_%7C_3.10-blue?logo=python&logoColor=white)](https://github.com/pre-commit/pre-commit) 18 | [![pytorch](https://img.shields.io/badge/PyTorch-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/) 19 | [![license](https://img.shields.io/badge/License-MIT-green.svg?labelColor=gray)](#license) 20 | 21 | :rocket: Welcome! This is the official repository of [ICCV'23] Generalized Lightness Adaptation with Channel Selective Normalization. 22 | 23 |
24 | 25 | 26 | 27 | ## 📌 Overview 28 | 29 | >Lightness adaptation is vital to the success of image processing to avoid unexpected visual deterioration, which covers multiple aspects, e.g., low-light image enhancement, image retouching, and inverse tone mapping. Existing methods typically work well on their trained lightness conditions but perform poorly in unknown ones due to their limited generalization ability. To address this limitation, we propose a novel generalized lightness adaptation algorithm that extends conventional normalization techniques through a channel filtering design, dubbed Channel Selective Normalization (CSNorm). The proposed CSNorm purposely normalizes the statistics of lightness-relevant channels and keeps other channels unchanged, so as to improve feature generalization and discrimination. To optimize CSNorm, we propose an alternating training strategy that effectively identifies lightness-relevant channels. The model equipped with our CSNorm only needs to be trained on one lightness condition and can be well generalized to unknown lightness conditions. Experimental results on multiple benchmark datasets demonstrate the effectiveness of CSNorm in enhancing the generalization ability for the existing lightness adaptation methods. 30 | 31 | 32 | ![image](https://github.com/mdyao/CSNorm/assets/33108887/f4c9b327-51fa-4832-8069-ab6919100277) 33 | 34 | Overview of our proposed method. (a) Channel selective normalization (CSNorm), which consists of an instance-level normalization module and a differential gating module. (b) Differential gating module. It outputs a series of on-off switch gates for binarized channel selection in CSNorm. (c) Alternating training strategy. In the first step, we optimize the parameters outside CSNorm to keep an essential ability for lightness adaptation. In the second step, we only update the parameters inside the CSNorm (see (a)&(b)) with lightness-perturbed images. The two steps drive the CSNorm to select channels sensitive to lightness changes, which are normalized in $x_{n+1}$. 35 | 36 | 37 | 43 | 44 | 45 | 46 | ## :sunflower: Results 47 | 48 | ![image-20230818115344581](README/image-20230818115344581.png) 49 | 50 | Visual comparisons of the generalized image retouching on the MIT-Adobe FiveK dataset. The models are trained on the original dataset and tested on the unseen lightness condition. 51 | 52 | 53 | 54 | ## :rocket: Usage 55 | 56 | 57 | 58 | 59 | To train the model equipped with CSNorm: 60 | 61 | 1. Modify the paths for training and testing in the configuration file (options/train/train_InvDN.yml). 62 | 2. Execute the command "python train.py -opt options/train/train_InvDN.yml". 63 | 3. Drink a cup of coffee or have a nice sleep. 64 | 4. Get the trained model. 65 | 66 | 67 | We employ the [NAFNet](https://github.com/mdyao/CSNorm/blob/62056d2ba45c6ab356a29e4a155d2f72c4c87beb/models/modules/NAFNet/NAFNet.py) as our base model, demonstrating the integration of CSNorm. 68 | 69 | Feel free to replace NAFNet with your preferred backbone when incorporating CSNorm: 70 | 71 | 72 | 1. Define the on-off switch gate function, where CHANNEL_NUM should be pre-defined. 73 | 74 | ``` 75 | 76 | class Generate_gate(nn.Module): 77 | def __init__(self): 78 | super(Generate_gate, self).__init__() 79 | self.proj = nn.Sequential(nn.AdaptiveAvgPool2d(1), 80 | nn.Conv2d(CHANNEL_NUM,CHANNEL_NUM, 1), 81 | nn.ReLU(), 82 | nn.Conv2d(CHANNEL_NUM,CHANNEL_NUM, 1), 83 | nn.ReLU()) 84 | 85 | self.epsilon = 1e-8 86 | def forward(self, x): 87 | 88 | 89 | alpha = self.proj(x) 90 | gate = (alpha**2) / (alpha**2 + self.epsilon) 91 | 92 | return gate 93 | 94 | def freeze(layer): 95 | for child in layer.children(): 96 | for param in child.parameters(): 97 | param.requires_grad = False 98 | 99 | 100 | def freeze_direct(layer): 101 | for param in layer.parameters(): 102 | param.requires_grad = False 103 | 104 | ``` 105 | 106 | 2. Initialize CSNorm in the `__init__()` Method, where CHANNEL_NUM should be pre-defined.: 107 | 108 | ``` 109 | self.gate = Generate_gate() 110 | for i in range(CHANNEL_NUM): 111 | setattr(self, 'CSN_' + str(i), nn.InstanceNorm2d(1, affine=True)) 112 | freeze_direct(getattr(self, 'CSN_' + str(i))) 113 | freeze(self.gate) 114 | ``` 115 | 116 | 3. Integrate the Code in the `forward()` Method of Your Backbone, where CHANNEL_NUM should be pre-defined. 117 | 118 | ``` 119 | x = conv(x) 120 | ... 121 | gate = self.gate(x) 122 | lq_copy = torch.cat([getattr(self, 'CSN_' + str(i))(x[:,i,:,:][:,None,:,:]) for i in range(CHANNEL_NUM)], dim=1) 123 | x = gate * lq_copy + (1-gate) * x 124 | ``` 125 | 126 | 4. The 2-step training strategy is in https://github.com/mdyao/CSNorm/blob/main/models/CSNorm_model.py. 127 | 128 | ## :heart: Citing Us 129 | If you find this repository or our work useful, please consider giving a star :star: and citation :t-rex: , which would be greatly appreciated: 130 | 131 | ```bibtex 132 | @inproceedings{yao2023csnorm, 133 | title={Generalized Lightness Adaptation with Channel Selective Normalization}, 134 | author={Mingde Yao, Jie Huang, Xin Jin, Ruikang Xu, Shenglong Zhou, Man Zhou, and Zhiwei Xiong}, 135 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 136 | year={2023} 137 | } 138 | ``` 139 | 140 | 141 | ## :email: Contact 142 | 143 | 144 | 145 | For any inquiries or questions, please contact me by email (mdyao@mail.ustc.edu.cn). 146 | -------------------------------------------------------------------------------- /README/image-20230818115344581.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/README/image-20230818115344581.png -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/__init__.py -------------------------------------------------------------------------------- /data/JSH_dataset_train.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import h5py 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import glob 9 | import os 10 | 11 | class JSHDataset(data.Dataset): 12 | ''' 13 | Read LQ (Low Quality, here is LR), GT and noisy image pairs. 14 | If only GT and noisy images are provided, generate LQ image on-the-fly. 15 | The pair is ensured by 'sorted' function, so please check the name convention. 16 | ''' 17 | 18 | def __init__(self, opt): 19 | super(JSHDataset, self).__init__() 20 | self.opt = opt 21 | self.data_type = self.opt['data_type'] 22 | self.gtimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_gt'], '*'))) 23 | self.inputimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_lq'], '*'))) 24 | self.length = len(self.gtimglist) 25 | 26 | def __getitem__(self, index): 27 | self.input = cv2.imread(self.inputimglist[index]) 28 | self.gt = cv2.imread(self.gtimglist[index]) 29 | GT_size = self.opt['GT_size'] 30 | 31 | # get GT image 32 | input_img = self.input/255.0 33 | gt_img = self.gt/255.0 34 | input_img = input_img.transpose(2,0,1) 35 | gt_img = gt_img.transpose(2,0,1) 36 | 37 | if self.opt['phase'] == 'train': 38 | C, H, W = input_img.shape 39 | x = random.randint(0, W - GT_size) 40 | y = random.randint(0, H - GT_size) 41 | input_img = input_img[:, y:y + GT_size, x:x + GT_size] 42 | gt_img = gt_img[:, y:y + GT_size, x:x + GT_size] 43 | 44 | # augmentation - flip, rotate 45 | input_img, gt_img = util.augment([input_img, gt_img], self.opt['use_flip'], 46 | self.opt['use_rot']) 47 | 48 | # BGR to RGB, HWC to CHW, numpy to tensor 49 | input_img = torch.from_numpy(np.ascontiguousarray(input_img)).float() 50 | gt_img = torch.from_numpy(np.ascontiguousarray(gt_img)).float() 51 | 52 | return {'gt_img': gt_img, 'lq_img': input_img} 53 | 54 | def __len__(self): 55 | return self.length 56 | 57 | 58 | -------------------------------------------------------------------------------- /data/JSH_dataset_val.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import h5py 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | import glob 9 | import os 10 | 11 | class JSHDataset(data.Dataset): 12 | ''' 13 | Read LQ (Low Quality, here is LR), GT and noisy image pairs. 14 | If only GT and noisy images are provided, generate LQ image on-the-fly. 15 | The pair is ensured by 'sorted' function, so please check the name convention. 16 | ''' 17 | 18 | def __init__(self, opt): 19 | super(JSHDataset, self).__init__() 20 | self.opt = opt 21 | self.data_type = self.opt['data_type'] 22 | self.gtimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_gt'], '*'))) 23 | self.inputimglist = sorted(glob.glob(os.path.join(self.opt['dataroot_lq'], '*'))) 24 | self.length = len(self.gtimglist) 25 | 26 | def __getitem__(self, index): 27 | self.input = cv2.imread(self.inputimglist[index]) 28 | self.gt = cv2.imread(self.gtimglist[index]) 29 | GT_size = self.opt['GT_size'] 30 | 31 | # get GT image 32 | input_img = self.input/255.0 33 | gt_img = self.gt/255.0 34 | input_img = input_img.transpose(2,0,1) 35 | gt_img = gt_img.transpose(2,0,1) 36 | 37 | if self.opt['phase'] == 'train': 38 | C, H, W = input_img.shape 39 | x = random.randint(0, W - GT_size) 40 | y = random.randint(0, H - GT_size) 41 | # input_img = input_img[:, y:y + GT_size, x:x + GT_size] 42 | 43 | # augmentation - flip, rotate 44 | input_img, gt_img = util.augment([input_img, gt_img], self.opt['use_flip'], 45 | self.opt['use_rot']) 46 | # BGR to RGB, HWC to CHW, numpy to tensor 47 | input_img = torch.from_numpy(np.ascontiguousarray(input_img)).float() 48 | gt_img = torch.from_numpy(np.ascontiguousarray(gt_img)).float() 49 | 50 | return {'gt_img': gt_img, 'lq_img': input_img} 51 | 52 | def __len__(self): 53 | return self.length 54 | 55 | 56 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=6, 25 | pin_memory=False) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | if mode == 'JSH_train': 31 | from data.JSH_dataset_train import JSHDataset as D 32 | elif mode =='JSH_val': 33 | from data.JSH_dataset_val import JSHDataset as D 34 | else: 35 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 36 | dataset = D(dataset_opt) 37 | 38 | logger = logging.getLogger('base') 39 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 40 | dataset_opt['name'])) 41 | return dataset 42 | -------------------------------------------------------------------------------- /data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | import random 5 | import numpy as np 6 | import torch 7 | import cv2 8 | import h5py 9 | 10 | #################### 11 | # Files & IO 12 | #################### 13 | 14 | ###################### get image path list ###################### 15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | 22 | def _get_paths_from_images(path): 23 | '''get image path list from image folder''' 24 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 25 | images = [] 26 | for dirpath, _, fnames in sorted(os.walk(path)): 27 | for fname in sorted(fnames): 28 | if is_image_file(fname): 29 | img_path = os.path.join(dirpath, fname) 30 | images.append(img_path) 31 | assert images, '{:s} has no valid image file'.format(path) 32 | return images 33 | 34 | 35 | def _get_paths_from_lmdb(dataroot): 36 | '''get image path list from lmdb meta info''' 37 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) 38 | paths = meta_info['keys'] 39 | sizes = meta_info['resolution'] 40 | if len(sizes) == 1: 41 | sizes = sizes * len(paths) 42 | return paths, sizes 43 | 44 | def _get_paths_from_mat(dataroot): 45 | '''get image path list from lmdb meta info''' 46 | meta_info = h5py.File(os.path.join(dataroot), 'r') 47 | key = meta_info.keys()[0] 48 | sizes = meta_info[0][-2::].size() 49 | return key, sizes 50 | 51 | def get_image_paths(data_type, dataroot): 52 | '''get image path list 53 | support lmdb or image files''' 54 | paths, sizes = None, None 55 | if dataroot is not None: 56 | if data_type == 'mat': 57 | paths, sizes = _get_paths_from_mat(dataroot) 58 | elif data_type == 'img': 59 | paths = sorted(_get_paths_from_images(dataroot)) 60 | else: 61 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) 62 | return paths, sizes 63 | 64 | 65 | ###################### read images ###################### 66 | def _read_img_lmdb(env, key, size): 67 | '''read image from lmdb with key (w/ and w/o fixed size) 68 | size: (C, H, W) tuple''' 69 | with env.begin(write=False) as txn: 70 | buf = txn.get(key.encode('ascii')) 71 | img_flat = np.frombuffer(buf, dtype=np.uint8) 72 | C, H, W = size 73 | img = img_flat.reshape(H, W, C) 74 | return img 75 | 76 | 77 | def read_img(env, path, size=None): 78 | '''read image by cv2 or from lmdb 79 | return: Numpy float32, HWC, BGR, [0,1]''' 80 | if env is None: # img 81 | #img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 82 | img = cv2.imread(path, cv2.IMREAD_COLOR) 83 | else: 84 | img = _read_img_lmdb(env, path, size) 85 | img = img.astype(np.float32) / 255. 86 | if img.ndim == 2: 87 | img = np.expand_dims(img, axis=2) 88 | # some images have 4 channels 89 | if img.shape[2] > 3: 90 | img = img[:, :, :3] 91 | return img 92 | 93 | def read_img_array(img): 94 | '''read image array and preprocess 95 | return: Numpy float32, HWC, BGR, [0,1]''' 96 | img = img.astype(np.float32) / 255. 97 | if img.ndim == 2: 98 | img = np.expand_dims(img, axis=2) 99 | return img 100 | 101 | #################### 102 | # image processing 103 | # process on numpy image 104 | #################### 105 | 106 | 107 | def augment(img_list, hflip=True, rot=True): 108 | # horizontal flip OR rotate 109 | hflip = hflip and random.random() < 0.5 110 | vflip = rot and random.random() < 0.5 111 | rot90 = rot and random.random() < 0.5 112 | 113 | def _augment(img): 114 | if isinstance(img, list): 115 | if hflip: 116 | img = [image[:, ::-1, :] for image in img] 117 | if vflip: 118 | img = [image[:, :, ::-1] for image in img] 119 | if rot90: 120 | img = [image.transpose(0, 2, 1) for image in img] 121 | else: 122 | if hflip: 123 | img = img[:, ::-1, :] 124 | if vflip: 125 | img = img[:, :, ::-1] 126 | if rot90: 127 | img = img.transpose(0, 2, 1) 128 | return img 129 | 130 | return [_augment(img) for img in img_list] 131 | 132 | 133 | def augment_flow(img_list, flow_list, hflip=True, rot=True): 134 | # horizontal flip OR rotate 135 | hflip = hflip and random.random() < 0.5 136 | vflip = rot and random.random() < 0.5 137 | rot90 = rot and random.random() < 0.5 138 | 139 | def _augment(img): 140 | if hflip: 141 | img = img[:, ::-1, :] 142 | if vflip: 143 | img = img[::-1, :, :] 144 | if rot90: 145 | img = img.transpose(1, 0, 2) 146 | return img 147 | 148 | def _augment_flow(flow): 149 | if hflip: 150 | flow = flow[:, ::-1, :] 151 | flow[:, :, 0] *= -1 152 | if vflip: 153 | flow = flow[::-1, :, :] 154 | flow[:, :, 1] *= -1 155 | if rot90: 156 | flow = flow.transpose(1, 0, 2) 157 | flow = flow[:, :, [1, 0]] 158 | return flow 159 | 160 | rlt_img_list = [_augment(img) for img in img_list] 161 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list] 162 | 163 | return rlt_img_list, rlt_flow_list 164 | 165 | 166 | def channel_convert(in_c, tar_type, img_list): 167 | # conversion among BGR, gray and y 168 | if in_c == 3 and tar_type == 'gray': # BGR to gray 169 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] 170 | return [np.expand_dims(img, axis=2) for img in gray_list] 171 | elif in_c == 3 and tar_type == 'y': # BGR to y 172 | y_list = [bgr2ycbcr(img, only_y=False) for img in img_list] 173 | return y_list 174 | # return [np.expand_dims(img, axis=2) for img in y_list] 175 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR 176 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] 177 | else: 178 | return img_list 179 | 180 | 181 | def rgb2ycbcr(img, only_y=True): 182 | '''same as matlab rgb2ycbcr 183 | only_y: only return Y channel 184 | Input: 185 | uint8, [0, 255] 186 | float, [0, 1] 187 | ''' 188 | in_img_type = img.dtype 189 | img.astype(np.float32) 190 | if in_img_type != np.uint8: 191 | img *= 255. 192 | # convert 193 | if only_y: 194 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 195 | else: 196 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 197 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 198 | if in_img_type == np.uint8: 199 | rlt = rlt.round() 200 | else: 201 | rlt /= 255. 202 | return rlt.astype(in_img_type) 203 | 204 | 205 | def bgr2ycbcr(img, only_y=True): 206 | '''bgr version of rgb2ycbcr 207 | only_y: only return Y channel 208 | Input: 209 | uint8, [0, 255] 210 | float, [0, 1] 211 | ''' 212 | in_img_type = img.dtype 213 | img.astype(np.float32) 214 | if in_img_type != np.uint8: 215 | img *= 255. 216 | # convert 217 | if only_y: 218 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 219 | else: 220 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 221 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 222 | if in_img_type == np.uint8: 223 | rlt = rlt.round() 224 | else: 225 | rlt /= 255. 226 | return rlt.astype(in_img_type) 227 | 228 | 229 | def ycbcr2rgb(img): 230 | '''same as matlab ycbcr2rgb 231 | Input: 232 | uint8, [0, 255] 233 | float, [0, 1] 234 | ''' 235 | in_img_type = img.dtype 236 | img.astype(np.float32) 237 | if in_img_type != np.uint8: 238 | img *= 255. 239 | # convert 240 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 241 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 242 | if in_img_type == np.uint8: 243 | rlt = rlt.round() 244 | else: 245 | rlt /= 255. 246 | return rlt.astype(in_img_type) 247 | 248 | 249 | def modcrop(img_in, scale): 250 | # img_in: Numpy, CHW or HW 251 | img = np.copy(img_in) 252 | if img.ndim == 2: 253 | H, W = img.shape 254 | H_r, W_r = H % scale, W % scale 255 | img = img[:H - H_r, :W - W_r] 256 | elif img.ndim == 3: 257 | C, H, W = img.shape 258 | H_r, W_r = H % scale, W % scale 259 | img = img[:, :H - H_r, :W - W_r] 260 | else: 261 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 262 | return img 263 | 264 | 265 | #################### 266 | # Functions 267 | #################### 268 | 269 | 270 | # matlab 'imresize' function, now only support 'bicubic' 271 | def cubic(x): 272 | absx = torch.abs(x) 273 | absx2 = absx**2 274 | absx3 = absx**3 275 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ( 276 | (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (( 277 | (absx > 1) * (absx <= 2)).type_as(absx)) 278 | 279 | 280 | def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing): 281 | if (scale < 1) and (antialiasing): 282 | # Use a modified kernel to simultaneously interpolate and antialias- larger kernel width 283 | kernel_width = kernel_width / scale 284 | 285 | # Output-space coordinates 286 | x = torch.linspace(1, out_length, out_length) 287 | 288 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 289 | # in output space maps to 0.5 in input space, and 0.5+scale in output 290 | # space maps to 1.5 in input space. 291 | u = x / scale + 0.5 * (1 - 1 / scale) 292 | 293 | # What is the left-most pixel that can be involved in the computation? 294 | left = torch.floor(u - kernel_width / 2) 295 | 296 | # What is the maximum number of pixels that can be involved in the 297 | # computation? Note: it's OK to use an extra pixel here; if the 298 | # corresponding weights are all zero, it will be eliminated at the end 299 | # of this function. 300 | P = math.ceil(kernel_width) + 2 301 | 302 | # The indices of the input pixels involved in computing the k-th output 303 | # pixel are in row k of the indices matrix. 304 | indices = left.view(out_length, 1).expand(out_length, P) + torch.linspace(0, P - 1, P).view( 305 | 1, P).expand(out_length, P) 306 | 307 | # The weights used to compute the k-th output pixel are in row k of the 308 | # weights matrix. 309 | distance_to_center = u.view(out_length, 1).expand(out_length, P) - indices 310 | # apply cubic kernel 311 | if (scale < 1) and (antialiasing): 312 | weights = scale * cubic(distance_to_center * scale) 313 | else: 314 | weights = cubic(distance_to_center) 315 | # Normalize the weights matrix so that each row sums to 1. 316 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 317 | weights = weights / weights_sum.expand(out_length, P) 318 | 319 | # If a column in weights is all zero, get rid of it. only consider the first and last column. 320 | weights_zero_tmp = torch.sum((weights == 0), 0) 321 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 322 | indices = indices.narrow(1, 1, P - 2) 323 | weights = weights.narrow(1, 1, P - 2) 324 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 325 | indices = indices.narrow(1, 0, P - 2) 326 | weights = weights.narrow(1, 0, P - 2) 327 | weights = weights.contiguous() 328 | indices = indices.contiguous() 329 | sym_len_s = -indices.min() + 1 330 | sym_len_e = indices.max() - in_length 331 | indices = indices + sym_len_s - 1 332 | return weights, indices, int(sym_len_s), int(sym_len_e) 333 | 334 | 335 | def imresize(img, scale, antialiasing=True): 336 | # Now the scale should be the same for H and W 337 | # input: img: CHW RGB [0,1] 338 | # output: CHW RGB [0,1] w/o round 339 | 340 | in_C, in_H, in_W = img.size() 341 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 342 | kernel_width = 4 343 | kernel = 'cubic' 344 | 345 | # Return the desired dimension order for performing the resize. The 346 | # strategy is to perform the resize first along the dimension with the 347 | # smallest scale factor. 348 | # Now we do not support this. 349 | 350 | # get weights and indices 351 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 352 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 353 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 354 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 355 | # process H dimension 356 | # symmetric copying 357 | img_aug = torch.FloatTensor(in_C, in_H + sym_len_Hs + sym_len_He, in_W) 358 | img_aug.narrow(1, sym_len_Hs, in_H).copy_(img) 359 | 360 | sym_patch = img[:, :sym_len_Hs, :] 361 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 362 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 363 | img_aug.narrow(1, 0, sym_len_Hs).copy_(sym_patch_inv) 364 | 365 | sym_patch = img[:, -sym_len_He:, :] 366 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 367 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 368 | img_aug.narrow(1, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 369 | 370 | out_1 = torch.FloatTensor(in_C, out_H, in_W) 371 | kernel_width = weights_H.size(1) 372 | for i in range(out_H): 373 | idx = int(indices_H[i][0]) 374 | out_1[0, i, :] = img_aug[0, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 375 | out_1[1, i, :] = img_aug[1, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 376 | out_1[2, i, :] = img_aug[2, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_H[i]) 377 | 378 | # process W dimension 379 | # symmetric copying 380 | out_1_aug = torch.FloatTensor(in_C, out_H, in_W + sym_len_Ws + sym_len_We) 381 | out_1_aug.narrow(2, sym_len_Ws, in_W).copy_(out_1) 382 | 383 | sym_patch = out_1[:, :, :sym_len_Ws] 384 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 385 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 386 | out_1_aug.narrow(2, 0, sym_len_Ws).copy_(sym_patch_inv) 387 | 388 | sym_patch = out_1[:, :, -sym_len_We:] 389 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 390 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 391 | out_1_aug.narrow(2, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 392 | 393 | out_2 = torch.FloatTensor(in_C, out_H, out_W) 394 | kernel_width = weights_W.size(1) 395 | for i in range(out_W): 396 | idx = int(indices_W[i][0]) 397 | out_2[0, :, i] = out_1_aug[0, :, idx:idx + kernel_width].mv(weights_W[i]) 398 | out_2[1, :, i] = out_1_aug[1, :, idx:idx + kernel_width].mv(weights_W[i]) 399 | out_2[2, :, i] = out_1_aug[2, :, idx:idx + kernel_width].mv(weights_W[i]) 400 | 401 | return out_2 402 | 403 | 404 | def imresize_np(img, scale, antialiasing=True): 405 | # Now the scale should be the same for H and W 406 | # input: img: Numpy, HWC BGR [0,1] 407 | # output: HWC BGR [0,1] w/o round 408 | img = torch.from_numpy(img) 409 | 410 | in_H, in_W, in_C = img.size() 411 | _, out_H, out_W = in_C, math.ceil(in_H * scale), math.ceil(in_W * scale) 412 | kernel_width = 4 413 | kernel = 'cubic' 414 | 415 | # Return the desired dimension order for performing the resize. The 416 | # strategy is to perform the resize first along the dimension with the 417 | # smallest scale factor. 418 | # Now we do not support this. 419 | 420 | # get weights and indices 421 | weights_H, indices_H, sym_len_Hs, sym_len_He = calculate_weights_indices( 422 | in_H, out_H, scale, kernel, kernel_width, antialiasing) 423 | weights_W, indices_W, sym_len_Ws, sym_len_We = calculate_weights_indices( 424 | in_W, out_W, scale, kernel, kernel_width, antialiasing) 425 | # process H dimension 426 | # symmetric copying 427 | img_aug = torch.FloatTensor(in_H + sym_len_Hs + sym_len_He, in_W, in_C) 428 | img_aug.narrow(0, sym_len_Hs, in_H).copy_(img) 429 | 430 | sym_patch = img[:sym_len_Hs, :, :] 431 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 432 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 433 | img_aug.narrow(0, 0, sym_len_Hs).copy_(sym_patch_inv) 434 | 435 | sym_patch = img[-sym_len_He:, :, :] 436 | inv_idx = torch.arange(sym_patch.size(0) - 1, -1, -1).long() 437 | sym_patch_inv = sym_patch.index_select(0, inv_idx) 438 | img_aug.narrow(0, sym_len_Hs + in_H, sym_len_He).copy_(sym_patch_inv) 439 | 440 | out_1 = torch.FloatTensor(out_H, in_W, in_C) 441 | kernel_width = weights_H.size(1) 442 | for i in range(out_H): 443 | idx = int(indices_H[i][0]) 444 | out_1[i, :, 0] = img_aug[idx:idx + kernel_width, :, 0].transpose(0, 1).mv(weights_H[i]) 445 | out_1[i, :, 1] = img_aug[idx:idx + kernel_width, :, 1].transpose(0, 1).mv(weights_H[i]) 446 | out_1[i, :, 2] = img_aug[idx:idx + kernel_width, :, 2].transpose(0, 1).mv(weights_H[i]) 447 | 448 | # process W dimension 449 | # symmetric copying 450 | out_1_aug = torch.FloatTensor(out_H, in_W + sym_len_Ws + sym_len_We, in_C) 451 | out_1_aug.narrow(1, sym_len_Ws, in_W).copy_(out_1) 452 | 453 | sym_patch = out_1[:, :sym_len_Ws, :] 454 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 455 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 456 | out_1_aug.narrow(1, 0, sym_len_Ws).copy_(sym_patch_inv) 457 | 458 | sym_patch = out_1[:, -sym_len_We:, :] 459 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 460 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 461 | out_1_aug.narrow(1, sym_len_Ws + in_W, sym_len_We).copy_(sym_patch_inv) 462 | 463 | out_2 = torch.FloatTensor(out_H, out_W, in_C) 464 | kernel_width = weights_W.size(1) 465 | for i in range(out_W): 466 | idx = int(indices_W[i][0]) 467 | out_2[:, i, 0] = out_1_aug[:, idx:idx + kernel_width, 0].mv(weights_W[i]) 468 | out_2[:, i, 1] = out_1_aug[:, idx:idx + kernel_width, 1].mv(weights_W[i]) 469 | out_2[:, i, 2] = out_1_aug[:, idx:idx + kernel_width, 2].mv(weights_W[i]) 470 | 471 | return out_2.numpy() 472 | 473 | 474 | if __name__ == '__main__': 475 | # test imresize function 476 | # read images 477 | img = cv2.imread('test.png') 478 | img = img * 1.0 / 255 479 | img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float() 480 | # imresize 481 | scale = 1 / 4 482 | import time 483 | total_time = 0 484 | for i in range(10): 485 | start_time = time.time() 486 | rlt = imresize(img, scale, antialiasing=True) 487 | use_time = time.time() - start_time 488 | total_time += use_time 489 | print('average time: {}'.format(total_time / 10)) 490 | 491 | import torchvision.utils 492 | torchvision.utils.save_image((rlt * 255).round() / 255, 'rlt.png', nrow=1, padding=0, 493 | normalize=False) 494 | -------------------------------------------------------------------------------- /models/CSNorm_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.modules.loss import FFT_Loss 11 | import numpy as np 12 | import time 13 | from models.modules.loss_new import SSIMLoss 14 | import re 15 | 16 | logger = logging.getLogger('base') 17 | 18 | 19 | class CSNorm_Model(BaseModel): 20 | def __init__(self, opt): 21 | super(CSNorm_Model, self).__init__(opt) 22 | 23 | 24 | if opt['dist']: 25 | self.rank = torch.distributed.get_rank() 26 | else: 27 | self.rank = -1 # non dist training 28 | train_opt = opt['train'] 29 | test_opt = opt['test'] 30 | self.train_opt = train_opt 31 | self.test_opt = test_opt 32 | 33 | self.netG = networks.define_G(opt).to(self.device) 34 | if opt['dist']: 35 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 36 | else: 37 | self.netG = DataParallel(self.netG) 38 | 39 | ######################### set parameters in CSNorm ############################### 40 | target_layer_patterns = re.compile(r'module\.(gate\.proj|CSN_\d+)\.') 41 | # target_layer_patterns = re.compile(r'(gate\.proj|CSN_\d+)\.') 42 | 43 | self.layer_aug = [ 44 | name for name, param in self.netG.named_parameters() 45 | if target_layer_patterns.search(name) 46 | ] 47 | # print('parameters in CSNorm:',self.layer_aug) 48 | ######################### set parameters in CSNorm ############################### 49 | 50 | # loss 51 | self.Back_rec = torch.nn.L1Loss() 52 | self.ssim_loss = SSIMLoss() 53 | self.fft_loss = FFT_Loss() 54 | # self.print_network() 55 | self.load() 56 | 57 | if self.is_train: 58 | self.netG.train() 59 | 60 | # optimizers 61 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 62 | optim_params = [] 63 | optim_params_aug = [] 64 | for k, v in self.netG.named_parameters(): 65 | if v.requires_grad: 66 | optim_params.append(v) 67 | else: 68 | if self.rank <= 0: 69 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 70 | 71 | for k, v in self.netG.named_parameters(): 72 | if k in self.layer_aug: 73 | optim_params_aug.append(v) 74 | else: 75 | if self.rank <= 0: 76 | logger.warning('Params [{:s}] will not optimize in aug.'.format(k)) 77 | 78 | 79 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 80 | weight_decay=wd_G, 81 | betas=(train_opt['beta1'], train_opt['beta2'])) 82 | 83 | self.optimizer_G_aug = torch.optim.Adam(optim_params_aug, lr=train_opt['lr_G'], 84 | weight_decay=wd_G, 85 | betas=(train_opt['beta1'], train_opt['beta2'])) 86 | 87 | self.optimizers.append(self.optimizer_G) 88 | self.optimizers.append(self.optimizer_G_aug) 89 | 90 | # schedulers 91 | if train_opt['lr_scheme'] == 'MultiStepLR': 92 | for optimizer in self.optimizers: 93 | self.schedulers.append( 94 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 95 | restarts=train_opt['restarts'], 96 | weights=train_opt['restart_weights'], 97 | gamma=train_opt['lr_gamma'], 98 | clear_state=train_opt['clear_state'])) 99 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 100 | for optimizer in self.optimizers: 101 | self.schedulers.append( 102 | lr_scheduler.CosineAnnealingLR_Restart( 103 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 104 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 105 | else: 106 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 107 | 108 | self.log_dict = OrderedDict() 109 | 110 | def amp_aug(self, x, y): 111 | x = x + 1e-8 112 | y = y + 1e-8 113 | x_freq= torch.fft.rfft2(x, norm='backward') 114 | x_amp = torch.abs(x_freq) 115 | x_phase = torch.angle(x_freq) 116 | 117 | y_freq= torch.fft.rfft2(y, norm='backward') 118 | y_amp = torch.abs(y_freq) 119 | y_phase = torch.angle(y_freq) 120 | 121 | mix_alpha = torch.rand(1).to(self.device)/0.5 122 | mix_alpha = torch.clip(mix_alpha, 0,0.5) 123 | y_amp = mix_alpha * y_amp + (1-mix_alpha) * x_amp 124 | 125 | real = y_amp * torch.cos(y_phase) 126 | imag = y_amp * torch.sin(y_phase) 127 | y_out = torch.complex(real, imag) + 1e-8 128 | y_out = torch.fft.irfft2(y_out) + 1e-8 129 | 130 | return y_out 131 | 132 | def feed_data(self, data): 133 | self.img_gt = data['gt_img'].to(self.device) 134 | self.img_input = data['lq_img'].to(self.device) 135 | self.img_input_aug = self.amp_aug(self.img_gt, self.img_input) 136 | 137 | def feed_data_test(self, data): 138 | # self.ref_L = data['LQ'].to(self.device) 139 | self.img_gt = data['gt_img'].to(self.device) 140 | self.img_input = data['lq_img'].to(self.device) 141 | 142 | def loss_forward(self,img, gt): 143 | loss = 1 * self.Back_rec(img, gt) 144 | loss_ssim = self.ssim_loss(img, gt) 145 | 146 | return loss, loss_ssim 147 | 148 | 149 | def loss_forward_aug(self,img, gt): 150 | loss = 1 * self.Back_rec(img, gt) 151 | loss_ssim = self.ssim_loss(img, gt) 152 | 153 | l_amp, _ = self.fft_loss(img, gt) 154 | return loss, loss_ssim, l_amp 155 | 156 | def optimize_parameters(self, step): 157 | 158 | ############## optimize parameters outside CSNorm ############################ 159 | for k, v in self.netG.named_parameters(): 160 | if k not in self.layer_aug: 161 | v.requires_grad = True 162 | else: 163 | v.requires_grad = False 164 | self.optimizer_G.zero_grad() 165 | 166 | # forward 167 | self.img_pred = self.netG(self.img_input, aug=True) 168 | loss, l_ssim = self.loss_forward(self.img_pred, self.img_gt) 169 | loss = loss + l_ssim 170 | 171 | # backward 172 | loss.backward() 173 | 174 | # gradient clipping 175 | if self.train_opt['gradient_clipping']: 176 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) 177 | 178 | self.optimizer_G.step() 179 | 180 | 181 | ############## optimize parameters inside CSNorm ############################ 182 | for k, v in self.netG.named_parameters(): 183 | if k in self.layer_aug: 184 | v.requires_grad = True 185 | else: 186 | v.requires_grad = False 187 | 188 | self.optimizer_G_aug.zero_grad() 189 | 190 | # forward 191 | self.img_pred = self.netG(self.img_input_aug, aug=True) 192 | loss_back, l_ssim, l_amp = self.loss_forward_aug(self.img_pred, self.img_gt) 193 | loss_aug = loss_back + l_ssim + l_amp 194 | # backward 195 | loss_aug.backward() 196 | 197 | # gradient clipping 198 | if self.train_opt['gradient_clipping']: 199 | nn.utils.clip_grad_norm_(self.netG.parameters(), self.train_opt['gradient_clipping']) 200 | 201 | self.optimizer_G_aug.step() 202 | 203 | 204 | # set log 205 | self.log_dict['loss'] = loss.item() 206 | self.log_dict['l_amp'] = l_amp.item() 207 | self.log_dict['l_ssim'] = l_ssim.item() 208 | 209 | def test(self): 210 | 211 | self.netG.eval() 212 | with torch.no_grad(): 213 | self.img_pred = self.netG(self.img_input, aug=True) 214 | 215 | self.netG.train() 216 | 217 | def get_current_log(self): 218 | return self.log_dict 219 | 220 | def get_current_visuals(self): 221 | out_dict = OrderedDict() 222 | out_dict['img_pred'] = self.img_pred.detach()[0].float().cpu() 223 | out_dict['img_input'] = self.img_input.detach()[0].float().cpu() 224 | out_dict['img_gt'] = self.img_gt.detach()[0].float().cpu() 225 | return out_dict 226 | 227 | def print_network(self): 228 | s, n = self.get_network_description(self.netG) 229 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 230 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 231 | self.netG.module.__class__.__name__) 232 | else: 233 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 234 | if self.rank <= 0: 235 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 236 | logger.info(s) 237 | 238 | def load(self): 239 | 240 | load_path_G = self.opt['path']['pretrain_model_G'] 241 | if load_path_G is not None: 242 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 243 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 244 | 245 | def save(self, iter_label): 246 | self.save_network(self.netG, 'G', iter_label) 247 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | 8 | if model == 'CSNorm': 9 | from .CSNorm_model import CSNorm_Model as M 10 | else: 11 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 12 | m = M(opt) 13 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 14 | return m 15 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | ''' set learning rate for warmup, 39 | lr_groups_l: list for lr_groups. each for a optimizer''' 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | # get the initial lr, which is set by the scheduler 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | #### set up warm up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | # return self.schedulers[0].get_lr()[0] 67 | return self.optimizers[0].param_groups[0]['lr'] 68 | 69 | def get_network_description(self, network): 70 | '''Get the string and total parameters of the network''' 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | s = str(network) 74 | n = sum(map(lambda x: x.numel(), network.parameters())) 75 | return s, n 76 | 77 | def save_network(self, network, network_label, iter_label): 78 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 79 | save_path = os.path.join(self.opt['path']['models'], save_filename) 80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 81 | network = network.module 82 | state_dict = network.state_dict() 83 | for key, param in state_dict.items(): 84 | state_dict[key] = param.cpu() 85 | torch.save(state_dict, save_path) 86 | 87 | def load_network(self, load_path, network, strict=True): 88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 89 | network = network.module 90 | load_net = torch.load(load_path) 91 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 92 | for k, v in load_net.items(): 93 | if k.startswith('module.'): 94 | load_net_clean[k[7:]] = v 95 | else: 96 | load_net_clean[k] = v 97 | network.load_state_dict(load_net_clean, strict=strict) 98 | 99 | def save_training_state(self, epoch, iter_step): 100 | '''Saves training state during training, which will be used for resuming''' 101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 102 | for s in self.schedulers: 103 | state['schedulers'].append(s.state_dict()) 104 | for o in self.optimizers: 105 | state['optimizers'].append(o.state_dict()) 106 | save_filename = '{}.state'.format(iter_step) 107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 108 | torch.save(state, save_path) 109 | 110 | def resume_training(self, resume_state): 111 | '''Resume the optimizers and schedulers for training''' 112 | resume_optimizers = resume_state['optimizers'] 113 | resume_schedulers = resume_state['schedulers'] 114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 116 | for i, o in enumerate(resume_optimizers): 117 | self.optimizers[i].load_state_dict(o) 118 | for i, s in enumerate(resume_schedulers): 119 | self.schedulers[i].load_state_dict(s) 120 | -------------------------------------------------------------------------------- /models/ckpts/NAF_LOL.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/models/ckpts/NAF_LOL.pth -------------------------------------------------------------------------------- /models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restart_weights = weights if weights else [1] 16 | assert len(self.restarts) == len( 17 | self.restart_weights), 'restarts and their weights do not match.' 18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch in self.restarts: 22 | if self.clear_state: 23 | self.optimizer.state = defaultdict(dict) 24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 26 | if self.last_epoch not in self.milestones: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | return [ 29 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 30 | for group in self.optimizer.param_groups 31 | ] 32 | 33 | 34 | class CosineAnnealingLR_Restart(_LRScheduler): 35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 36 | self.T_period = T_period 37 | self.T_max = self.T_period[0] # current T period 38 | self.eta_min = eta_min 39 | self.restarts = restarts if restarts else [0] 40 | self.restart_weights = weights if weights else [1] 41 | self.last_restart = 0 42 | assert len(self.restarts) == len( 43 | self.restart_weights), 'restarts and their weights do not match.' 44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch == 0: 48 | return self.base_lrs 49 | elif self.last_epoch in self.restarts: 50 | self.last_restart = self.last_epoch 51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 55 | return [ 56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 61 | (group['lr'] - self.eta_min) + self.eta_min 62 | for group in self.optimizer.param_groups] 63 | 64 | 65 | if __name__ == "__main__": 66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 67 | betas=(0.9, 0.99)) 68 | ############################## 69 | # MultiStepLR_Restart 70 | ############################## 71 | ## Original 72 | lr_steps = [200000, 400000, 600000, 800000] 73 | restarts = None 74 | restart_weights = None 75 | 76 | ## two 77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 78 | restarts = [500000] 79 | restart_weights = [1] 80 | 81 | ## four 82 | lr_steps = [ 83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 85 | ] 86 | restarts = [250000, 500000, 750000] 87 | restart_weights = [1, 1, 1] 88 | 89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 90 | clear_state=False) 91 | 92 | ############################## 93 | # Cosine Annealing Restart 94 | ############################## 95 | ## two 96 | T_period = [500000, 500000] 97 | restarts = [500000] 98 | restart_weights = [1] 99 | 100 | ## four 101 | T_period = [250000, 250000, 250000, 250000] 102 | restarts = [250000, 500000, 750000] 103 | restart_weights = [1, 1, 1] 104 | 105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 106 | weights=restart_weights) 107 | 108 | ############################## 109 | # Draw figure 110 | ############################## 111 | N_iter = 1000000 112 | lr_l = list(range(N_iter)) 113 | for i in range(N_iter): 114 | scheduler.step() 115 | current_lr = optimizer.param_groups[0]['lr'] 116 | lr_l[i] = current_lr 117 | 118 | import matplotlib as mpl 119 | from matplotlib import pyplot as plt 120 | import matplotlib.ticker as mtick 121 | mpl.style.use('default') 122 | import seaborn 123 | seaborn.set(style='whitegrid') 124 | seaborn.set_context('paper') 125 | 126 | plt.figure(1) 127 | plt.subplot(111) 128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 129 | plt.title('Title', fontsize=16, color='k') 130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 131 | legend = plt.legend(loc='upper right', shadow=False) 132 | ax = plt.gca() 133 | labels = ax.get_xticks().tolist() 134 | for k, v in enumerate(labels): 135 | labels[k] = str(int(v / 1000)) + 'K' 136 | ax.set_xticklabels(labels) 137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 138 | 139 | ax.set_ylabel('Learning rate') 140 | ax.set_xlabel('Iteration') 141 | fig = plt.gcf() 142 | plt.show() 143 | -------------------------------------------------------------------------------- /models/modules/NAFNet/Baseline_arch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | ''' 6 | Simple Baselines for Image Restoration 7 | 8 | @article{chen2022simple, 9 | title={Simple Baselines for Image Restoration}, 10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, 11 | journal={arXiv preprint arXiv:2204.04676}, 12 | year={2022} 13 | } 14 | ''' 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from basicsr.models.archs.arch_util import LayerNorm2d 20 | from basicsr.models.archs.local_arch import Local_Base 21 | 22 | class BaselineBlock(nn.Module): 23 | def __init__(self, c, DW_Expand=1, FFN_Expand=2, drop_out_rate=0.): 24 | super().__init__() 25 | dw_channel = c * DW_Expand 26 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 27 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, 28 | bias=True) 29 | self.conv3 = nn.Conv2d(in_channels=dw_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 30 | 31 | # Channel Attention 32 | self.se = nn.Sequential( 33 | nn.AdaptiveAvgPool2d(1), 34 | nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 35 | groups=1, bias=True), 36 | nn.ReLU(inplace=True), 37 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, 38 | groups=1, bias=True), 39 | nn.Sigmoid() 40 | ) 41 | 42 | # GELU 43 | self.gelu = nn.GELU() 44 | 45 | ffn_channel = FFN_Expand * c 46 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 47 | self.conv5 = nn.Conv2d(in_channels=ffn_channel, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 48 | 49 | self.norm1 = LayerNorm2d(c) 50 | self.norm2 = LayerNorm2d(c) 51 | 52 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 53 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 54 | 55 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 56 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 57 | 58 | def forward(self, inp): 59 | x = inp 60 | 61 | x = self.norm1(x) 62 | 63 | x = self.conv1(x) 64 | x = self.conv2(x) 65 | x = self.gelu(x) 66 | x = x * self.se(x) 67 | x = self.conv3(x) 68 | 69 | x = self.dropout1(x) 70 | 71 | y = inp + x * self.beta 72 | 73 | x = self.conv4(self.norm2(y)) 74 | x = self.gelu(x) 75 | x = self.conv5(x) 76 | 77 | x = self.dropout2(x) 78 | 79 | return y + x * self.gamma 80 | 81 | 82 | class Baseline(nn.Module): 83 | 84 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[], dw_expand=1, ffn_expand=2): 85 | super().__init__() 86 | 87 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, 88 | bias=True) 89 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, 90 | bias=True) 91 | 92 | self.encoders = nn.ModuleList() 93 | self.decoders = nn.ModuleList() 94 | self.middle_blks = nn.ModuleList() 95 | self.ups = nn.ModuleList() 96 | self.downs = nn.ModuleList() 97 | 98 | chan = width 99 | for num in enc_blk_nums: 100 | self.encoders.append( 101 | nn.Sequential( 102 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)] 103 | ) 104 | ) 105 | self.downs.append( 106 | nn.Conv2d(chan, 2*chan, 2, 2) 107 | ) 108 | chan = chan * 2 109 | 110 | self.middle_blks = \ 111 | nn.Sequential( 112 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(middle_blk_num)] 113 | ) 114 | 115 | for num in dec_blk_nums: 116 | self.ups.append( 117 | nn.Sequential( 118 | nn.Conv2d(chan, chan * 2, 1, bias=False), 119 | nn.PixelShuffle(2) 120 | ) 121 | ) 122 | chan = chan // 2 123 | self.decoders.append( 124 | nn.Sequential( 125 | *[BaselineBlock(chan, dw_expand, ffn_expand) for _ in range(num)] 126 | ) 127 | ) 128 | 129 | self.padder_size = 2 ** len(self.encoders) 130 | 131 | def forward(self, inp): 132 | B, C, H, W = inp.shape 133 | inp = self.check_image_size(inp) 134 | 135 | x = self.intro(inp) 136 | 137 | encs = [] 138 | 139 | for encoder, down in zip(self.encoders, self.downs): 140 | x = encoder(x) 141 | encs.append(x) 142 | x = down(x) 143 | 144 | x = self.middle_blks(x) 145 | 146 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 147 | x = up(x) 148 | x = x + enc_skip 149 | x = decoder(x) 150 | 151 | x = self.ending(x) 152 | x = x + inp 153 | 154 | return x[:, :, :H, :W] 155 | 156 | def check_image_size(self, x): 157 | _, _, h, w = x.size() 158 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 159 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 160 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 161 | return x 162 | 163 | class BaselineLocal(Local_Base, Baseline): 164 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): 165 | Local_Base.__init__(self) 166 | Baseline.__init__(self, *args, **kwargs) 167 | 168 | N, C, H, W = train_size 169 | base_size = (int(H * 1.5), int(W * 1.5)) 170 | 171 | self.eval() 172 | with torch.no_grad(): 173 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 174 | 175 | if __name__ == '__main__': 176 | img_channel = 3 177 | width = 32 178 | 179 | dw_expand = 1 180 | ffn_expand = 2 181 | 182 | # enc_blks = [2, 2, 4, 8] 183 | # middle_blk_num = 12 184 | # dec_blks = [2, 2, 2, 2] 185 | 186 | enc_blks = [1, 1, 1, 28] 187 | middle_blk_num = 1 188 | dec_blks = [1, 1, 1, 1] 189 | 190 | net = Baseline(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, 191 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks, dw_expand=dw_expand, ffn_expand=ffn_expand) 192 | 193 | inp_shape = (3, 256, 256) 194 | 195 | from ptflops import get_model_complexity_info 196 | 197 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) 198 | 199 | params = float(params[:-3]) 200 | macs = float(macs[:-4]) 201 | 202 | print(macs, params) 203 | -------------------------------------------------------------------------------- /models/modules/NAFNet/NAFNet.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | ''' 6 | Simple Baselines for Image Restoration 7 | 8 | @article{chen2022simple, 9 | title={Simple Baselines for Image Restoration}, 10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, 11 | journal={arXiv preprint arXiv:2204.04676}, 12 | year={2022} 13 | } 14 | ''' 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from models.modules.NAFNet.arch_util import LayerNorm2d 20 | from models.modules.NAFNet.local_arch import Local_Base 21 | 22 | class SimpleGate(nn.Module): 23 | def forward(self, x): 24 | x1, x2 = x.chunk(2, dim=1) 25 | return x1 * x2 26 | 27 | class NAFBlock(nn.Module): 28 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 29 | super().__init__() 30 | dw_channel = c * DW_Expand 31 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 32 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, 33 | bias=True) 34 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 35 | 36 | # Simplified Channel Attention 37 | self.sca = nn.Sequential( 38 | nn.AdaptiveAvgPool2d(1), 39 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 40 | groups=1, bias=True), 41 | ) 42 | 43 | # SimpleGate 44 | self.sg = SimpleGate() 45 | 46 | ffn_channel = FFN_Expand * c 47 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 48 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 49 | 50 | self.norm1 = LayerNorm2d(c) 51 | self.norm2 = LayerNorm2d(c) 52 | 53 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 54 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 55 | 56 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 57 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 58 | 59 | def forward(self, inp): 60 | x = inp 61 | 62 | x = self.norm1(x) 63 | 64 | x = self.conv1(x) 65 | x = self.conv2(x) 66 | x = self.sg(x) 67 | x = x * self.sca(x) 68 | x = self.conv3(x) 69 | 70 | x = self.dropout1(x) 71 | 72 | y = inp + x * self.beta 73 | 74 | x = self.conv4(self.norm2(y)) 75 | x = self.sg(x) 76 | x = self.conv5(x) 77 | 78 | x = self.dropout2(x) 79 | 80 | return y + x * self.gamma 81 | 82 | 83 | class Generate_gate(nn.Module): 84 | def __init__(self): 85 | super(Generate_gate, self).__init__() 86 | self.proj = nn.Sequential(nn.AdaptiveAvgPool2d(1), 87 | nn.Conv2d(512,256, 1), 88 | nn.ReLU(), 89 | nn.Conv2d(256,512, 1), 90 | nn.ReLU()) 91 | 92 | self.epsilon = 1e-8 93 | def forward(self, x): 94 | 95 | 96 | alpha = self.proj(x) 97 | gate = (alpha**2) / (alpha**2 + self.epsilon) 98 | 99 | return gate 100 | 101 | def freeze(layer): 102 | for child in layer.children(): 103 | for param in child.parameters(): 104 | param.requires_grad = False 105 | 106 | 107 | def unfreeze(layer): 108 | for child in layer.children(): 109 | for param in child.parameters(): 110 | param.requires_grad = True 111 | 112 | def freeze_direct(layer): 113 | for param in layer.parameters(): 114 | param.requires_grad = False 115 | 116 | 117 | class NAFNet(nn.Module): 118 | 119 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]): 120 | super().__init__() 121 | 122 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, 123 | bias=True) 124 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, 125 | bias=True) 126 | 127 | self.encoders = nn.ModuleList() 128 | self.decoders = nn.ModuleList() 129 | self.middle_blks = nn.ModuleList() 130 | self.ups = nn.ModuleList() 131 | self.downs = nn.ModuleList() 132 | 133 | chan = width 134 | for num in enc_blk_nums: 135 | self.encoders.append( 136 | nn.Sequential( 137 | *[NAFBlock(chan) for _ in range(num)] 138 | ) 139 | ) 140 | self.downs.append( 141 | nn.Conv2d(chan, 2*chan, 2, 2) 142 | ) 143 | chan = chan * 2 144 | 145 | self.middle_blks = \ 146 | nn.Sequential( 147 | *[NAFBlock(chan) for _ in range(middle_blk_num)] 148 | ) 149 | 150 | for num in dec_blk_nums: 151 | self.ups.append( 152 | nn.Sequential( 153 | nn.Conv2d(chan, chan * 2, 1, bias=False), 154 | nn.PixelShuffle(2) 155 | ) 156 | ) 157 | chan = chan // 2 158 | self.decoders.append( 159 | nn.Sequential( 160 | *[NAFBlock(chan) for _ in range(num)] 161 | ) 162 | ) 163 | 164 | self.padder_size = 2 ** len(self.encoders) 165 | 166 | 167 | ###################### init CSNorm ################## 168 | self.gate = Generate_gate() 169 | for i in range(512): 170 | setattr(self, 'CSN_' + str(i), nn.InstanceNorm2d(1, affine=True)) 171 | freeze_direct(getattr(self, 'CSN_' + str(i))) 172 | freeze(self.gate) 173 | ###################### init CSNorm ################## 174 | 175 | 176 | def forward(self, inp, aug=False): 177 | B, C, H, W = inp.shape 178 | inp = self.check_image_size(inp) 179 | 180 | x = self.intro(inp) 181 | 182 | encs = [] 183 | 184 | for encoder, down in zip(self.encoders, self.downs): 185 | x = encoder(x) 186 | encs.append(x) 187 | x = down(x) 188 | 189 | ##################### add CSNorm in the network ################# 190 | if aug: 191 | gate = self.gate(x) 192 | lq_copy = torch.cat([getattr(self, 'CSN_' + str(i))(x[:,i,:,:][:,None,:,:]) for i in range(512)], dim=1) 193 | x = gate * (lq_copy) + (1-gate) * x 194 | ##################### add CSNorm in the network ################# 195 | 196 | 197 | x = self.middle_blks(x) 198 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 199 | x = up(x) 200 | x = x + enc_skip 201 | x = decoder(x) 202 | 203 | x = self.ending(x) 204 | x = x + inp 205 | 206 | return x[:, :, :H, :W] 207 | 208 | def check_image_size(self, x): 209 | _, _, h, w = x.size() 210 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 211 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 212 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 213 | return x 214 | 215 | class NAFNetLocal(Local_Base, NAFNet): 216 | def __init__(self, *args, train_size=(1, 3, 256, 256), fast_imp=False, **kwargs): 217 | Local_Base.__init__(self) 218 | NAFNet.__init__(self, *args, **kwargs) 219 | 220 | N, C, H, W = train_size 221 | base_size = (int(H * 1.5), int(W * 1.5)) 222 | 223 | self.eval() 224 | with torch.no_grad(): 225 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 226 | 227 | if __name__ == '__main__': 228 | img_channel = 3 229 | width = 32 230 | enc_blks = [1, 1, 1, 1] 231 | middle_blk_num = 1 232 | dec_blks = [1, 1, 1, 1] 233 | 234 | model = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, 235 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) 236 | 237 | inp_shape = (1, 3, 64, 64) 238 | 239 | device = "cpu" 240 | if torch.cuda.is_available(): 241 | device = "cuda" 242 | input1 = torch.randn(inp_shape).to(device) 243 | model = model.to(device) 244 | 245 | import re 246 | # layer_pattern = re.compile(r'module\.(gate\.proj|bn_\d+)\.') 247 | layer_pattern = re.compile(r'(gate\.proj|CSN_\d+)\.') 248 | 249 | selected_params = [ 250 | name for name, param in model.named_parameters() 251 | if layer_pattern.search(name) 252 | ] 253 | print(selected_params) 254 | -------------------------------------------------------------------------------- /models/modules/NAFNet/arch_util.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | # Modified from BasicSR (https://github.com/xinntao/BasicSR) 5 | # Copyright 2018-2020 BasicSR Authors 6 | # ------------------------------------------------------------------------ 7 | import math 8 | import torch 9 | from torch import nn as nn 10 | from torch.nn import functional as F 11 | from torch.nn import init as init 12 | from torch.nn.modules.batchnorm import _BatchNorm 13 | 14 | # from basicsr.utils import get_root_logger 15 | 16 | # try: 17 | # from basicsr.models.ops.dcn import (ModulatedDeformConvPack, 18 | # modulated_deform_conv) 19 | # except ImportError: 20 | # # print('Cannot import dcn. Ignore this warning if dcn is not used. ' 21 | # # 'Otherwise install BasicSR with compiling dcn.') 22 | # 23 | 24 | @torch.no_grad() 25 | def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): 26 | """Initialize network weights. 27 | 28 | Args: 29 | module_list (list[nn.Module] | nn.Module): Modules to be initialized. 30 | scale (float): Scale initialized weights, especially for residual 31 | blocks. Default: 1. 32 | bias_fill (float): The value to fill bias. Default: 0 33 | kwargs (dict): Other arguments for initialization function. 34 | """ 35 | if not isinstance(module_list, list): 36 | module_list = [module_list] 37 | for module in module_list: 38 | for m in module.modules(): 39 | if isinstance(m, nn.Conv2d): 40 | init.kaiming_normal_(m.weight, **kwargs) 41 | m.weight.data *= scale 42 | if m.bias is not None: 43 | m.bias.data.fill_(bias_fill) 44 | elif isinstance(m, nn.Linear): 45 | init.kaiming_normal_(m.weight, **kwargs) 46 | m.weight.data *= scale 47 | if m.bias is not None: 48 | m.bias.data.fill_(bias_fill) 49 | elif isinstance(m, _BatchNorm): 50 | init.constant_(m.weight, 1) 51 | if m.bias is not None: 52 | m.bias.data.fill_(bias_fill) 53 | 54 | 55 | def make_layer(basic_block, num_basic_block, **kwarg): 56 | """Make layers by stacking the same blocks. 57 | 58 | Args: 59 | basic_block (nn.module): nn.module class for basic block. 60 | num_basic_block (int): number of blocks. 61 | 62 | Returns: 63 | nn.Sequential: Stacked blocks in nn.Sequential. 64 | """ 65 | layers = [] 66 | for _ in range(num_basic_block): 67 | layers.append(basic_block(**kwarg)) 68 | return nn.Sequential(*layers) 69 | 70 | 71 | class ResidualBlockNoBN(nn.Module): 72 | """Residual block without BN. 73 | 74 | It has a style of: 75 | ---Conv-ReLU-Conv-+- 76 | |________________| 77 | 78 | Args: 79 | num_feat (int): Channel number of intermediate features. 80 | Default: 64. 81 | res_scale (float): Residual scale. Default: 1. 82 | pytorch_init (bool): If set to True, use pytorch default init, 83 | otherwise, use default_init_weights. Default: False. 84 | """ 85 | 86 | def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): 87 | super(ResidualBlockNoBN, self).__init__() 88 | self.res_scale = res_scale 89 | self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 90 | self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) 91 | self.relu = nn.ReLU(inplace=True) 92 | 93 | if not pytorch_init: 94 | default_init_weights([self.conv1, self.conv2], 0.1) 95 | 96 | def forward(self, x): 97 | identity = x 98 | out = self.conv2(self.relu(self.conv1(x))) 99 | return identity + out * self.res_scale 100 | 101 | 102 | class Upsample(nn.Sequential): 103 | """Upsample module. 104 | 105 | Args: 106 | scale (int): Scale factor. Supported scales: 2^n and 3. 107 | num_feat (int): Channel number of intermediate features. 108 | """ 109 | 110 | def __init__(self, scale, num_feat): 111 | m = [] 112 | if (scale & (scale - 1)) == 0: # scale = 2^n 113 | for _ in range(int(math.log(scale, 2))): 114 | m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) 115 | m.append(nn.PixelShuffle(2)) 116 | elif scale == 3: 117 | m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) 118 | m.append(nn.PixelShuffle(3)) 119 | else: 120 | raise ValueError(f'scale {scale} is not supported. ' 121 | 'Supported scales: 2^n and 3.') 122 | super(Upsample, self).__init__(*m) 123 | 124 | 125 | def flow_warp(x, 126 | flow, 127 | interp_mode='bilinear', 128 | padding_mode='zeros', 129 | align_corners=True): 130 | """Warp an image or feature map with optical flow. 131 | 132 | Args: 133 | x (Tensor): Tensor with size (n, c, h, w). 134 | flow (Tensor): Tensor with size (n, h, w, 2), normal value. 135 | interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. 136 | padding_mode (str): 'zeros' or 'border' or 'reflection'. 137 | Default: 'zeros'. 138 | align_corners (bool): Before pytorch 1.3, the default value is 139 | align_corners=True. After pytorch 1.3, the default value is 140 | align_corners=False. Here, we use the True as default. 141 | 142 | Returns: 143 | Tensor: Warped image or feature map. 144 | """ 145 | assert x.size()[-2:] == flow.size()[1:3] 146 | _, _, h, w = x.size() 147 | # create mesh grid 148 | grid_y, grid_x = torch.meshgrid( 149 | torch.arange(0, h).type_as(x), 150 | torch.arange(0, w).type_as(x)) 151 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 152 | grid.requires_grad = False 153 | 154 | vgrid = grid + flow 155 | # scale grid to [-1,1] 156 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 157 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 158 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 159 | output = F.grid_sample( 160 | x, 161 | vgrid_scaled, 162 | mode=interp_mode, 163 | padding_mode=padding_mode, 164 | align_corners=align_corners) 165 | 166 | # TODO, what if align_corners=False 167 | return output 168 | 169 | 170 | def resize_flow(flow, 171 | size_type, 172 | sizes, 173 | interp_mode='bilinear', 174 | align_corners=False): 175 | """Resize a flow according to ratio or shape. 176 | 177 | Args: 178 | flow (Tensor): Precomputed flow. shape [N, 2, H, W]. 179 | size_type (str): 'ratio' or 'shape'. 180 | sizes (list[int | float]): the ratio for resizing or the final output 181 | shape. 182 | 1) The order of ratio should be [ratio_h, ratio_w]. For 183 | downsampling, the ratio should be smaller than 1.0 (i.e., ratio 184 | < 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., 185 | ratio > 1.0). 186 | 2) The order of output_size should be [out_h, out_w]. 187 | interp_mode (str): The mode of interpolation for resizing. 188 | Default: 'bilinear'. 189 | align_corners (bool): Whether align corners. Default: False. 190 | 191 | Returns: 192 | Tensor: Resized flow. 193 | """ 194 | _, _, flow_h, flow_w = flow.size() 195 | if size_type == 'ratio': 196 | output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) 197 | elif size_type == 'shape': 198 | output_h, output_w = sizes[0], sizes[1] 199 | else: 200 | raise ValueError( 201 | f'Size type should be ratio or shape, but got type {size_type}.') 202 | 203 | input_flow = flow.clone() 204 | ratio_h = output_h / flow_h 205 | ratio_w = output_w / flow_w 206 | input_flow[:, 0, :, :] *= ratio_w 207 | input_flow[:, 1, :, :] *= ratio_h 208 | resized_flow = F.interpolate( 209 | input=input_flow, 210 | size=(output_h, output_w), 211 | mode=interp_mode, 212 | align_corners=align_corners) 213 | return resized_flow 214 | 215 | 216 | # TODO: may write a cpp file 217 | def pixel_unshuffle(x, scale): 218 | """ Pixel unshuffle. 219 | 220 | Args: 221 | x (Tensor): Input feature with shape (b, c, hh, hw). 222 | scale (int): Downsample ratio. 223 | 224 | Returns: 225 | Tensor: the pixel unshuffled feature. 226 | """ 227 | b, c, hh, hw = x.size() 228 | out_channel = c * (scale**2) 229 | assert hh % scale == 0 and hw % scale == 0 230 | h = hh // scale 231 | w = hw // scale 232 | x_view = x.view(b, c, h, scale, w, scale) 233 | return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) 234 | 235 | 236 | # class DCNv2Pack(ModulatedDeformConvPack): 237 | # """Modulated deformable conv for deformable alignment. 238 | # 239 | # Different from the official DCNv2Pack, which generates offsets and masks 240 | # from the preceding features, this DCNv2Pack takes another different 241 | # features to generate offsets and masks. 242 | # 243 | # Ref: 244 | # Delving Deep into Deformable Alignment in Video Super-Resolution. 245 | # """ 246 | # 247 | # def forward(self, x, feat): 248 | # out = self.conv_offset(feat) 249 | # o1, o2, mask = torch.chunk(out, 3, dim=1) 250 | # offset = torch.cat((o1, o2), dim=1) 251 | # mask = torch.sigmoid(mask) 252 | # 253 | # offset_absmean = torch.mean(torch.abs(offset)) 254 | # if offset_absmean > 50: 255 | # logger = get_root_logger() 256 | # logger.warning( 257 | # f'Offset abs mean is {offset_absmean}, larger than 50.') 258 | # 259 | # return modulated_deform_conv(x, offset, mask, self.weight, self.bias, 260 | # self.stride, self.padding, self.dilation, 261 | # self.groups, self.deformable_groups) 262 | 263 | 264 | class LayerNormFunction(torch.autograd.Function): 265 | 266 | @staticmethod 267 | def forward(ctx, x, weight, bias, eps): 268 | ctx.eps = eps 269 | N, C, H, W = x.size() 270 | mu = x.mean(1, keepdim=True) 271 | var = (x - mu).pow(2).mean(1, keepdim=True) 272 | y = (x - mu) / (var + eps).sqrt() 273 | ctx.save_for_backward(y, var, weight) 274 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 275 | return y 276 | 277 | @staticmethod 278 | def backward(ctx, grad_output): 279 | eps = ctx.eps 280 | 281 | N, C, H, W = grad_output.size() 282 | y, var, weight = ctx.saved_variables 283 | g = grad_output * weight.view(1, C, 1, 1) 284 | mean_g = g.mean(dim=1, keepdim=True) 285 | 286 | mean_gy = (g * y).mean(dim=1, keepdim=True) 287 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 288 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 289 | dim=0), None 290 | 291 | class LayerNorm2d(nn.Module): 292 | 293 | def __init__(self, channels, eps=1e-6): 294 | super(LayerNorm2d, self).__init__() 295 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 296 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 297 | self.eps = eps 298 | 299 | def forward(self, x): 300 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 301 | 302 | # handle multiple input 303 | class MySequential(nn.Sequential): 304 | def forward(self, *inputs): 305 | for module in self._modules.values(): 306 | if type(inputs) == tuple: 307 | inputs = module(*inputs) 308 | else: 309 | inputs = module(inputs) 310 | return inputs 311 | 312 | import time 313 | def measure_inference_speed(model, data, max_iter=200, log_interval=50): 314 | model.eval() 315 | 316 | # the first several iterations may be very slow so skip them 317 | num_warmup = 5 318 | pure_inf_time = 0 319 | fps = 0 320 | 321 | # benchmark with 2000 image and take the average 322 | for i in range(max_iter): 323 | 324 | torch.cuda.synchronize() 325 | start_time = time.perf_counter() 326 | 327 | with torch.no_grad(): 328 | model(*data) 329 | 330 | torch.cuda.synchronize() 331 | elapsed = time.perf_counter() - start_time 332 | 333 | if i >= num_warmup: 334 | pure_inf_time += elapsed 335 | if (i + 1) % log_interval == 0: 336 | fps = (i + 1 - num_warmup) / pure_inf_time 337 | print( 338 | f'Done image [{i + 1:<3}/ {max_iter}], ' 339 | f'fps: {fps:.1f} img / s, ' 340 | f'times per image: {1000 / fps:.1f} ms / img', 341 | flush=True) 342 | 343 | if (i + 1) == max_iter: 344 | fps = (i + 1 - num_warmup) / pure_inf_time 345 | print( 346 | f'Overall fps: {fps:.1f} img / s, ' 347 | f'times per image: {1000 / fps:.1f} ms / img', 348 | flush=True) 349 | break 350 | return fps -------------------------------------------------------------------------------- /models/modules/NAFNet/local_arch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | class AvgPool2d(nn.Module): 11 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 12 | super().__init__() 13 | self.kernel_size = kernel_size 14 | self.base_size = base_size 15 | self.auto_pad = auto_pad 16 | 17 | # only used for fast implementation 18 | self.fast_imp = fast_imp 19 | self.rs = [5, 4, 3, 2, 1] 20 | self.max_r1 = self.rs[0] 21 | self.max_r2 = self.rs[0] 22 | self.train_size = train_size 23 | 24 | def extra_repr(self) -> str: 25 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 26 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp 27 | ) 28 | 29 | def forward(self, x): 30 | if self.kernel_size is None and self.base_size: 31 | train_size = self.train_size 32 | if isinstance(self.base_size, int): 33 | self.base_size = (self.base_size, self.base_size) 34 | self.kernel_size = list(self.base_size) 35 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] 36 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] 37 | 38 | # only used for fast implementation 39 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 40 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 41 | 42 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 43 | return F.adaptive_avg_pool2d(x, 1) 44 | 45 | if self.fast_imp: # Non-equivalent implementation but faster 46 | h, w = x.shape[2:] 47 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 48 | out = F.adaptive_avg_pool2d(x, 1) 49 | else: 50 | r1 = [r for r in self.rs if h % r == 0][0] 51 | r2 = [r for r in self.rs if w % r == 0][0] 52 | # reduction_constraint 53 | r1 = min(self.max_r1, r1) 54 | r2 = min(self.max_r2, r2) 55 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 56 | n, c, h, w = s.shape 57 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 58 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 59 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) 60 | else: 61 | n, c, h, w = x.shape 62 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 63 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 64 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 65 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 66 | out = s4 + s1 - s2 - s3 67 | out = out / (k1 * k2) 68 | 69 | if self.auto_pad: 70 | n, c, h, w = x.shape 71 | _h, _w = out.shape[2:] 72 | # print(x.shape, self.kernel_size) 73 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 74 | out = torch.nn.functional.pad(out, pad2d, mode='replicate') 75 | 76 | return out 77 | 78 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 79 | for n, m in model.named_children(): 80 | if len(list(m.children())) > 0: 81 | ## compound module, go inside it 82 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 83 | 84 | if isinstance(m, nn.AdaptiveAvgPool2d): 85 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) 86 | assert m.output_size == 1 87 | setattr(model, n, pool) 88 | 89 | 90 | ''' 91 | ref. 92 | @article{chu2021tlsc, 93 | title={Revisiting Global Statistics Aggregation for Improving Image Restoration}, 94 | author={Chu, Xiaojie and Chen, Liangyu and and Chen, Chengpeng and Lu, Xin}, 95 | journal={arXiv preprint arXiv:2112.04491}, 96 | year={2021} 97 | } 98 | ''' 99 | class Local_Base(): 100 | def convert(self, *args, train_size, **kwargs): 101 | replace_layers(self, *args, train_size=train_size, **kwargs) 102 | imgs = torch.rand(train_size) 103 | with torch.no_grad(): 104 | self.forward(imgs) 105 | -------------------------------------------------------------------------------- /models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/models/modules/__init__.py -------------------------------------------------------------------------------- /models/modules/common.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 9 | return nn.Conv2d( 10 | in_channels, out_channels, kernel_size, 11 | padding=(kernel_size // 2), bias=bias) 12 | 13 | 14 | class MeanShift(nn.Conv2d): 15 | def __init__(self, rgb_range, rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 16 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 17 | std = torch.Tensor(rgb_std) 18 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 19 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 20 | for p in self.parameters(): 21 | p.requires_grad = False 22 | 23 | 24 | class BasicBlock(nn.Sequential): 25 | def __init__( 26 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 27 | bn=True, act=nn.ReLU(True)): 28 | 29 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 30 | if bn: 31 | m.append(nn.BatchNorm2d(out_channels)) 32 | if act is not None: 33 | m.append(act) 34 | 35 | super(BasicBlock, self).__init__(*m) 36 | 37 | 38 | class ResBlock(nn.Module): 39 | def __init__( 40 | self, conv, n_feats, kernel_size, 41 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 42 | 43 | super(ResBlock, self).__init__() 44 | m = [] 45 | for i in range(2): 46 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 47 | if bn: 48 | m.append(nn.BatchNorm2d(n_feats)) 49 | if i == 0: 50 | m.append(act) 51 | 52 | self.body = nn.Sequential(*m) 53 | self.res_scale = res_scale 54 | 55 | def forward(self, x): 56 | res = self.body(x) 57 | res += x 58 | 59 | return res 60 | 61 | 62 | class Upsampler(nn.Sequential): 63 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 64 | 65 | m = [] 66 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 67 | for _ in range(int(math.log(scale, 2))): 68 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 69 | m.append(nn.PixelShuffle(2)) 70 | if bn: 71 | m.append(nn.BatchNorm2d(n_feats)) 72 | if act == 'relu': 73 | m.append(nn.ReLU(True)) 74 | elif act == 'prelu': 75 | m.append(nn.PReLU(n_feats)) 76 | 77 | elif scale == 3: 78 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 79 | m.append(nn.PixelShuffle(3)) 80 | if bn: 81 | m.append(nn.BatchNorm2d(n_feats)) 82 | if act == 'relu': 83 | m.append(nn.ReLU(True)) 84 | elif act == 'prelu': 85 | m.append(nn.PReLU(n_feats)) 86 | else: 87 | raise NotImplementedError 88 | 89 | super(Upsampler, self).__init__(*m) -------------------------------------------------------------------------------- /models/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from torchvision.models.vgg import vgg16 5 | from torch.nn import functional as F 6 | import torch.fft as fft 7 | 8 | class ReconstructionLoss(nn.Module): 9 | def __init__(self, losstype='l2', eps=1e-3): 10 | super(ReconstructionLoss, self).__init__() 11 | self.losstype = losstype 12 | self.eps = eps 13 | 14 | def forward(self, x, target): 15 | if self.losstype == 'l2': 16 | return torch.mean(torch.sum((x - target)**2, (1, 2, 3))) 17 | elif self.losstype == 'l1': 18 | diff = x - target 19 | return torch.mean(torch.sum(torch.sqrt(diff * diff + self.eps), (1, 2, 3))) 20 | elif self.losstype == 'l_log': 21 | diff = x - target 22 | eps = 1e-6 23 | return torch.mean(torch.sum(-torch.log(1-diff.abs()+eps), (1, 2, 3))) 24 | else: 25 | print("reconstruction loss type error!") 26 | return 0 27 | 28 | 29 | class FFT_Loss(nn.Module): 30 | def __init__(self, losstype='l2', eps=1e-3): 31 | super(FFT_Loss, self).__init__() 32 | # self.fpre = 33 | def forward(self, x, gt): 34 | x = x + 1e-8 35 | gt = gt + 1e-8 36 | x_freq= torch.fft.rfft2(x, norm='backward') 37 | x_amp = torch.abs(x_freq) 38 | x_phase = torch.angle(x_freq) 39 | 40 | gt_freq= torch.fft.rfft2(gt, norm='backward') 41 | gt_amp = torch.abs(gt_freq) 42 | gt_phase = torch.angle(gt_freq) 43 | 44 | loss_amp = torch.mean(torch.sum((x_amp - gt_amp) ** 2)) 45 | loss_phase = torch.mean(torch.sum((x_phase - gt_phase) ** 2)) 46 | return loss_amp, loss_phase 47 | 48 | # Gradient Loss 49 | class Gradient_Loss(nn.Module): 50 | def __init__(self, losstype='l2'): 51 | super(Gradient_Loss, self).__init__() 52 | a = np.array([[1, 0, -1], [2, 0, -2], [1, 0, -1]]) 53 | conv1 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 54 | a = torch.from_numpy(a).float().unsqueeze(0) 55 | a = torch.stack((a, a, a)) 56 | conv1.weight = nn.Parameter(a, requires_grad=False) 57 | self.conv1 = conv1.cuda() 58 | 59 | b = np.array([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]) 60 | conv2 = nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1, bias=False, groups=3) 61 | b = torch.from_numpy(b).float().unsqueeze(0) 62 | b = torch.stack((b, b, b)) 63 | conv2.weight = nn.Parameter(b, requires_grad=False) 64 | self.conv2 = conv2.cuda() 65 | 66 | # self.Loss_criterion = ReconstructionLoss(losstype) 67 | self.Loss_criterion = nn.L1Loss() 68 | 69 | def forward(self, x, y): 70 | x1 = self.conv1(x) 71 | x2 = self.conv2(x) 72 | # x_total = torch.sqrt(torch.pow(x1, 2) + torch.pow(x2, 2)) 73 | 74 | y1 = self.conv1(y) 75 | y2 = self.conv2(y) 76 | # y_total = torch.sqrt(torch.pow(y1, 2) + torch.pow(y2, 2)) 77 | 78 | l_h = self.Loss_criterion(x1, y1) 79 | l_v = self.Loss_criterion(x2, y2) 80 | # l_total = self.Loss_criterion(x_total, y_total) 81 | return l_h + l_v #+ l_total 82 | 83 | 84 | class SSIM_Loss(nn.Module): 85 | """Layer to compute the SSIM loss between a pair of images 86 | """ 87 | def __init__(self): 88 | super(SSIM_Loss, self).__init__() 89 | self.mu_x_pool = nn.AvgPool2d(3, 1) 90 | self.mu_y_pool = nn.AvgPool2d(3, 1) 91 | self.sig_x_pool = nn.AvgPool2d(3, 1) 92 | self.sig_y_pool = nn.AvgPool2d(3, 1) 93 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 94 | 95 | self.refl = nn.ReflectionPad2d(1) 96 | 97 | self.C1 = 0.01 ** 2 98 | self.C2 = 0.03 ** 2 99 | 100 | def forward(self, x, y): 101 | x = self.refl(x) 102 | y = self.refl(y) 103 | 104 | mu_x = self.mu_x_pool(x) 105 | mu_y = self.mu_y_pool(y) 106 | 107 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 108 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 109 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 110 | 111 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 112 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 113 | 114 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 115 | 116 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 117 | class GANLoss(nn.Module): 118 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 119 | super(GANLoss, self).__init__() 120 | self.gan_type = gan_type.lower() 121 | self.real_label_val = real_label_val 122 | self.fake_label_val = fake_label_val 123 | 124 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 125 | self.loss = nn.BCEWithLogitsLoss() 126 | elif self.gan_type == 'lsgan': 127 | self.loss = nn.MSELoss() 128 | elif self.gan_type == 'wgan-gp': 129 | 130 | def wgan_loss(input, target): 131 | # target is boolean 132 | return -1 * input.mean() if target else input.mean() 133 | 134 | self.loss = wgan_loss 135 | else: 136 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 137 | 138 | def get_target_label(self, input, target_is_real): 139 | if self.gan_type == 'wgan-gp': 140 | return target_is_real 141 | if target_is_real: 142 | return torch.empty_like(input).fill_(self.real_label_val) 143 | else: 144 | return torch.empty_like(input).fill_(self.fake_label_val) 145 | 146 | def forward(self, input, target_is_real): 147 | target_label = self.get_target_label(input, target_is_real) 148 | loss = self.loss(input, target_label) 149 | return loss 150 | 151 | 152 | class GradientPenaltyLoss(nn.Module): 153 | def __init__(self, device=torch.device('cpu')): 154 | super(GradientPenaltyLoss, self).__init__() 155 | self.register_buffer('grad_outputs', torch.Tensor()) 156 | self.grad_outputs = self.grad_outputs.to(device) 157 | 158 | def get_grad_outputs(self, input): 159 | if self.grad_outputs.size() != input.size(): 160 | self.grad_outputs.resize_(input.size()).fill_(1.0) 161 | return self.grad_outputs 162 | 163 | def forward(self, interp, interp_crit): 164 | grad_outputs = self.get_grad_outputs(interp_crit) 165 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 166 | grad_outputs=grad_outputs, create_graph=True, 167 | retain_graph=True, only_inputs=True)[0] 168 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 169 | grad_interp_norm = grad_interp.norm(2, dim=1) 170 | 171 | loss = ((grad_interp_norm - 1)**2).mean() 172 | return loss 173 | 174 | class TVLoss(nn.Module): 175 | def __init__(self, TVLoss_weight=1): 176 | super(TVLoss, self).__init__() 177 | self.TVLoss_weight = TVLoss_weight 178 | 179 | def forward(self, x): 180 | batch_size = x.size()[0] 181 | h_x = x.size()[2] 182 | w_x = x.size()[3] 183 | count_h = self._tensor_size(x[:, :, 1:, :]) 184 | count_w = self._tensor_size(x[:, :, :, 1:]) 185 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 186 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 187 | return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 188 | 189 | def _tensor_size(self, t): 190 | return t.size()[1] * t.size()[2] * t.size()[3] 191 | 192 | class TV_extractor(nn.Module): 193 | def __init__(self, TVLoss_weight=1): 194 | super(TV_extractor, self).__init__() 195 | self.TVLoss_weight = TVLoss_weight 196 | self.fil = nn.Parameter(torch.ones(1, 1, 3, 3)/9, requires_grad=False) 197 | 198 | def forward(self, x): 199 | batch_size = x.size()[0] 200 | h_x = x.size()[2] 201 | w_x = x.size()[3] 202 | count_h = self._tensor_size(x[:, :, 1:, :]) 203 | count_w = self._tensor_size(x[:, :, :, 1:]) 204 | h_tv = torch.abs((x[:, :, 1:, :] - x[:, :, :h_x - 1, :])) 205 | w_tv = torch.abs((x[:, :, :, 1:] - x[:, :, :, :w_x - 1])) 206 | h_tv = F.pad(h_tv, [0,0,0,1], "constant", 0) 207 | w_tv = F.pad(w_tv, [0,1,0,0], "constant", 0) 208 | 209 | h_tv = F.conv2d(h_tv, self.fil, stride=1, padding=1, groups=1) 210 | w_tv = F.conv2d(w_tv, self.fil, stride=1, padding=1, groups=1) 211 | 212 | # print(h_tv.shape, w_tv.shape) 213 | tv = torch.abs(h_tv)+torch.abs(w_tv) 214 | return tv 215 | 216 | def _tensor_size(self, t): 217 | return t.size()[1] * t.size()[2] * t.size()[3] 218 | 219 | class CL_Loss(nn.Module): 220 | def __init__(self, opt): 221 | super(CL_Loss, self).__init__() 222 | self.opt = opt 223 | self.d = nn.MSELoss(size_average=True) 224 | vgg = vgg16(pretrained=False).cuda() 225 | vgg.load_state_dict(torch.load(self.opt['vgg16_model'])) 226 | self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 227 | for param in self.loss_network.parameters(): 228 | param.requires_grad = False 229 | 230 | def forward(self, anchor, postive, negative): 231 | anchor_f = self.loss_network(anchor) 232 | positive_f = self.loss_network(postive) 233 | negative_f = self.loss_network(negative) 234 | 235 | loss = self.d(anchor_f, positive_f)/self.d(anchor_f, negative_f) 236 | return loss 237 | 238 | class Percep_Loss(nn.Module): 239 | def __init__(self, opt): 240 | super(Percep_Loss, self).__init__() 241 | self.opt = opt 242 | self.d = nn.MSELoss(size_average=True) 243 | vgg = vgg16(pretrained=True).cuda() 244 | # vgg.load_state_dict(torch.load(self.opt['vgg16_model'])) 245 | # self.loss_network = nn.Sequential(*list(vgg.features)[:31]).eval() 246 | # for param in self.loss_network.parameters(): 247 | # param.requires_grad = False 248 | 249 | blocks = [] 250 | blocks.append(vgg.features[:4].eval()) 251 | blocks.append(vgg.features[4:9].eval()) 252 | blocks.append(vgg.features[9:16].eval()) 253 | blocks.append(vgg.features[16:23].eval()) 254 | for bl in blocks: 255 | for p in bl.parameters(): 256 | p.requires_grad = False 257 | 258 | self.blocks = torch.nn.ModuleList(blocks) 259 | self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 260 | self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 261 | 262 | def forward(self, input, target,feature_layers=[0, 1, 2, 3],weights=[1,1,1,1]): 263 | if input.shape[1] != 3: 264 | input = input.repeat(1, 3, 1, 1) 265 | target = target.repeat(1, 3, 1, 1) 266 | # input = (input-self.mean) / self.std 267 | # target = (target-self.mean) / self.std 268 | loss = 0.0 269 | x = input 270 | y = target 271 | for i,block in enumerate(self.blocks): 272 | x = block(x) 273 | y = block(y) 274 | if i in feature_layers: 275 | loss += weights[i] * self.d(x, y) 276 | return loss 277 | 278 | class SID_loss(nn.Module): 279 | def __init__(self): 280 | super(SID_loss).__init__() 281 | 282 | criterion = nn.KLDivLoss() 283 | 284 | def forward(self,x,y): 285 | p = torch.zeros_like(x).cuda() 286 | q = torch.zeros_like(x).cuda() 287 | Sid = 0 288 | # for i in range(len(x)): 289 | # p[i] = x[i] / torch.sum(x) 290 | # q[i] = y[i] / torch.sum(y) 291 | # print(p[i],q[i]) 292 | for j in range(len(x)): 293 | Sid += p[j] * np.log10(p[j] / q[j]) + q[j] * np.log10(q[j] / p[j]) 294 | return Sid 295 | -------------------------------------------------------------------------------- /models/modules/loss_new.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import Variable 5 | from math import exp 6 | import numpy as np 7 | from torchvision import models 8 | 9 | 10 | ######################################################################################################################################### 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | from torch.autograd import Variable 15 | import numpy as np 16 | from math import exp 17 | 18 | 19 | def gaussian(window_size, sigma): 20 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 21 | return gauss / gauss.sum() 22 | 23 | 24 | def create_window(window_size, channel): 25 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 26 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 27 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 28 | return window 29 | 30 | 31 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 32 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 33 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 34 | 35 | mu1_sq = mu1.pow(2) 36 | mu2_sq = mu2.pow(2) 37 | mu1_mu2 = mu1 * mu2 38 | 39 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 40 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 41 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 42 | 43 | C1 = 0.01 ** 2 44 | C2 = 0.03 ** 2 45 | 46 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 47 | 48 | if size_average: 49 | return (-1) * ssim_map.mean() 50 | else: 51 | return (-1) * ssim_map.mean(1).mean(1).mean(1) 52 | 53 | 54 | class SSIMLoss(torch.nn.Module): 55 | def __init__(self, window_size=11, size_average=True): 56 | super(SSIMLoss, self).__init__() 57 | self.window_size = window_size 58 | self.size_average = size_average 59 | self.channel = 1 60 | self.window = create_window(window_size, self.channel) 61 | 62 | def forward(self, img1, img2): 63 | (_, channel, _, _) = img1.size() 64 | 65 | if channel == self.channel and self.window.data.type() == img1.data.type(): 66 | window = self.window 67 | else: 68 | window = create_window(self.window_size, channel) 69 | 70 | if img1.is_cuda: 71 | window = window.cuda(img1.get_device()) 72 | window = window.type_as(img1) 73 | 74 | self.window = window 75 | self.channel = channel 76 | 77 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 78 | 79 | 80 | def ssim(img1, img2, window_size=11, size_average=True): 81 | (_, channel, _, _) = img1.size() 82 | window = create_window(window_size, channel) 83 | 84 | if img1.is_cuda: 85 | window = window.cuda(img1.get_device()) 86 | window = window.type_as(img1) 87 | 88 | return _ssim(img1, img2, window, window_size, channel, size_average) 89 | 90 | 91 | ########################################################################################################################### 92 | 93 | 94 | 95 | class Vgg19(nn.Module): 96 | def __init__(self, id, requires_grad=False): 97 | super(Vgg19, self).__init__() 98 | vgg = models.vgg19(pretrained=False) 99 | vgg.load_state_dict(torch.load('/model/1760921465/NewWork2021/vgg19-dcbb9e9d.pth')) 100 | vgg.eval() 101 | vgg_pretrained_features = vgg.features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | for x in range(3): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(3, 7): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(7, 12): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(12, 21): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(21, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | self.id = id 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h_relu1 = self.slice1(X) 124 | h_relu2 = self.slice2(h_relu1) 125 | h_relu3 = self.slice3(h_relu2) 126 | h_relu4 = self.slice4(h_relu3) 127 | h_relu5 = self.slice5(h_relu4) 128 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 129 | return out[self.id] 130 | 131 | 132 | class VGGLoss(nn.Module): 133 | def __init__(self, id, gpu_id=0): 134 | super(VGGLoss, self).__init__() 135 | self.vgg = Vgg19(id).cuda(gpu_id) 136 | self.criterion = nn.MSELoss() 137 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 138 | self.downsample = nn.AvgPool2d(2, stride=2, count_include_pad=False) 139 | 140 | def forward(self, x, y): 141 | while x.size()[3] > 4096: 142 | x, y = self.downsample(x), self.downsample(y) 143 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 144 | # loss = 0 145 | # for i in range(len(x_vgg)): 146 | loss = self.criterion(x_vgg, y_vgg.detach()) 147 | return loss 148 | 149 | 150 | ############################################################################################################################3 151 | 152 | 153 | class GradientLoss(nn.Module): 154 | """Gradient Histogram Loss""" 155 | def __init__(self): 156 | super(GradientLoss, self).__init__() 157 | self.bin_num = 64 158 | self.delta = 0.2 159 | self.clip_radius = 0.2 160 | assert(self.clip_radius>0 and self.clip_radius<=1) 161 | self.bin_width = 2*self.clip_radius/self.bin_num 162 | if self.bin_width*255<1: 163 | raise RuntimeError("bin width is too small") 164 | self.bin_mean = np.arange(-self.clip_radius+self.bin_width*0.5, self.clip_radius, self.bin_width) 165 | self.gradient_hist_loss_function = 'L2' 166 | # default is KL loss 167 | if self.gradient_hist_loss_function == 'L2': 168 | self.criterion = nn.MSELoss() 169 | elif self.gradient_hist_loss_function == 'L1': 170 | self.criterion = nn.L1Loss() 171 | else: 172 | self.criterion = nn.KLDivLoss() 173 | 174 | def get_response(self, gradient, mean): 175 | # tmp = torch.mul(torch.pow(torch.add(gradient, -mean), 2), self.delta_square_inverse) 176 | s = (-1) / (self.delta ** 2) 177 | tmp = ((gradient - mean) ** 2) * s 178 | return torch.mean(torch.exp(tmp)) 179 | 180 | def get_gradient(self, src): 181 | right_src = src[:, :, 1:, 0:-1] # shift src image right by one pixel 182 | down_src = src[:, :, 0:-1, 1:] # shift src image down by one pixel 183 | clip_src = src[:, :, 0:-1, 0:-1] # make src same size as shift version 184 | d_x = right_src - clip_src 185 | d_y = down_src - clip_src 186 | 187 | return d_x, d_y 188 | 189 | def get_gradient_hist(self, gradient_x, gradient_y): 190 | lx = None 191 | ly = None 192 | for ind_bin in range(self.bin_num): 193 | fx = self.get_response(gradient_x, self.bin_mean[ind_bin]) 194 | fy = self.get_response(gradient_y, self.bin_mean[ind_bin]) 195 | fx = torch.cuda.FloatTensor([fx]) 196 | fy = torch.cuda.FloatTensor([fy]) 197 | 198 | if lx is None: 199 | lx = fx 200 | ly = fy 201 | else: 202 | lx = torch.cat((lx, fx), 0) 203 | ly = torch.cat((ly, fy), 0) 204 | # lx = torch.div(lx, torch.sum(lx)) 205 | # ly = torch.div(ly, torch.sum(ly)) 206 | return lx, ly 207 | 208 | def forward(self, output, target): 209 | output_gradient_x, output_gradient_y = self.get_gradient(output) 210 | target_gradient_x, target_gradient_y = self.get_gradient(target) 211 | 212 | output_gradient_x_hist, output_gradient_y_hist = self.get_gradient_hist(output_gradient_x, output_gradient_y) 213 | target_gradient_x_hist, target_gradient_y_hist = self.get_gradient_hist(target_gradient_x, target_gradient_y) 214 | # loss = self.criterion(output_gradient_x_hist, target_gradient_x_hist) + self.criterion(output_gradient_y_hist, target_gradient_y_hist) 215 | loss = self.criterion(output_gradient_x,target_gradient_x)+self.criterion(output_gradient_y,target_gradient_y) 216 | return loss 217 | -------------------------------------------------------------------------------- /models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def initialize_weights_xavier(net_l, scale=1): 28 | if not isinstance(net_l, list): 29 | net_l = [net_l] 30 | for net in net_l: 31 | for m in net.modules(): 32 | if isinstance(m, nn.Conv2d): 33 | init.xavier_normal_(m.weight) 34 | m.weight.data *= scale # for residual block 35 | if m.bias is not None: 36 | m.bias.data.zero_() 37 | elif isinstance(m, nn.Linear): 38 | init.xavier_normal_(m.weight) 39 | m.weight.data *= scale 40 | if m.bias is not None: 41 | m.bias.data.zero_() 42 | elif isinstance(m, nn.BatchNorm2d): 43 | init.constant_(m.weight, 1) 44 | init.constant_(m.bias.data, 0.0) 45 | 46 | def sine_init(m): 47 | with torch.no_grad(): 48 | if hasattr(m, 'weight'): 49 | num_input = m.weight.size(-1) 50 | # See supplement Sec. 1.5 for discussion of factor 30 51 | m.weight.uniform_(-np.sqrt(6 / num_input) / 30, np.sqrt(6 / num_input) / 30) 52 | 53 | 54 | def first_layer_sine_init(m): 55 | with torch.no_grad(): 56 | if hasattr(m, 'weight'): 57 | num_input = m.weight.size(-1) 58 | # See paper sec. 3.2, final paragraph, and supplement Sec. 1.5 for discussion of factor 30 59 | m.weight.uniform_(-1 / num_input, 1 / num_input) 60 | 61 | def make_layer(block, n_layers): 62 | layers = [] 63 | for _ in range(n_layers): 64 | layers.append(block()) 65 | return nn.Sequential(*layers) 66 | 67 | 68 | class ResidualBlock_noBN(nn.Module): 69 | '''Residual block w/o BN 70 | ---Conv-ReLU-Conv-+- 71 | |________________| 72 | ''' 73 | 74 | def __init__(self, nf=64): 75 | super(ResidualBlock_noBN, self).__init__() 76 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 77 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 78 | 79 | # initialization 80 | initialize_weights([self.conv1, self.conv2], 0.1) 81 | 82 | def forward(self, x): 83 | identity = x 84 | out = F.relu(self.conv1(x), inplace=True) 85 | out = self.conv2(out) 86 | return identity + out 87 | 88 | 89 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 90 | """Warp an image or feature map with optical flow 91 | Args: 92 | x (Tensor): size (N, C, H, W) 93 | flow (Tensor): size (N, H, W, 2), normal value 94 | interp_mode (str): 'nearest' or 'bilinear' 95 | padding_mode (str): 'zeros' or 'border' or 'reflection' 96 | 97 | Returns: 98 | Tensor: warped image or feature map 99 | """ 100 | assert x.size()[-2:] == flow.size()[1:3] 101 | B, C, H, W = x.size() 102 | # mesh grid 103 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 104 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 105 | grid.requires_grad = False 106 | grid = grid.type_as(x) 107 | vgrid = grid + flow 108 | # scale grid to [-1,1] 109 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 110 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 111 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 112 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 113 | return output 114 | -------------------------------------------------------------------------------- /models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import logging 3 | from models.modules.NAFNet.NAFNet import NAFNet 4 | 5 | import math 6 | logger = logging.getLogger('base') 7 | 8 | 9 | #################### 10 | # define network 11 | #################### 12 | def define_G(opt): 13 | img_channel = 3 14 | width = 32 15 | enc_blks= [2, 2, 4, 8] 16 | middle_blk_num= 6 17 | dec_blks= [2, 2, 2, 2] 18 | 19 | netG = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, 20 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) 21 | 22 | return netG 23 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/options/__init__.py -------------------------------------------------------------------------------- /options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | scale = opt['scale'] 19 | 20 | # datasets 21 | for phase, dataset in opt['datasets'].items(): 22 | phase = phase.split('_')[0] 23 | dataset['phase'] = phase 24 | dataset['scale'] = scale 25 | is_mat = False 26 | if dataset.get('dataroot_gt', None) is not None: 27 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 28 | if dataset['dataroot_gt'].endswith('mat'): 29 | is_mat = True 30 | # if dataset.get('dataroot_GT_bg', None) is not None: 31 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 32 | if dataset.get('dataroot_lq', None) is not None: 33 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 34 | if dataset['dataroot_lq'].endswith('mat'): 35 | is_mat = True 36 | dataset['data_type'] = 'mat' if is_mat else 'img' 37 | if dataset['mode'].endswith('mc'): # for memcached 38 | dataset['data_type'] = 'mc' 39 | dataset['mode'] = dataset['mode'].replace('_mc', '') 40 | 41 | # path 42 | for key, path in opt['path'].items(): 43 | if path and key in opt['path'] and key != 'strict_load': 44 | opt['path'][key] = osp.expanduser(path) 45 | 46 | if opt['path']['root'] == None: 47 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 48 | 49 | if is_train: 50 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 51 | opt['path']['experiments_root'] = experiments_root 52 | opt['path']['models'] = osp.join(experiments_root, 'models') 53 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 54 | opt['path']['log'] = experiments_root 55 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 56 | 57 | # change some options for debug mode 58 | if 'debug' in opt['name']: 59 | opt['train']['val_freq'] = 8 60 | opt['logger']['print_freq'] = 1 61 | opt['logger']['save_checkpoint_freq'] = 8 62 | else: # test 63 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 64 | opt['path']['results_root'] = results_root 65 | opt['path']['log'] = results_root 66 | opt['network_G']['scale'] = scale 67 | 68 | return opt 69 | 70 | 71 | def dict2str(opt, indent_l=1): 72 | '''dict to string for logger''' 73 | msg = '' 74 | for k, v in opt.items(): 75 | if isinstance(v, dict): 76 | msg += ' ' * (indent_l * 2) + k + ':[\n' 77 | msg += dict2str(v, indent_l + 1) 78 | msg += ' ' * (indent_l * 2) + ']\n' 79 | else: 80 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 81 | return msg 82 | 83 | 84 | class NoneDict(dict): 85 | def __missing__(self, key): 86 | return None 87 | 88 | 89 | # convert to NoneDict, which return None for missing key. 90 | def dict_to_nonedict(opt): 91 | if isinstance(opt, dict): 92 | new_opt = dict() 93 | for key, sub_opt in opt.items(): 94 | new_opt[key] = dict_to_nonedict(sub_opt) 95 | return NoneDict(**new_opt) 96 | elif isinstance(opt, list): 97 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 98 | else: 99 | return opt 100 | 101 | 102 | def check_resume(opt, resume_iter): 103 | '''Check resume states and pretrain_model paths''' 104 | logger = logging.getLogger('base') 105 | if opt['path']['resume_state']: 106 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 107 | 'pretrain_model_D', None) is not None: 108 | logger.warning('pretrain_model path will be ignored when resuming training.') 109 | 110 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 111 | '{}_G.pth'.format(resume_iter)) 112 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 113 | if 'gan' in opt['model']: 114 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 115 | '{}_D.pth'.format(resume_iter)) 116 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 117 | -------------------------------------------------------------------------------- /options/test/test.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: test 5 | use_tb_logger: False 6 | model: CSNorm 7 | scale: 2 8 | gpu_ids: [0] 9 | 10 | #### datasets 11 | 12 | datasets: 13 | val: 14 | name: data_val 15 | mode: JSH_val 16 | # dataroot_gt: './data/example' # path to validation Clean images 17 | # dataroot_lq: './data/example' # path to validation Noisy images 18 | dataroot_gt: './README' # path to validation Clean images 19 | dataroot_lq: './README' # path to validation Noisy images 20 | 21 | #### network structures 22 | 23 | network_G: 24 | which_model_G: 25 | subnet_type: Resnet 26 | in_nc: 3 27 | out_nc: 3 28 | block_num: [8, 8] 29 | scale: 2 30 | init: xavier 31 | 32 | 33 | #### path 34 | 35 | path: 36 | root: ./ 37 | pretrain_model_G: ./models/ckpts/NAF_LOL.pth 38 | strict_load: true 39 | resume_state: ~ 40 | 41 | 42 | #### training settings: learning rate scheme, loss 43 | 44 | train: 45 | lr_G: !!float 1e-4 46 | beta1: 0.9 47 | beta2: 0.999 48 | niter: 600000 49 | warmup_iter: -1 # no warm up 50 | 51 | lr_scheme: MultiStepLR 52 | lr_steps: [5000, 10000, 15000, 30000, 500000] 53 | lr_gamma: 0.5 54 | 55 | pixel_criterion_forw: l2 56 | pixel_criterion_back: l1 57 | pixel_criterion_hist: l2 58 | 59 | manual_seed: 9 60 | 61 | val_freq: !!float 2000 62 | 63 | vgg16_model: 64 | 65 | lambda_fit_forw: 10 66 | lambda_vgg_forw: 0. 67 | lambda_structure_forw: 1 68 | lambda_orth_forw: 1 69 | 70 | lambda_rec_back: 1 71 | lambda_structure_back: 1 72 | lambda_orth_back: 1 73 | 74 | weight_decay_G: !!float 1e-8 75 | gradient_clipping: 10 76 | 77 | 78 | #### logger 79 | 80 | logger: 81 | print_freq: 500 82 | save_checkpoint_freq: !!float 5000 83 | -------------------------------------------------------------------------------- /options/train/train_InvDN.yml: -------------------------------------------------------------------------------- 1 | 2 | #### general settings 3 | 4 | name: CSNorm_log 5 | use_tb_logger: False 6 | model: CSNorm 7 | gpu_ids: [0] 8 | 9 | #### datasets 10 | 11 | datasets: 12 | train: 13 | name: data_train 14 | mode: JSH_train 15 | dataroot_gt: 'TRAINDATA/GT' # path to training Clean images 16 | dataroot_lq: 'TRAINDATA/LQ' # path to training Noisy images 17 | 18 | use_shuffle: true 19 | n_workers: 4 # per GPU 20 | batch_size: 4 21 | GT_size: 256 22 | use_flip: true 23 | use_rot: true 24 | color: RGB 25 | 26 | val: 27 | name: data_val 28 | mode: JSH_val 29 | dataroot_gt: '.VALIDATA/GT' # path to validation Clean images 30 | dataroot_lq: '.VALIDATA/LQ' # path to validation Noisy images 31 | 32 | #### network structures 33 | 34 | network_G: 35 | which_model_G: 36 | subnet_type: Resnet 37 | in_nc: 3 38 | out_nc: 3 39 | block_num: [8, 8] 40 | scale: 2 41 | init: xavier 42 | 43 | 44 | #### path 45 | 46 | path: 47 | root: ./ 48 | pretrain_model_G: 49 | strict_load: true 50 | resume_state: ~ 51 | 52 | 53 | #### training settings: learning rate scheme, loss 54 | 55 | train: 56 | lr_G: !!float 2e-4 57 | beta1: 0.9 58 | beta2: 0.999 59 | niter: 600000 60 | warmup_iter: -1 # no warm up 61 | 62 | lr_scheme: MultiStepLR 63 | lr_steps: [50000, 80000, 100000, 200000, 500000] 64 | lr_gamma: 0.5 65 | 66 | pixel_criterion_forw: l2 67 | pixel_criterion_back: l2 68 | pixel_criterion_hist: l2 69 | 70 | manual_seed: 9 71 | 72 | val_freq: !!float 100 73 | 74 | vgg16_model: 75 | 76 | lambda_fit_forw: 10 77 | lambda_vgg_forw: 0. 78 | lambda_structure_forw: 1 79 | lambda_orth_forw: 1 80 | 81 | lambda_rec_back: 1 82 | lambda_structure_back: 1 83 | lambda_orth_back: 1 84 | 85 | weight_decay_G: !!float 1e-8 86 | gradient_clipping: 10 87 | 88 | 89 | #### logger 90 | 91 | logger: 92 | print_freq: 40 93 | save_checkpoint_freq: !!float 5000 94 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from data.data_sampler import DistIterSampler 11 | 12 | import options.options as option 13 | from utils import util 14 | from data import create_dataloader, create_dataset 15 | from models import create_model 16 | import numpy as np 17 | 18 | 19 | def init_dist(backend='nccl', **kwargs): 20 | ''' initialization for distributed training''' 21 | # if mp.get_start_method(allow_none=True) is None: 22 | if mp.get_start_method(allow_none=True) != 'spawn': 23 | mp.set_start_method('spawn') 24 | rank = int(os.environ['RANK']) 25 | num_gpus = torch.cuda.device_count() 26 | torch.cuda.set_device(rank % num_gpus) 27 | dist.init_process_group(backend=backend, **kwargs) 28 | 29 | 30 | def main(): 31 | #### options 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 34 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 35 | help='job launcher') 36 | parser.add_argument('--local_rank', type=int, default=0) 37 | args = parser.parse_args() 38 | opt = option.parse(args.opt, is_train=True) 39 | 40 | #### distributed training settings 41 | if args.launcher == 'none': # disabled distributed training 42 | opt['dist'] = False 43 | rank = -1 44 | print('Disabled distributed training.') 45 | else: 46 | opt['dist'] = True 47 | init_dist() 48 | world_size = torch.distributed.get_world_size() 49 | rank = torch.distributed.get_rank() 50 | 51 | #### loading resume state if exists 52 | if opt['path'].get('resume_state', None): 53 | # distributed resuming: all load into default GPU 54 | device_id = torch.cuda.current_device() 55 | resume_state = torch.load(opt['path']['resume_state'], 56 | map_location=lambda storage, loc: storage.cuda(device_id)) 57 | option.check_resume(opt, resume_state['iter']) # check resume options 58 | else: 59 | resume_state = None 60 | 61 | #### mkdir and loggers 62 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 63 | if resume_state is None: 64 | util.mkdir_and_rename( 65 | opt['path']['experiments_root']) # rename experiment folder if exists 66 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 67 | and 'pretrain_model' not in key and 'resume' not in key)) 68 | 69 | # config loggers. Before it, the log will not work 70 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 71 | screen=True, tofile=True) 72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 73 | screen=True, tofile=True) 74 | logger = logging.getLogger('base') 75 | logger.info(option.dict2str(opt)) 76 | # tensorboard logger 77 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 78 | version = float(torch.__version__[0:3]) 79 | if version >= 1.1: # PyTorch 1.1 80 | from tensorboardX import SummaryWriter 81 | else: 82 | logger.info( 83 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 84 | from tensorboardX import SummaryWriter 85 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name']) 86 | else: 87 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 88 | logger = logging.getLogger('base') 89 | 90 | # convert to NoneDict, which returns None for missing keys 91 | opt = option.dict_to_nonedict(opt) 92 | 93 | #### random seed 94 | seed = opt['train']['manual_seed'] 95 | if seed is None: 96 | seed = random.randint(1, 10000) 97 | if rank <= 0: 98 | logger.info('Random seed: {}'.format(seed)) 99 | util.set_random_seed(seed) 100 | 101 | torch.backends.cudnn.benchmark = True 102 | # torch.backends.cudnn.deterministic = True 103 | 104 | #### create train and val dataloader 105 | dataset_ratio = 200 # enlarge the size of each epoch 106 | for phase, dataset_opt in opt['datasets'].items(): 107 | if phase == 'val': 108 | val_set = create_dataset(dataset_opt) 109 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 110 | if rank <= 0: 111 | logger.info('Number of val images in [{:s}]: {:d}'.format( 112 | dataset_opt['name'], len(val_set))) 113 | else: 114 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 115 | 116 | #### create model 117 | model = create_model(opt) 118 | 119 | # #### resume training 120 | # if resume_state: 121 | # logger.info('Resuming training from epoch: {}, iter: {}.'.format( 122 | # resume_state['epoch'], resume_state['iter'])) 123 | # 124 | # start_epoch = resume_state['epoch'] 125 | # current_step = resume_state['iter'] 126 | # model.resume_training(resume_state) # handle optimizers and schedulers 127 | # else: 128 | # current_step = 0 129 | # start_epoch = 0 130 | 131 | #### test 132 | avg_psnr = 0.0 133 | idx = 0 134 | for val_data in val_loader: 135 | idx += 1 136 | # img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 137 | # img_dir = os.path.join(opt['path']['val_images'], img_name) 138 | # util.mkdir(img_dir) 139 | model.feed_data_test(val_data) 140 | model.test() 141 | visuals = model.get_current_visuals() 142 | img_input = visuals['img_input'].numpy() 143 | img_pred = visuals['img_pred'].numpy() 144 | img_gt = visuals['img_gt'].numpy() 145 | 146 | ########################## save images for visualization################### 147 | img_input = img_input[::-1,:,:] 148 | img_pred1 = img_pred[::-1,:,:] 149 | img_gt1 = img_gt[::-1,:,:] 150 | 151 | img_input = img_input.transpose(1,2,0) 152 | img_pred1 = img_pred1.transpose(1,2,0) 153 | img_gt1 = img_gt1.transpose(1,2,0) 154 | 155 | from PIL import Image 156 | 157 | img_pred1 = np.clip(img_pred1,0,1) 158 | Image.fromarray((img_pred1*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d.png'%idx)) 159 | 160 | img_input = np.clip(img_input,0,1) 161 | Image.fromarray((img_input*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d_i.png'%idx)) 162 | 163 | img_gt1 = np.clip(img_gt1,0,1) 164 | Image.fromarray((img_gt1*255).astype(np.uint8)).save(os.path.join(opt['path']['val_images'], '%03d_t.png'%idx)) 165 | 166 | 167 | def compute_psnr(img_orig, img_out, peak): 168 | mse = np.mean(np.square(img_orig - img_out)) 169 | psnr = 10 * np.log10(peak * peak / mse) 170 | return psnr 171 | curr_psnr = compute_psnr(img_pred, img_gt, 1) 172 | avg_psnr += curr_psnr 173 | print('idx', idx, curr_psnr) 174 | 175 | avg_psnr = avg_psnr / idx 176 | 177 | logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr)) 178 | 179 | 180 | 181 | if __name__ == '__main__': 182 | main() 183 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | 7 | import torch 8 | import torch.distributed as dist 9 | import torch.multiprocessing as mp 10 | from data.data_sampler import DistIterSampler 11 | 12 | import options.options as option 13 | from utils import util 14 | from data import create_dataloader, create_dataset 15 | from models import create_model 16 | import numpy as np 17 | 18 | 19 | def init_dist(backend='nccl', **kwargs): 20 | ''' initialization for distributed training''' 21 | # if mp.get_start_method(allow_none=True) is None: 22 | if mp.get_start_method(allow_none=True) != 'spawn': 23 | mp.set_start_method('spawn') 24 | rank = int(os.environ['RANK']) 25 | num_gpus = torch.cuda.device_count() 26 | torch.cuda.set_device(rank % num_gpus) 27 | dist.init_process_group(backend=backend, **kwargs) 28 | 29 | 30 | def main(): 31 | #### options 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 34 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 35 | help='job launcher') 36 | parser.add_argument('--local_rank', type=int, default=0) 37 | args = parser.parse_args() 38 | opt = option.parse(args.opt, is_train=True) 39 | 40 | #### distributed training settings 41 | if args.launcher == 'none': # disabled distributed training 42 | opt['dist'] = False 43 | rank = -1 44 | print('Disabled distributed training.') 45 | else: 46 | opt['dist'] = True 47 | init_dist() 48 | world_size = torch.distributed.get_world_size() 49 | rank = torch.distributed.get_rank() 50 | 51 | #### loading resume state if exists 52 | if opt['path'].get('resume_state', None): 53 | # distributed resuming: all load into default GPU 54 | device_id = torch.cuda.current_device() 55 | resume_state = torch.load(opt['path']['resume_state'], 56 | map_location=lambda storage, loc: storage.cuda(device_id)) 57 | option.check_resume(opt, resume_state['iter']) # check resume options 58 | else: 59 | resume_state = None 60 | 61 | #### mkdir and loggers 62 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 63 | if resume_state is None: 64 | util.mkdir_and_rename( 65 | opt['path']['experiments_root']) # rename experiment folder if exists 66 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 67 | and 'pretrain_model' not in key and 'resume' not in key)) 68 | 69 | # config loggers. Before it, the log will not work 70 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 71 | screen=True, tofile=True) 72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 73 | screen=True, tofile=True) 74 | logger = logging.getLogger('base') 75 | logger.info(option.dict2str(opt)) 76 | # tensorboard logger 77 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 78 | version = float(torch.__version__[0:3]) 79 | if version >= 1.1: # PyTorch 1.1 80 | from tensorboardX import SummaryWriter 81 | else: 82 | logger.info( 83 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 84 | from tensorboardX import SummaryWriter 85 | tb_logger = SummaryWriter(log_dir='./tb_logger/' + opt['name']) 86 | else: 87 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 88 | logger = logging.getLogger('base') 89 | 90 | # convert to NoneDict, which returns None for missing keys 91 | opt = option.dict_to_nonedict(opt) 92 | 93 | #### random seed 94 | seed = opt['train']['manual_seed'] 95 | if seed is None: 96 | seed = random.randint(1, 10000) 97 | if rank <= 0: 98 | logger.info('Random seed: {}'.format(seed)) 99 | util.set_random_seed(seed) 100 | 101 | torch.backends.cudnn.benchmark = True 102 | # torch.backends.cudnn.deterministic = True 103 | 104 | #### create train and val dataloader 105 | dataset_ratio = 1 # enlarge the size of each epoch 106 | for phase, dataset_opt in opt['datasets'].items(): 107 | if phase == 'train': 108 | train_set = create_dataset(dataset_opt) 109 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 110 | 111 | total_iters = int(opt['train']['niter']) 112 | total_epochs = int(math.ceil(total_iters / train_size)) 113 | if opt['dist']: 114 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 115 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 116 | else: 117 | train_sampler = None 118 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 119 | if rank <= 0: 120 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 121 | len(train_set), train_size)) 122 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 123 | total_epochs, total_iters)) 124 | elif phase == 'val': 125 | val_set = create_dataset(dataset_opt) 126 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 127 | if rank <= 0: 128 | logger.info('Number of val images in [{:s}]: {:d}'.format( 129 | dataset_opt['name'], len(val_set))) 130 | else: 131 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 132 | assert train_loader is not None 133 | 134 | #### create model 135 | model = create_model(opt) 136 | 137 | #### resume training 138 | if resume_state: 139 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 140 | resume_state['epoch'], resume_state['iter'])) 141 | 142 | start_epoch = resume_state['epoch'] 143 | current_step = resume_state['iter'] 144 | model.resume_training(resume_state) # handle optimizers and schedulers 145 | else: 146 | current_step = 0 147 | start_epoch = 0 148 | 149 | val_freq = opt['train']['val_freq'] 150 | #### training 151 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 152 | for epoch in range(start_epoch, total_epochs + 1): 153 | if opt['dist']: 154 | train_sampler.set_epoch(epoch) 155 | for _, train_data in enumerate(train_loader): 156 | current_step += 1 157 | 158 | if current_step > total_iters: 159 | break 160 | #### training 161 | model.feed_data(train_data) 162 | model.optimize_parameters(current_step) 163 | 164 | #### update learning rate 165 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 166 | 167 | #### log 168 | if current_step % opt['logger']['print_freq'] == 0: 169 | logs = model.get_current_log() 170 | message = ' '.format( 171 | epoch, current_step, model.get_current_learning_rate()) 172 | for k, v in logs.items(): 173 | message += '{:s}: {:.4e} '.format(k, v) 174 | # tensorboard logger 175 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 176 | if rank <= 0: 177 | tb_logger.add_scalar(k, v, current_step) 178 | if rank <= 0: 179 | logger.info(message) 180 | 181 | # validation 182 | if current_step % val_freq == 0 and rank <= 0: 183 | avg_psnr = 0.0 184 | idx = 0 185 | for val_data in val_loader: 186 | idx += 1 187 | model.feed_data_test(val_data) 188 | model.test() 189 | visuals = model.get_current_visuals() 190 | img_pred = visuals['img_pred'].numpy() 191 | img_gt = visuals['img_gt'].numpy() 192 | 193 | def compute_psnr(img_orig, img_out, peak): 194 | mse = np.mean(np.square(img_orig - img_out)) 195 | psnr = 10 * np.log10(peak * peak / mse) 196 | return psnr 197 | curr_psnr = compute_psnr(img_pred, img_gt, 1) 198 | avg_psnr += curr_psnr 199 | print('idx', idx, curr_psnr) 200 | 201 | avg_psnr = avg_psnr / idx 202 | 203 | # log 204 | logger.info('# Validation # PSNR: {:.4e}.'.format(avg_psnr)) 205 | logger_val = logging.getLogger('val') # validation logger 206 | logger_val.info(' psnr: {:.4e}.'.format( 207 | epoch, current_step, avg_psnr)) 208 | 209 | #### save models and training states 210 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 211 | if rank <= 0: 212 | logger.info('Saving models and training states.') 213 | model.save(current_step) 214 | # model.save_training_state(epoch, current_step) 215 | 216 | if rank <= 0: 217 | logger.info('Saving the final model.') 218 | model.save('latest') 219 | logger.info('End of training.') 220 | 221 | 222 | if __name__ == '__main__': 223 | main() 224 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mdyao/CSNorm/49bf5a07ac1c58c8d2c221ac86022698c7f1c897/utils/__init__.py -------------------------------------------------------------------------------- /utils/pytorch_ssim/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | def gaussian(window_size, sigma): 8 | gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)]) 9 | return gauss/gauss.sum() 10 | 11 | def create_window(window_size, channel): 12 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 13 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 14 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 15 | return window 16 | 17 | def _ssimmap(img1, window, window_size, channel, size_average = True): 18 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 19 | mu1_sq = mu1.pow(2) 20 | 21 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 22 | # sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 23 | # sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 24 | 25 | C1 = 0.01**2 26 | C2 = 0.03**2 27 | 28 | # ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 29 | feat_map = torch.cat((mu1, C1/(mu1+C1), sigma1_sq, C2/(sigma1_sq+C2)), 1) 30 | 31 | return feat_map 32 | 33 | def _ssim(img1, img2, window, window_size, channel, size_average = True): 34 | mu1 = F.conv2d(img1, window, padding = window_size//2, groups = channel) 35 | mu2 = F.conv2d(img2, window, padding = window_size//2, groups = channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1*mu2 40 | 41 | sigma1_sq = F.conv2d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 42 | sigma2_sq = F.conv2d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 43 | sigma12 = F.conv2d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 44 | 45 | C1 = 0.01**2 46 | C2 = 0.03**2 47 | 48 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 49 | 50 | if size_average: 51 | return (-1)*ssim_map.mean() 52 | else: 53 | return (-1)*ssim_map.mean(1).mean(1).mean(1) 54 | 55 | class SSIMMap(torch.nn.Module): 56 | def __init__(self, window_size = 11, size_average = True): 57 | super(SSIMMap, self).__init__() 58 | self.window_size = window_size 59 | self.size_average = size_average 60 | self.channel = 1 61 | self.window = create_window(window_size, self.channel) 62 | 63 | def forward(self, img1): 64 | (_, channel, _, _) = img1.size() 65 | 66 | if channel == self.channel and self.window.data.type() == img1.data.type(): 67 | window = self.window 68 | else: 69 | window = create_window(self.window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | self.window = window 76 | self.channel = channel 77 | 78 | return _ssimmap(img1, window, self.window_size, channel, self.size_average) 79 | 80 | class SSIM(torch.nn.Module): 81 | def __init__(self, window_size = 11, size_average = True): 82 | super(SSIM, self).__init__() 83 | self.window_size = window_size 84 | self.size_average = size_average 85 | self.channel = 1 86 | self.window = create_window(window_size, self.channel) 87 | 88 | def forward(self, img1, img2): 89 | (_, channel, _, _) = img1.size() 90 | 91 | if channel == self.channel and self.window.data.type() == img1.data.type(): 92 | window = self.window 93 | else: 94 | window = create_window(self.window_size, channel) 95 | 96 | if img1.is_cuda: 97 | window = window.cuda(img1.get_device()) 98 | window = window.type_as(img1) 99 | 100 | self.window = window 101 | self.channel = channel 102 | 103 | 104 | return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 105 | 106 | def ssim(img1, img2, window_size = 11, size_average = True): 107 | (_, channel, _, _) = img1.size() 108 | window = create_window(window_size, channel) 109 | 110 | if img1.is_cuda: 111 | window = window.cuda(img1.get_device()) 112 | window = window.type_as(img1) 113 | 114 | return _ssim(img1, img2, window, window_size, channel, size_average) 115 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | from datetime import datetime 6 | import random 7 | import logging 8 | from collections import OrderedDict 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from torchvision.utils import make_grid 13 | from shutil import get_terminal_size 14 | import imageio 15 | import yaml 16 | try: 17 | from yaml import CLoader as Loader, CDumper as Dumper 18 | except ImportError: 19 | from yaml import Loader, Dumper 20 | 21 | def rgb2yuv(img): 22 | y = 0.299 * img[:, 0] + 0.587 * img[:, 1] + 0.114 * img[:, 2] 23 | u = -0.169 * img[:, 0] - 0.331 * img[:, 1] + 0.5 * img[:, 2] + 0.5 24 | v = 0.5 * img[:, 0] - 0.419 * img[:, 1] - 0.081 * img[:, 2] + 0.5 25 | out = torch.stack((y, u, v)) 26 | out = out.transpose(0, 1) 27 | return out 28 | 29 | 30 | def yuv2rgb(img): 31 | r = img[:, 0] + 1.4075 * (img[:, 2] - 0.5) 32 | g = img[:, 0] - 0.3455 * (img[:, 1] - 0.5) - 0.7169 * (img[:, 2] - 0.5) 33 | b = img[:, 0] + 1.779 * (img[:, 1] - 0.5) 34 | out = torch.stack((r, g, b)) 35 | out = out.transpose(0, 1) 36 | return out 37 | 38 | def OrderedYaml(): 39 | '''yaml orderedDict support''' 40 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 41 | 42 | def dict_representer(dumper, data): 43 | return dumper.represent_dict(data.items()) 44 | 45 | def dict_constructor(loader, node): 46 | return OrderedDict(loader.construct_pairs(node)) 47 | 48 | Dumper.add_representer(OrderedDict, dict_representer) 49 | Loader.add_constructor(_mapping_tag, dict_constructor) 50 | return Loader, Dumper 51 | 52 | def save_results_yuv(pred, index, test_img_dir): 53 | test_pred = np.squeeze(pred) 54 | test_pred = np.clip(test_pred, 0, 1) * 1023 55 | test_pred = np.uint16(test_pred) 56 | 57 | # split image 58 | pred_y = test_pred[:, :, 0] 59 | pred_u = test_pred[:, :, 1] 60 | pred_v = test_pred[:, :, 2] 61 | 62 | # save prediction - must be saved in separate channels due to 16-bit pixel depth 63 | imageio.imwrite(os.path.join(test_img_dir, "{}-y_pred.png".format(str(int(index) + 1).zfill(2))), 64 | pred_y) 65 | imageio.imwrite(os.path.join(test_img_dir, "{}-u_pred.png".format(str(int(index) + 1).zfill(2))), 66 | pred_u) 67 | imageio.imwrite(os.path.join(test_img_dir, "{}-v_pred.png".format(str(int(index) + 1).zfill(2))), 68 | pred_v) 69 | 70 | 71 | 72 | #################### 73 | # miscellaneous 74 | #################### 75 | 76 | 77 | def get_timestamp(): 78 | return datetime.now().strftime('%y%m%d-%H%M%S') 79 | 80 | 81 | def mkdir(path): 82 | if not os.path.exists(path): 83 | os.makedirs(path) 84 | 85 | 86 | def mkdirs(paths): 87 | if isinstance(paths, str): 88 | mkdir(paths) 89 | else: 90 | for path in paths: 91 | mkdir(path) 92 | 93 | 94 | def mkdir_and_rename(path): 95 | if os.path.exists(path): 96 | new_name = path + '_archived_' + get_timestamp() 97 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 98 | logger = logging.getLogger('base') 99 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 100 | os.rename(path, new_name) 101 | os.makedirs(path) 102 | 103 | 104 | def set_random_seed(seed): 105 | random.seed(seed) 106 | np.random.seed(seed) 107 | torch.manual_seed(seed) 108 | torch.cuda.manual_seed_all(seed) 109 | 110 | 111 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 112 | '''set up logger''' 113 | lg = logging.getLogger(logger_name) 114 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 115 | datefmt='%y-%m-%d %H:%M:%S') 116 | lg.setLevel(level) 117 | if tofile: 118 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 119 | fh = logging.FileHandler(log_file, mode='w') 120 | fh.setFormatter(formatter) 121 | lg.addHandler(fh) 122 | if screen: 123 | sh = logging.StreamHandler() 124 | sh.setFormatter(formatter) 125 | lg.addHandler(sh) 126 | 127 | 128 | #################### 129 | # image convert 130 | #################### 131 | 132 | 133 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 134 | ''' 135 | Converts a torch Tensor into an image Numpy array 136 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 137 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 138 | ''' 139 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 140 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 141 | n_dim = tensor.dim() 142 | if n_dim == 4: 143 | n_img = len(tensor) 144 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 145 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 146 | elif n_dim == 3: 147 | img_np = tensor.numpy() 148 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 149 | elif n_dim == 2: 150 | img_np = tensor.numpy() 151 | else: 152 | raise TypeError( 153 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 154 | if out_type == np.uint8: 155 | img_np = (img_np * 255.0).round() 156 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 157 | return img_np.astype(out_type) 158 | 159 | def tensor2img_Real(tensor, out_type=np.uint8, min_max=(0, 1)): 160 | ''' 161 | Converts a torch Tensor into an image Numpy array 162 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 163 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 164 | ''' 165 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 166 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 167 | n_dim = tensor.dim() 168 | if n_dim == 4: 169 | # n_img = len(tensor) 170 | # img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 171 | img_np = tensor.numpy() 172 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 173 | elif n_dim == 3: 174 | img_np = tensor.numpy() 175 | # img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 176 | elif n_dim == 2: 177 | img_np = tensor.numpy() 178 | else: 179 | raise TypeError( 180 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 181 | if out_type == np.uint8: 182 | img_np = (img_np * 255.0).round() 183 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 184 | return img_np.astype(out_type) 185 | 186 | def save_img(img, img_path, mode='RGB'): 187 | cv2.imwrite(img_path, img) 188 | 189 | 190 | #################### 191 | # metric 192 | #################### 193 | 194 | 195 | def calculate_psnr(img1, img2): 196 | # img1 and img2 have range [0, 255] 197 | img1 = img1.astype(np.float64) 198 | img2 = img2.astype(np.float64) 199 | mse = np.mean((img1 - img2)**2) 200 | if mse == 0: 201 | return float('inf') 202 | return 20 * math.log10(255.0 / math.sqrt(mse)) 203 | 204 | 205 | def ssim(img1, img2): 206 | C1 = (0.01 * 255)**2 207 | C2 = (0.03 * 255)**2 208 | 209 | img1 = img1.astype(np.float64) 210 | img2 = img2.astype(np.float64) 211 | kernel = cv2.getGaussianKernel(11, 1.5) 212 | window = np.outer(kernel, kernel.transpose()) 213 | 214 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 215 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 216 | mu1_sq = mu1**2 217 | mu2_sq = mu2**2 218 | mu1_mu2 = mu1 * mu2 219 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 220 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 221 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 222 | 223 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 224 | (sigma1_sq + sigma2_sq + C2)) 225 | return ssim_map.mean() 226 | 227 | 228 | def calculate_ssim(img1, img2): 229 | '''calculate SSIM 230 | the same outputs as MATLAB's 231 | img1, img2: [0, 255] 232 | ''' 233 | if not img1.shape == img2.shape: 234 | raise ValueError('Input images must have the same dimensions.') 235 | if img1.ndim == 2: 236 | return ssim(img1, img2) 237 | elif img1.ndim == 3: 238 | if img1.shape[2] == 3: 239 | ssims = [] 240 | for i in range(3): 241 | ssims.append(ssim(img1, img2)) 242 | return np.array(ssims).mean() 243 | elif img1.shape[2] == 1: 244 | return ssim(np.squeeze(img1), np.squeeze(img2)) 245 | else: 246 | raise ValueError('Wrong input image dimensions.') 247 | 248 | 249 | class ProgressBar(object): 250 | '''A progress bar which can print the progress 251 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 252 | ''' 253 | 254 | def __init__(self, task_num=0, bar_width=50, start=True): 255 | self.task_num = task_num 256 | max_bar_width = self._get_max_bar_width() 257 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 258 | self.completed = 0 259 | if start: 260 | self.start() 261 | 262 | def _get_max_bar_width(self): 263 | terminal_width, _ = get_terminal_size() 264 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 265 | if max_bar_width < 10: 266 | print('terminal width is too small ({}), please consider widen the terminal for better ' 267 | 'progressbar visualization'.format(terminal_width)) 268 | max_bar_width = 10 269 | return max_bar_width 270 | 271 | def start(self): 272 | if self.task_num > 0: 273 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 274 | ' ' * self.bar_width, self.task_num, 'Start...')) 275 | else: 276 | sys.stdout.write('completed: 0, elapsed: 0s') 277 | sys.stdout.flush() 278 | self.start_time = time.time() 279 | 280 | def update(self, msg='In progress...'): 281 | self.completed += 1 282 | elapsed = time.time() - self.start_time 283 | fps = self.completed / elapsed 284 | if self.task_num > 0: 285 | percentage = self.completed / float(self.task_num) 286 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 287 | mark_width = int(self.bar_width * percentage) 288 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 289 | sys.stdout.write('\033[2F') # cursor up 2 lines 290 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 291 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 292 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 293 | else: 294 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 295 | self.completed, int(elapsed + 0.5), fps)) 296 | sys.stdout.flush() 297 | --------------------------------------------------------------------------------